Unverified Commit 4dfbfd8d authored by Hector Sanjuan's avatar Hector Sanjuan Committed by GitHub

Merge pull request #191 from ipfs/fix/method-handling

change HandledMethods to AllowGet and cleanup method handling
parents ef934e8d 3093cad8
...@@ -22,11 +22,13 @@ type ServerConfig struct { ...@@ -22,11 +22,13 @@ type ServerConfig struct {
// Headers is an optional map of headers that is written out. // Headers is an optional map of headers that is written out.
Headers map[string][]string Headers map[string][]string
// HandledMethods set which methods will be handled for the HTTP // AllowGet indicates whether or not this server accepts GET requests.
// requests. Other methods will return 405. This is different from CORS // When unset, the server only accepts POST, HEAD, and OPTIONS.
// AllowedMethods (the API may handle GET and POST, but only allow GETs //
// for CORS-enabled requests via AllowedMethods). // This is different from CORS AllowedMethods. The API may allow GET
HandledMethods []string // requests in general, but reject them in CORS. That will allow
// websites to include resources from the API but not _read_ them.
AllowGet bool
// corsOpts is a set of options for CORS headers. // corsOpts is a set of options for CORS headers.
corsOpts *cors.Options corsOpts *cors.Options
...@@ -38,7 +40,6 @@ type ServerConfig struct { ...@@ -38,7 +40,6 @@ type ServerConfig struct {
func NewServerConfig() *ServerConfig { func NewServerConfig() *ServerConfig {
cfg := new(ServerConfig) cfg := new(ServerConfig)
cfg.corsOpts = new(cors.Options) cfg.corsOpts = new(cors.Options)
cfg.HandledMethods = []string{http.MethodPost}
return cfg return cfg
} }
...@@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool { ...@@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {
return false return false
} }
// handleRequestMethod returns true if the request method is among
// HandledMethods.
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
// For very small slices as these, this should be faster than
// a map lookup.
for _, m := range cfg.HandledMethods {
if r.Method == m {
return true
}
}
return false
}
...@@ -116,7 +116,7 @@ func TestErrors(t *testing.T) { ...@@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) { mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ _, srv := getTestServer(t, nil, false) // handler_test:/^func getTestServer/
c := NewClient(srv.URL) c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot) req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
if err != nil { if err != nil {
...@@ -161,11 +161,11 @@ func TestErrors(t *testing.T) { ...@@ -161,11 +161,11 @@ func TestErrors(t *testing.T) {
func TestUnhandledMethod(t *testing.T) { func TestUnhandledMethod(t *testing.T) {
tc := httpTestCase{ tc := httpTestCase{
Method: "GET", Method: "GET",
HandledMethods: []string{"POST"}, AllowGet: false,
Code: http.StatusMethodNotAllowed, Code: http.StatusMethodNotAllowed,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
"Allow": "POST", "Allow": "POST, HEAD, OPTIONS",
}, },
} }
tc.test(t) tc.test(t)
......
...@@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// First of all, check if we are allowed to handle the request method // First of all, check if we are allowed to handle the request method
// or we are configured not to. // or we are configured not to.
if !handleRequestMethod(r, h.cfg) { //
setAllowedHeaders(w, h.cfg.HandledMethods) // Always allow OPTIONS, POST
switch r.Method {
case http.MethodOptions:
// If we get here, this is a normal (non-preflight) request.
// The CORS library handles all other requests.
// Tell the user the allowed methods, and return.
setAllowedHeaders(w, h.cfg.AllowGet)
w.WriteHeader(http.StatusNoContent)
return
case http.MethodPost:
case http.MethodGet, http.MethodHead:
if h.cfg.AllowGet {
break
}
fallthrough
default:
setAllowedHeaders(w, h.cfg.AllowGet)
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed) http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
log.Warnf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods) log.Warnf("The IPFS API does not support %s requests.", r.Method)
return return
} }
...@@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}
// Handle the timeout up front. // Handle the timeout up front.
var cancel func() var cancel func()
if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok { if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok {
...@@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer done() defer done()
} }
// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}
h.root.Call(req, re, h.env) h.root.Call(req, re, h.env)
} }
...@@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string { ...@@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string {
return s return s
} }
func setAllowedHeaders(w http.ResponseWriter, methods []string) { func setAllowedHeaders(w http.ResponseWriter, allowGet bool) {
for _, m := range methods { w.Header().Add("Allow", http.MethodHead)
w.Header().Add("Allow", m) w.Header().Add("Allow", http.MethodOptions)
w.Header().Add("Allow", http.MethodPost)
if allowGet {
w.Header().Add("Allow", http.MethodGet)
} }
} }
...@@ -292,7 +292,7 @@ var ( ...@@ -292,7 +292,7 @@ var (
} }
) )
func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) { func getTestServer(t *testing.T, origins []string, allowGet bool) (cmds.Environment, *httptest.Server) {
if len(origins) == 0 { if len(origins) == 0 {
origins = defaultOrigins origins = defaultOrigins
} }
...@@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd ...@@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd
} }
srvCfg := originCfg(origins) srvCfg := originCfg(origins)
srvCfg.AllowGet = allowGet
if len(handledMethods) == 0 {
srvCfg.HandledMethods = []string{"GET", "POST"}
} else {
srvCfg.HandledMethods = handledMethods
}
return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg)) return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
} }
......
...@@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) { ...@@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) { mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/ env, srv := getTestServer(t, nil, true) // handler_test:/^func getTestServer/
c := NewClient(srv.URL) c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot) req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
if err != nil { if err != nil {
......
...@@ -4,15 +4,42 @@ import ( ...@@ -4,15 +4,42 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"testing" "testing"
cmds "github.com/ipfs/go-ipfs-cmds" cmds "github.com/ipfs/go-ipfs-cmds"
) )
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) { func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
t.Helper()
t.Logf("headers: %v", resHeaders)
for name, value := range reqHeaders { for name, value := range reqHeaders {
if resHeaders.Get(name) != value { header := resHeaders[http.CanonicalHeaderKey(name)]
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name)) switch len(header) {
case 0:
if value != "" {
t.Errorf("expected a header for %s", name)
}
case 1:
if header[0] != value {
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, header[0])
}
default:
values := strings.Split(value, ",")
set := make(map[string]bool, len(values))
for _, v := range values {
set[strings.Trim(v, " ")] = true
}
for _, got := range header {
if !set[got] {
t.Errorf("found unexpected value %s in header %s", got, name)
continue
}
delete(set, got)
}
for missing := range set {
t.Errorf("missing value %s in header %s", missing, name)
}
} }
} }
} }
...@@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig { ...@@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig {
cfg := NewServerConfig() cfg := NewServerConfig()
cfg.SetAllowedOrigins(origins...) cfg.SetAllowedOrigins(origins...)
cfg.SetAllowedMethods("GET", "PUT", "POST") cfg.SetAllowedMethods("GET", "PUT", "POST")
cfg.HandledMethods = []string{"GET", "POST"} cfg.AllowGet = true
return cfg return cfg
} }
...@@ -39,18 +66,19 @@ var defaultOrigins = []string{ ...@@ -39,18 +66,19 @@ var defaultOrigins = []string{
} }
type httpTestCase struct { type httpTestCase struct {
Method string Method string
Path string Path string
Code int Code int
Origin string Origin string
Referer string Referer string
AllowOrigins []string AllowOrigins []string
HandledMethods []string AllowGet bool
ReqHeaders map[string]string ReqHeaders map[string]string
ResHeaders map[string]string ResHeaders map[string]string
} }
func (tc *httpTestCase) test(t *testing.T) { func (tc *httpTestCase) test(t *testing.T) {
t.Helper()
// defaults // defaults
method := tc.Method method := tc.Method
if method == "" { if method == "" {
...@@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) { ...@@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) {
} }
// server // server
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods) _, server := getTestServer(t, tc.AllowOrigins, tc.AllowGet)
if server == nil { if server == nil {
return return
} }
...@@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) { ...@@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) {
return httpTestCase{ return httpTestCase{
Origin: origin, Origin: origin,
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: "", ACAOrigin: "",
ACAMethods: "", ACAMethods: "",
...@@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) { ...@@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) {
return httpTestCase{ return httpTestCase{
Origin: origin, Origin: origin,
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: origin, ACAOrigin: origin,
ACAMethods: "", ACAMethods: "",
...@@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) { ...@@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) {
gtc := func(origin string, allowedOrigins []string) httpTestCase { gtc := func(origin string, allowedOrigins []string) httpTestCase {
return httpTestCase{ return httpTestCase{
Origin: origin, Origin: origin,
AllowGet: true,
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: "*", ACAOrigin: "*",
...@@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) { ...@@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) {
return httpTestCase{ return httpTestCase{
Origin: "http://localhost", Origin: "http://localhost",
Referer: referer, Referer: referer,
AllowGet: true,
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: "http://localhost", ACAOrigin: "http://localhost",
...@@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) { ...@@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) {
return httpTestCase{ return httpTestCase{
Origin: "http://localhost", Origin: "http://localhost",
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: "http://localhost", ACAOrigin: "http://localhost",
ACAMethods: "", ACAMethods: "",
...@@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) { ...@@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) {
return httpTestCase{ return httpTestCase{
Origin: origin, Origin: origin,
AllowOrigins: allowedOrigins, AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{ ResHeaders: map[string]string{
ACAOrigin: "*", ACAOrigin: "*",
ACAMethods: "", ACAMethods: "",
...@@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) { ...@@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) {
return httpTestCase{ return httpTestCase{
Method: "GET", Method: "GET",
Path: path, Path: path,
AllowGet: true,
Origin: "http://localhost", Origin: "http://localhost",
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
ReqHeaders: map[string]string{ ReqHeaders: map[string]string{
......
...@@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error { ...@@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error {
var err error var err error
// return immediately if this is a head request // return immediately if this is a head request
if re.method == "HEAD" { if re.method == http.MethodHead {
return nil return nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment