Unverified Commit 4dfbfd8d authored by Hector Sanjuan's avatar Hector Sanjuan Committed by GitHub

Merge pull request #191 from ipfs/fix/method-handling

change HandledMethods to AllowGet and cleanup method handling
parents ef934e8d 3093cad8
......@@ -22,11 +22,13 @@ type ServerConfig struct {
// Headers is an optional map of headers that is written out.
Headers map[string][]string
// HandledMethods set which methods will be handled for the HTTP
// requests. Other methods will return 405. This is different from CORS
// AllowedMethods (the API may handle GET and POST, but only allow GETs
// for CORS-enabled requests via AllowedMethods).
HandledMethods []string
// AllowGet indicates whether or not this server accepts GET requests.
// When unset, the server only accepts POST, HEAD, and OPTIONS.
//
// This is different from CORS AllowedMethods. The API may allow GET
// requests in general, but reject them in CORS. That will allow
// websites to include resources from the API but not _read_ them.
AllowGet bool
// corsOpts is a set of options for CORS headers.
corsOpts *cors.Options
......@@ -38,7 +40,6 @@ type ServerConfig struct {
func NewServerConfig() *ServerConfig {
cfg := new(ServerConfig)
cfg.corsOpts = new(cors.Options)
cfg.HandledMethods = []string{http.MethodPost}
return cfg
}
......@@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {
return false
}
// handleRequestMethod returns true if the request method is among
// HandledMethods.
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
// For very small slices as these, this should be faster than
// a map lookup.
for _, m := range cfg.HandledMethods {
if r.Method == m {
return true
}
}
return false
}
......@@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
_, srv := getTestServer(t, nil, false) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
if err != nil {
......@@ -161,11 +161,11 @@ func TestErrors(t *testing.T) {
func TestUnhandledMethod(t *testing.T) {
tc := httpTestCase{
Method: "GET",
HandledMethods: []string{"POST"},
Code: http.StatusMethodNotAllowed,
Method: "GET",
AllowGet: false,
Code: http.StatusMethodNotAllowed,
ResHeaders: map[string]string{
"Allow": "POST",
"Allow": "POST, HEAD, OPTIONS",
},
}
tc.test(t)
......
......@@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// First of all, check if we are allowed to handle the request method
// or we are configured not to.
if !handleRequestMethod(r, h.cfg) {
setAllowedHeaders(w, h.cfg.HandledMethods)
//
// Always allow OPTIONS, POST
switch r.Method {
case http.MethodOptions:
// If we get here, this is a normal (non-preflight) request.
// The CORS library handles all other requests.
// Tell the user the allowed methods, and return.
setAllowedHeaders(w, h.cfg.AllowGet)
w.WriteHeader(http.StatusNoContent)
return
case http.MethodPost:
case http.MethodGet, http.MethodHead:
if h.cfg.AllowGet {
break
}
fallthrough
default:
setAllowedHeaders(w, h.cfg.AllowGet)
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
log.Warnf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods)
log.Warnf("The IPFS API does not support %s requests.", r.Method)
return
}
......@@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}
// Handle the timeout up front.
var cancel func()
if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok {
......@@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer done()
}
// set user's headers first.
for k, v := range h.cfg.Headers {
if !skipAPIHeader(k) {
w.Header()[k] = v
}
}
h.root.Call(req, re, h.env)
}
......@@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string {
return s
}
func setAllowedHeaders(w http.ResponseWriter, methods []string) {
for _, m := range methods {
w.Header().Add("Allow", m)
func setAllowedHeaders(w http.ResponseWriter, allowGet bool) {
w.Header().Add("Allow", http.MethodHead)
w.Header().Add("Allow", http.MethodOptions)
w.Header().Add("Allow", http.MethodPost)
if allowGet {
w.Header().Add("Allow", http.MethodGet)
}
}
......@@ -292,7 +292,7 @@ var (
}
)
func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) {
func getTestServer(t *testing.T, origins []string, allowGet bool) (cmds.Environment, *httptest.Server) {
if len(origins) == 0 {
origins = defaultOrigins
}
......@@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd
}
srvCfg := originCfg(origins)
if len(handledMethods) == 0 {
srvCfg.HandledMethods = []string{"GET", "POST"}
} else {
srvCfg.HandledMethods = handledMethods
}
srvCfg.AllowGet = allowGet
return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
}
......
......@@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {
mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
env, srv := getTestServer(t, nil, true) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
if err != nil {
......
......@@ -4,15 +4,42 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"testing"
cmds "github.com/ipfs/go-ipfs-cmds"
)
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
t.Helper()
t.Logf("headers: %v", resHeaders)
for name, value := range reqHeaders {
if resHeaders.Get(name) != value {
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name))
header := resHeaders[http.CanonicalHeaderKey(name)]
switch len(header) {
case 0:
if value != "" {
t.Errorf("expected a header for %s", name)
}
case 1:
if header[0] != value {
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, header[0])
}
default:
values := strings.Split(value, ",")
set := make(map[string]bool, len(values))
for _, v := range values {
set[strings.Trim(v, " ")] = true
}
for _, got := range header {
if !set[got] {
t.Errorf("found unexpected value %s in header %s", got, name)
continue
}
delete(set, got)
}
for missing := range set {
t.Errorf("missing value %s in header %s", missing, name)
}
}
}
}
......@@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig {
cfg := NewServerConfig()
cfg.SetAllowedOrigins(origins...)
cfg.SetAllowedMethods("GET", "PUT", "POST")
cfg.HandledMethods = []string{"GET", "POST"}
cfg.AllowGet = true
return cfg
}
......@@ -39,18 +66,19 @@ var defaultOrigins = []string{
}
type httpTestCase struct {
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
HandledMethods []string
ReqHeaders map[string]string
ResHeaders map[string]string
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
AllowGet bool
ReqHeaders map[string]string
ResHeaders map[string]string
}
func (tc *httpTestCase) test(t *testing.T) {
t.Helper()
// defaults
method := tc.Method
if method == "" {
......@@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) {
}
// server
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods)
_, server := getTestServer(t, tc.AllowOrigins, tc.AllowGet)
if server == nil {
return
}
......@@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "",
ACAMethods: "",
......@@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: origin,
ACAMethods: "",
......@@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) {
gtc := func(origin string, allowedOrigins []string) httpTestCase {
return httpTestCase{
Origin: origin,
AllowGet: true,
AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{
ACAOrigin: "*",
......@@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) {
return httpTestCase{
Origin: "http://localhost",
Referer: referer,
AllowGet: true,
AllowOrigins: allowedOrigins,
ResHeaders: map[string]string{
ACAOrigin: "http://localhost",
......@@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) {
return httpTestCase{
Origin: "http://localhost",
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "http://localhost",
ACAMethods: "",
......@@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) {
return httpTestCase{
Origin: origin,
AllowOrigins: allowedOrigins,
AllowGet: true,
ResHeaders: map[string]string{
ACAOrigin: "*",
ACAMethods: "",
......@@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) {
return httpTestCase{
Method: "GET",
Path: path,
AllowGet: true,
Origin: "http://localhost",
AllowOrigins: []string{"*"},
ReqHeaders: map[string]string{
......
......@@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error {
var err error
// return immediately if this is a head request
if re.method == "HEAD" {
if re.method == http.MethodHead {
return nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment