Commit df28dcfc authored by David Braun's avatar David Braun

Add CORS middleware handler to the API.

parent a586024e
...@@ -8,6 +8,8 @@ import ( ...@@ -8,6 +8,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/rs/cors"
context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context"
cmds "github.com/ipfs/go-ipfs/commands" cmds "github.com/ipfs/go-ipfs/commands"
...@@ -16,10 +18,17 @@ import ( ...@@ -16,10 +18,17 @@ import (
var log = u.Logger("commands/http") var log = u.Logger("commands/http")
type Handler struct { // the internal handler for the API
type internalHandler struct {
ctx cmds.Context ctx cmds.Context
root *cmds.Command root *cmds.Command
origin string }
// The Handler struct is funny because we want to wrap our internal handler
// with CORS while keeping our fields.
type Handler struct {
internalHandler
corsHandler http.Handler
} }
var ErrNotFound = errors.New("404 page not found") var ErrNotFound = errors.New("404 page not found")
...@@ -39,16 +48,31 @@ var mimeTypes = map[string]string{ ...@@ -39,16 +48,31 @@ var mimeTypes = map[string]string{
cmds.Text: "text/plain", cmds.Text: "text/plain",
} }
func NewHandler(ctx cmds.Context, root *cmds.Command, origin string) *Handler { func NewHandler(ctx cmds.Context, root *cmds.Command, allowedOrigin string) *Handler {
// allow whitelisted origins (so we can make API requests from the browser) // allow whitelisted origins (so we can make API requests from the browser)
if len(origin) > 0 { if len(allowedOrigin) > 0 {
log.Info("Allowing API requests from origin: " + origin) log.Info("Allowing API requests from origin: " + allowedOrigin)
} }
return &Handler{ctx, root, origin} // Create a handler for the API.
internal := internalHandler{ctx, root}
// Create a CORS object for wrapping the internal handler.
c := cors.New(cors.Options{
AllowedMethods: []string{"GET", "POST", "PUT"},
// use AllowOriginFunc instead of AllowedOrigins because we want to be
// restrictive by default.
AllowOriginFunc: func(origin string) bool {
return (allowedOrigin == "*") || (origin == allowedOrigin)
},
})
// Wrap the internal handler with CORS handling-middleware.
return &Handler{internal, c.Handler(internal)}
} }
func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (i internalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug("Incoming API request: ", r.URL) log.Debug("Incoming API request: ", r.URL)
// error on external referers (to prevent CSRF attacks) // error on external referers (to prevent CSRF attacks)
...@@ -65,11 +89,6 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -65,11 +89,6 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(i.origin) > 0 {
w.Header().Set("Access-Control-Allow-Origin", i.origin)
}
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
req, err := Parse(r, i.root) req, err := Parse(r, i.root)
if err != nil { if err != nil {
if err == ErrNotFound { if err == ErrNotFound {
...@@ -168,6 +187,11 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -168,6 +187,11 @@ func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
flushCopy(w, out) flushCopy(w, out)
} }
func (i Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Call the CORS handler which wraps the internal handler.
i.corsHandler.ServeHTTP(w, r)
}
// flushCopy Copies from an io.Reader to a http.ResponseWriter. // flushCopy Copies from an io.Reader to a http.ResponseWriter.
// Flushes chunks over HTTP stream as they are read (if supported by transport). // Flushes chunks over HTTP stream as they are read (if supported by transport).
func flushCopy(w http.ResponseWriter, out io.Reader) error { func flushCopy(w http.ResponseWriter, out io.Reader) error {
......
package http
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/ipfs/go-ipfs/commands"
)
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
for name, value := range reqHeaders {
if resHeaders.Get(name) != value {
t.Errorf("Invalid header `%s', wanted `%s', got `%s'", name, value, resHeaders.Get(name))
}
}
}
func TestDisallowedOrigin(t *testing.T) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://barbaz.com")
handler := NewHandler(commands.Context{}, nil, "")
handler.ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestWildcardOrigin(t *testing.T) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
handler := NewHandler(commands.Context{}, nil, "*")
handler.ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestAllowedMethod(t *testing.T) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://www.foobar.com")
req.Header.Add("Access-Control-Request-Method", "PUT")
handler := NewHandler(commands.Context{}, nil, "http://www.foobar.com")
handler.ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://www.foobar.com",
"Access-Control-Allow-Methods": "PUT",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment