Unverified Commit 90182cee authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #190 from ipfs/fix/api-post

http: configurable allowed request methods for the API.
parents 59c18d03 2fbebbec
......@@ -22,6 +22,12 @@ type ServerConfig struct {
// Headers is an optional map of headers that is written out.
Headers map[string][]string
// HandledMethods set which methods will be handled for the HTTP
// requests. Other methods will return 405. This is different from CORS
// AllowedMethods (the API may handle GET and POST, but only allow GETs
// for CORS-enabled requests via AllowedMethods).
HandledMethods []string
// corsOpts is a set of options for CORS headers.
corsOpts *cors.Options
......@@ -32,6 +38,7 @@ type ServerConfig struct {
func NewServerConfig() *ServerConfig {
cfg := new(ServerConfig)
cfg.corsOpts = new(cors.Options)
cfg.HandledMethods = []string{http.MethodPost}
return cfg
}
......@@ -142,3 +149,16 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {
return false
}
// handleRequestMethod returns true if the request method is among
// HandledMethods.
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
// For very small slices as these, this should be faster than
// a map lookup.
for _, m := range cfg.HandledMethods {
if r.Method == m {
return true
}
}
return false
}
......@@ -9,7 +9,7 @@ import (
"strings"
"testing"
"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"
)
func TestErrors(t *testing.T) {
......@@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
_, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
if err != nil {
......@@ -158,3 +158,15 @@ func TestErrors(t *testing.T) {
t.Run(fmt.Sprintf("%d-%s", i, strings.Join(tc.path, "/")), mkTest(tc))
}
}
func TestUnhandledMethod(t *testing.T) {
tc := httpTestCase{
Method: "GET",
HandledMethods: []string{"POST"},
Code: http.StatusMethodNotAllowed,
ResHeaders: map[string]string{
"Allow": "POST",
},
}
tc.test(t)
}
......@@ -95,6 +95,15 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()
// First of all, check if we are allowed to handle the request method
// or we are configured not to.
if !handleRequestMethod(r, h.cfg) {
setAllowedHeaders(w, h.cfg.HandledMethods)
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
log.Warningf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods)
return
}
if !allowOrigin(r, h.cfg) || !allowReferer(r, h.cfg) {
http.Error(w, "403 - Forbidden", http.StatusForbidden)
log.Warningf("API blocked request to %s. (possible CSRF)", r.URL)
......@@ -170,3 +179,9 @@ func sanitizedErrStr(err error) string {
s = strings.Split(s, "\r")[0]
return s
}
func setAllowedHeaders(w http.ResponseWriter, methods []string) {
for _, m := range methods {
w.Header().Add("Allow", m)
}
}
......@@ -292,7 +292,7 @@ var (
}
)
func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest.Server) {
func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) {
if len(origins) == 0 {
origins = defaultOrigins
}
......@@ -305,7 +305,15 @@ func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest.
wait: make(chan struct{}),
}
return env, httptest.NewServer(NewHandler(env, cmdRoot, originCfg(origins)))
srvCfg := originCfg(origins)
if len(handledMethods) == 0 {
srvCfg.HandledMethods = []string{"GET", "POST"}
} else {
srvCfg.HandledMethods = handledMethods
}
return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
}
func errEq(err1, err2 error) bool {
......
......@@ -12,9 +12,9 @@ import (
"strings"
"testing"
"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"
"github.com/ipfs/go-ipfs-files"
files "github.com/ipfs/go-ipfs-files"
)
func newReaderPathFile(t *testing.T, path string, reader io.ReadCloser, stat os.FileInfo) files.File {
......@@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
env, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
if err != nil {
......
......@@ -6,7 +6,7 @@ import (
"net/url"
"testing"
"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"
)
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
......@@ -27,6 +27,7 @@ func originCfg(origins []string) *ServerConfig {
cfg := NewServerConfig()
cfg.SetAllowedOrigins(origins...)
cfg.SetAllowedMethods("GET", "PUT", "POST")
cfg.HandledMethods = []string{"GET", "POST"}
return cfg
}
......@@ -38,14 +39,15 @@ var defaultOrigins = []string{
}
type httpTestCase struct {
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
ReqHeaders map[string]string
ResHeaders map[string]string
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
HandledMethods []string
ReqHeaders map[string]string
ResHeaders map[string]string
}
func (tc *httpTestCase) test(t *testing.T) {
......@@ -83,7 +85,7 @@ func (tc *httpTestCase) test(t *testing.T) {
}
// server
_, server := getTestServer(t, tc.AllowOrigins)
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods)
if server == nil {
return
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment