diff --git a/http/config.go b/http/config.go index 9bc7421cdca350fa98acb78fb98f861fde6e6527..aff4164dd19a9c9b7c6d2f79c01fab7cf6658e19 100644 --- a/http/config.go +++ b/http/config.go @@ -22,6 +22,12 @@ type ServerConfig struct { // Headers is an optional map of headers that is written out. Headers map[string][]string + // HandledMethods set which methods will be handled for the HTTP + // requests. Other methods will return 405. This is different from CORS + // AllowedMethods (the API may handle GET and POST, but only allow GETs + // for CORS-enabled requests via AllowedMethods). + HandledMethods []string + // corsOpts is a set of options for CORS headers. corsOpts *cors.Options @@ -32,6 +38,7 @@ type ServerConfig struct { func NewServerConfig() *ServerConfig { cfg := new(ServerConfig) cfg.corsOpts = new(cors.Options) + cfg.HandledMethods = []string{http.MethodPost} return cfg } @@ -142,3 +149,16 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool { return false } + +// handleRequestMethod returns true if the request method is among +// HandledMethods. +func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool { + // For very small slices as these, this should be faster than + // a map lookup. + for _, m := range cfg.HandledMethods { + if r.Method == m { + return true + } + } + return false +} diff --git a/http/errors_test.go b/http/errors_test.go index 88c963dd9880167515754894970d2b9358511db0..3c004eee4326912a79877a05ce3d1148bdb1cdd1 100644 --- a/http/errors_test.go +++ b/http/errors_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - "github.com/ipfs/go-ipfs-cmds" + cmds "github.com/ipfs/go-ipfs-cmds" ) func TestErrors(t *testing.T) { @@ -116,7 +116,7 @@ func TestErrors(t *testing.T) { mkTest := func(tc testcase) func(*testing.T) { return func(t *testing.T) { - _, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/ + _, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ c := NewClient(srv.URL) req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot) if err != nil { @@ -158,3 +158,15 @@ func TestErrors(t *testing.T) { t.Run(fmt.Sprintf("%d-%s", i, strings.Join(tc.path, "/")), mkTest(tc)) } } + +func TestUnhandledMethod(t *testing.T) { + tc := httpTestCase{ + Method: "GET", + HandledMethods: []string{"POST"}, + Code: http.StatusMethodNotAllowed, + ResHeaders: map[string]string{ + "Allow": "POST", + }, + } + tc.test(t) +} diff --git a/http/handler.go b/http/handler.go index 35b069ef48cf07e1ab27af5b3284504a8dc604a8..3e9f078a87f3e547ed46945deba139a7a8f13379 100644 --- a/http/handler.go +++ b/http/handler.go @@ -95,6 +95,15 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() + // First of all, check if we are allowed to handle the request method + // or we are configured not to. + if !handleRequestMethod(r, h.cfg) { + setAllowedHeaders(w, h.cfg.HandledMethods) + http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed) + log.Warningf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods) + return + } + if !allowOrigin(r, h.cfg) || !allowReferer(r, h.cfg) { http.Error(w, "403 - Forbidden", http.StatusForbidden) log.Warningf("API blocked request to %s. (possible CSRF)", r.URL) @@ -170,3 +179,9 @@ func sanitizedErrStr(err error) string { s = strings.Split(s, "\r")[0] return s } + +func setAllowedHeaders(w http.ResponseWriter, methods []string) { + for _, m := range methods { + w.Header().Add("Allow", m) + } +} diff --git a/http/handler_test.go b/http/handler_test.go index 10cd9f358d4de82146fbd96cd5b4dc50633fe361..9c268a8139beafc7f083772f71b8c6cfb581876b 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -292,7 +292,7 @@ var ( } ) -func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest.Server) { +func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) { if len(origins) == 0 { origins = defaultOrigins } @@ -305,7 +305,15 @@ func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest. wait: make(chan struct{}), } - return env, httptest.NewServer(NewHandler(env, cmdRoot, originCfg(origins))) + srvCfg := originCfg(origins) + + if len(handledMethods) == 0 { + srvCfg.HandledMethods = []string{"GET", "POST"} + } else { + srvCfg.HandledMethods = handledMethods + } + + return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg)) } func errEq(err1, err2 error) bool { diff --git a/http/http_test.go b/http/http_test.go index 850c77fe1b62684eb7d886bb9f7f9c1ac9f70677..4792d53375447148bc3f212c3712baa8f57cdb51 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -12,9 +12,9 @@ import ( "strings" "testing" - "github.com/ipfs/go-ipfs-cmds" + cmds "github.com/ipfs/go-ipfs-cmds" - "github.com/ipfs/go-ipfs-files" + files "github.com/ipfs/go-ipfs-files" ) func newReaderPathFile(t *testing.T, path string, reader io.ReadCloser, stat os.FileInfo) files.File { @@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) { mkTest := func(tc testcase) func(*testing.T) { return func(t *testing.T) { - env, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/ + env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ c := NewClient(srv.URL) req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot) if err != nil { diff --git a/http/reforigin_test.go b/http/reforigin_test.go index 1d00db7ea4d9269350cf9f6e85b46075c7eafcc3..08958cccaac36116a4155e3779e88b5ffb00c9fe 100644 --- a/http/reforigin_test.go +++ b/http/reforigin_test.go @@ -6,7 +6,7 @@ import ( "net/url" "testing" - "github.com/ipfs/go-ipfs-cmds" + cmds "github.com/ipfs/go-ipfs-cmds" ) func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) { @@ -27,6 +27,7 @@ func originCfg(origins []string) *ServerConfig { cfg := NewServerConfig() cfg.SetAllowedOrigins(origins...) cfg.SetAllowedMethods("GET", "PUT", "POST") + cfg.HandledMethods = []string{"GET", "POST"} return cfg } @@ -38,14 +39,15 @@ var defaultOrigins = []string{ } type httpTestCase struct { - Method string - Path string - Code int - Origin string - Referer string - AllowOrigins []string - ReqHeaders map[string]string - ResHeaders map[string]string + Method string + Path string + Code int + Origin string + Referer string + AllowOrigins []string + HandledMethods []string + ReqHeaders map[string]string + ResHeaders map[string]string } func (tc *httpTestCase) test(t *testing.T) { @@ -83,7 +85,7 @@ func (tc *httpTestCase) test(t *testing.T) { } // server - _, server := getTestServer(t, tc.AllowOrigins) + _, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods) if server == nil { return }