Unverified Commit 67606650 authored by Hannah Howard's avatar Hannah Howard Committed by GitHub

Option to Reject requests by default (#58)

* refactor(hooks): refactor default validation as hook, add unregister option

* feat(graphsync): add disable default validation option

* fix(responsemanager): fix mutex  unlocking

cover case where unlocking was not happening
parent b3cc648d
......@@ -154,6 +154,9 @@ type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions Requ
// If it returns an error processing is halted and the original request is cancelled.
type OnResponseReceivedHook func(p peer.ID, responseData ResponseData) error
// UnregisterHookFunc is a function call to unregister a hook that was previously registered
type UnregisterHookFunc func()
// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
type GraphExchange interface {
// Request initiates a new GraphSync request to the given peer using the given selector spec.
......@@ -163,8 +166,8 @@ type GraphExchange interface {
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
RegisterRequestReceivedHook(hook OnRequestReceivedHook) error
RegisterRequestReceivedHook(hook OnRequestReceivedHook) UnregisterHookFunc
// RegisterResponseReceivedHook adds a hook that runs when a response is received
RegisterResponseReceivedHook(OnResponseReceivedHook) error
RegisterResponseReceivedHook(OnResponseReceivedHook) UnregisterHookFunc
}
......@@ -4,15 +4,15 @@ import (
"context"
"github.com/ipfs/go-graphsync"
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/messagequeue"
gsnet "github.com/ipfs/go-graphsync/network"
"github.com/ipfs/go-graphsync/peermanager"
"github.com/ipfs/go-graphsync/requestmanager"
"github.com/ipfs/go-graphsync/requestmanager/asyncloader"
"github.com/ipfs/go-graphsync/responsemanager"
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
"github.com/ipfs/go-graphsync/selectorvalidator"
logging "github.com/ipfs/go-log"
"github.com/ipfs/go-peertaskqueue"
ipld "github.com/ipld/go-ipld-prime"
......@@ -21,26 +21,41 @@ import (
var log = logging.Logger("graphsync")
const maxRecursionDepth = 100
// GraphSync is an instance of a GraphSync exchange that implements
// the graphsync protocol.
type GraphSync struct {
network gsnet.GraphSyncNetwork
loader ipld.Loader
storer ipld.Storer
requestManager *requestmanager.RequestManager
responseManager *responsemanager.ResponseManager
asyncLoader *asyncloader.AsyncLoader
peerResponseManager *peerresponsemanager.PeerResponseManager
peerTaskQueue *peertaskqueue.PeerTaskQueue
peerManager *peermanager.PeerMessageManager
ctx context.Context
cancel context.CancelFunc
network gsnet.GraphSyncNetwork
loader ipld.Loader
storer ipld.Storer
requestManager *requestmanager.RequestManager
responseManager *responsemanager.ResponseManager
asyncLoader *asyncloader.AsyncLoader
peerResponseManager *peerresponsemanager.PeerResponseManager
peerTaskQueue *peertaskqueue.PeerTaskQueue
peerManager *peermanager.PeerMessageManager
ctx context.Context
cancel context.CancelFunc
unregisterDefaultValidator graphsync.UnregisterHookFunc
}
// Option defines the functional option type that can be used to configure
// graphsync instances
type Option func(*GraphSync)
// RejectAllRequestsByDefault means that without hooks registered
// that perform their own request validation, all requests are rejected
func RejectAllRequestsByDefault() Option {
return func(gs *GraphSync) {
gs.unregisterDefaultValidator()
}
}
// New creates a new GraphSync Exchange on the given network,
// and the given link loader+storer.
func New(parent context.Context, network gsnet.GraphSyncNetwork,
loader ipld.Loader, storer ipld.Storer) graphsync.GraphExchange {
loader ipld.Loader, storer ipld.Storer, options ...Option) graphsync.GraphExchange {
ctx, cancel := context.WithCancel(parent)
createMessageQueue := func(ctx context.Context, p peer.ID) peermanager.PeerQueue {
......@@ -55,18 +70,24 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
}
peerResponseManager := peerresponsemanager.New(ctx, createdResponseQueue)
responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue)
unregisterDefaultValidator := responseManager.RegisterHook(selectorvalidator.SelectorValidator(maxRecursionDepth))
graphSync := &GraphSync{
network: network,
loader: loader,
storer: storer,
asyncLoader: asyncLoader,
requestManager: requestManager,
peerManager: peerManager,
peerTaskQueue: peerTaskQueue,
peerResponseManager: peerResponseManager,
responseManager: responseManager,
ctx: ctx,
cancel: cancel,
network: network,
loader: loader,
storer: storer,
asyncLoader: asyncLoader,
requestManager: requestManager,
peerManager: peerManager,
peerTaskQueue: peerTaskQueue,
peerResponseManager: peerResponseManager,
responseManager: responseManager,
ctx: ctx,
cancel: cancel,
unregisterDefaultValidator: unregisterDefaultValidator,
}
for _, option := range options {
option(graphSync)
}
asyncLoader.Startup()
......@@ -86,15 +107,13 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(hook)
return nil
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) graphsync.UnregisterHookFunc {
return gs.responseManager.RegisterHook(hook)
}
// RegisterResponseReceivedHook adds a hook that runs when a response is received
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) error {
gs.requestManager.RegisterHook(hook)
return nil
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
return gs.requestManager.RegisterHook(hook)
}
type graphSyncReceiver GraphSync
......
......@@ -99,7 +99,7 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
var receivedRequestData []byte
// initialize graphsync on second node to response to requests
gsnet := td.GraphSyncHost2()
err := gsnet.RegisterRequestReceivedHook(
gsnet.RegisterRequestReceivedHook(
func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
var has bool
receivedRequestData, has = requestData.Extension(td.extensionName)
......@@ -107,7 +107,6 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
hookActions.SendExtensionData(td.extensionResponse)
},
)
require.NoError(t, err, "error registering extension")
blockChainLength := 100
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 100, blockChainLength)
......@@ -117,7 +116,7 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
message := gsmsg.New()
message.AddRequest(gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), td.extension))
// send request across network
err = td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
err := td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
require.NoError(t, err)
// read the values sent back to requestor
var received gsmsg.GraphSyncMessage
......@@ -150,6 +149,27 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
require.Equal(t, td.extensionResponseData, receivedExtensions[0], "did not return correct extension data")
}
func TestRejectRequestsByDefault(t *testing.T) {
// create network
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
td := newGsTestData(ctx, t)
requestor := td.GraphSyncHost1()
// setup responder to disable default validation, meaning all requests are rejected
_ = td.GraphSyncHost2(RejectAllRequestsByDefault())
blockChainLength := 5
blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, 5, blockChainLength)
// send request across network
progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)
testutil.VerifyEmptyResponse(ctx, t, progressChan)
testutil.VerifySingleTerminalError(ctx, t, errChan)
}
func TestGraphsyncRoundTrip(t *testing.T) {
// create network
ctx := context.Background()
......@@ -170,7 +190,7 @@ func TestGraphsyncRoundTrip(t *testing.T) {
var receivedResponseData []byte
var receivedRequestData []byte
err := requestor.RegisterResponseReceivedHook(
requestor.RegisterResponseReceivedHook(
func(p peer.ID, responseData graphsync.ResponseData) error {
data, has := responseData.Extension(td.extensionName)
if has {
......@@ -178,9 +198,8 @@ func TestGraphsyncRoundTrip(t *testing.T) {
}
return nil
})
require.NoError(t, err, "Error setting up extension")
err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
var has bool
receivedRequestData, has = requestData.Extension(td.extensionName)
if !has {
......@@ -189,7 +208,6 @@ func TestGraphsyncRoundTrip(t *testing.T) {
hookActions.SendExtensionData(td.extensionResponse)
}
})
require.NoError(t, err, "Error setting up extension")
progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)
......@@ -342,15 +360,14 @@ func TestUnixFSFetch(t *testing.T) {
requestor := New(ctx, td.gsnet1, loader1, storer1)
responder := New(ctx, td.gsnet2, loader2, storer2)
extensionName := graphsync.ExtensionName("Free for all")
err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(graphsync.ExtensionData{
Name: extensionName,
Data: nil,
})
})
require.NoError(t, err)
// make a go-ipld-prime link for the root UnixFS node
clink := cidlink.Link{Cid: nd.Cid()}
......@@ -443,13 +460,13 @@ func newGsTestData(ctx context.Context, t *testing.T) *gsTestData {
return td
}
func (td *gsTestData) GraphSyncHost1() graphsync.GraphExchange {
return New(td.ctx, td.gsnet1, td.loader1, td.storer1)
func (td *gsTestData) GraphSyncHost1(options ...Option) graphsync.GraphExchange {
return New(td.ctx, td.gsnet1, td.loader1, td.storer1, options...)
}
func (td *gsTestData) GraphSyncHost2() graphsync.GraphExchange {
func (td *gsTestData) GraphSyncHost2(options ...Option) graphsync.GraphExchange {
return New(td.ctx, td.gsnet2, td.loader2, td.storer2)
return New(td.ctx, td.gsnet2, td.loader2, td.storer2, options...)
}
type receivedMessage struct {
......
......@@ -34,6 +34,7 @@ type inProgressRequestStatus struct {
}
type responseHook struct {
key uint64
hook graphsync.OnResponseReceivedHook
}
......@@ -65,6 +66,7 @@ type RequestManager struct {
// dont touch out side of run loop
nextRequestID graphsync.RequestID
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
responseHookNextKey uint64
responseHooks []responseHook
}
......@@ -201,12 +203,25 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn
}
}
type registerHookMessage struct {
hook graphsync.OnResponseReceivedHook
unregisterHookChan chan graphsync.UnregisterHookFunc
}
// RegisterHook registers an extension to processincoming responses
func (rm *RequestManager) RegisterHook(
hook graphsync.OnResponseReceivedHook) {
hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
response := make(chan graphsync.UnregisterHookFunc)
select {
case rm.messages <- &registerHookMessage{hook, response}:
case <-rm.ctx.Done():
return nil
}
select {
case rm.messages <- &responseHook{hook}:
case unregister := <-response:
return unregister
case <-rm.ctx.Done():
return nil
}
}
......@@ -285,8 +300,21 @@ func (prm *processResponseMessage) handle(rm *RequestManager) {
rm.processTerminations(filteredResponses)
}
func (rh *responseHook) handle(rm *RequestManager) {
rm.responseHooks = append(rm.responseHooks, *rh)
func (rhm *registerHookMessage) handle(rm *RequestManager) {
rh := responseHook{rm.responseHookNextKey, rhm.hook}
rm.responseHookNextKey++
rm.responseHooks = append(rm.responseHooks, rh)
select {
case rhm.unregisterHookChan <- func() {
for i, matchHook := range rm.responseHooks {
if rh.key == matchHook.key {
rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...)
return
}
}
}:
case <-rm.ctx.Done():
}
}
func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
......
......@@ -2,6 +2,7 @@ package responsemanager
import (
"context"
"sync"
"time"
"github.com/ipfs/go-graphsync"
......@@ -9,7 +10,6 @@ import (
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/responsemanager/loader"
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
"github.com/ipfs/go-graphsync/responsemanager/selectorvalidator"
"github.com/ipfs/go-peertaskqueue/peertask"
ipld "github.com/ipld/go-ipld-prime"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
......@@ -19,7 +19,6 @@ import (
const (
maxInProcessRequests = 6
maxRecursionDepth = 100
thawSpeed = time.Millisecond * 100
)
......@@ -40,6 +39,7 @@ type responseTaskData struct {
}
type requestHook struct {
key uint64
hook graphsync.OnRequestReceivedHook
}
......@@ -75,6 +75,8 @@ type ResponseManager struct {
workSignal chan struct{}
ticker *time.Ticker
inProgressResponses map[responseKey]inProgressResponseStatus
requestHooksLk sync.RWMutex
requestHookNextKey uint64
requestHooks []requestHook
}
......@@ -113,10 +115,21 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque
}
// RegisterHook registers an extension to process new incoming requests
func (rm *ResponseManager) RegisterHook(hook graphsync.OnRequestReceivedHook) {
select {
case rm.messages <- &requestHook{hook}:
case <-rm.ctx.Done():
func (rm *ResponseManager) RegisterHook(hook graphsync.OnRequestReceivedHook) graphsync.UnregisterHookFunc {
rm.requestHooksLk.Lock()
rh := requestHook{rm.requestHookNextKey, hook}
rm.requestHookNextKey++
rm.requestHooks = append(rm.requestHooks, rh)
rm.requestHooksLk.Unlock()
return func() {
rm.requestHooksLk.Lock()
defer rm.requestHooksLk.Unlock()
for i, matchHook := range rm.requestHooks {
if rh.key == matchHook.key {
rm.requestHooks = append(rm.requestHooks[:i], rm.requestHooks[i+1:]...)
return
}
}
}
}
......@@ -217,18 +230,18 @@ func (rm *ResponseManager) executeQuery(ctx context.Context,
peerResponseSender := rm.peerManager.SenderForPeer(p)
selectorSpec := request.Selector()
ha := &hookActions{false, request.ID(), peerResponseSender, nil}
rm.requestHooksLk.RLock()
for _, requestHook := range rm.requestHooks {
requestHook.hook(p, request, ha)
if ha.err != nil {
rm.requestHooksLk.RUnlock()
return
}
}
rm.requestHooksLk.RUnlock()
if !ha.isValidated {
err := selectorvalidator.ValidateSelector(selectorSpec, maxRecursionDepth)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
selector, err := ipldutil.ParseSelector(selectorSpec)
if err != nil {
......@@ -304,10 +317,6 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) {
}
}
func (rh *requestHook) handle(rm *ResponseManager) {
rm.requestHooks = append(rm.requestHooks, *rh)
}
func (rdr *responseDataRequest) handle(rm *ResponseManager) {
response, ok := rm.inProgressResponses[rdr.key]
var taskData *responseTaskData
......
......@@ -12,6 +12,7 @@ import (
"github.com/ipfs/go-graphsync"
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
"github.com/ipfs/go-graphsync/selectorvalidator"
"github.com/ipfs/go-graphsync/testutil"
"github.com/ipfs/go-peertaskqueue/peertask"
ipld "github.com/ipld/go-ipld-prime"
......@@ -142,6 +143,7 @@ func TestIncomingQuery(t *testing.T) {
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
responseManager.Startup()
requestID := graphsync.RequestID(rand.Int31())
......@@ -179,6 +181,7 @@ func TestCancellationQueryInProgress(t *testing.T) {
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
responseManager.Startup()
requestID := graphsync.RequestID(rand.Int31())
......@@ -287,85 +290,94 @@ func TestValidationAndExtensions(t *testing.T) {
Data: extensionResponseData,
}
t.Run("with invalid selector", func(t *testing.T) {
selectorSpec := testutil.NewInvalidSelectorSpec()
requestID := graphsync.RequestID(rand.Int31())
requests := []gsmsg.GraphSyncRequest{
gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, selectorSpec, graphsync.Priority(math.MaxInt32), extension),
}
p := testutil.GeneratePeers(1)[0]
t.Run("on its own, should fail validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
})
requestID := graphsync.RequestID(rand.Int31())
requests := []gsmsg.GraphSyncRequest{
gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
}
p := testutil.GeneratePeers(1)[0]
t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
t.Run("on its own, should fail validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
})
t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
})
t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
})
t.Run("with valid selector", func(t *testing.T) {
requestID := graphsync.RequestID(rand.Int31())
requests := []gsmsg.GraphSyncRequest{
gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
}
p := testutil.GeneratePeers(1)[0]
t.Run("on its own, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
t.Run("if any hook fails, should fail", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
})
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
hookActions.TerminateWithError(errors.New("everything went to crap"))
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
})
t.Run("if any hook fails, should fail", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
hookActions.TerminateWithError(errors.New("everything went to crap"))
})
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
t.Run("hooks can be unregistered", func(t *testing.T) {
responseManager := New(ctx, loader, peerManager, queryQueue)
responseManager.Startup()
unregister := responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(extensionResponse)
})
// hook validates request
responseManager.ProcessRequests(ctx, p, requests)
var lastRequest completedRequest
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
var receivedExtension sentExtension
testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
// unregister
unregister()
// no same request should fail
responseManager.ProcessRequests(ctx, p, requests)
testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
})
}
......@@ -3,11 +3,13 @@ package selectorvalidator
import (
"errors"
"github.com/ipfs/go-graphsync"
ipld "github.com/ipld/go-ipld-prime"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
"github.com/ipld/go-ipld-prime/traversal"
"github.com/ipld/go-ipld-prime/traversal/selector"
"github.com/ipld/go-ipld-prime/traversal/selector/builder"
"github.com/libp2p/go-libp2p-core/peer"
)
var (
......@@ -16,10 +18,21 @@ var (
ErrInvalidLimit = errors.New("unsupported recursive selector limit")
)
// ValidateSelector applies the default selector validation policy to a selector
// on an incoming request -- which by default is to limit recursive selectors
// to a fixed depth
func ValidateSelector(node ipld.Node, maxAcceptedDepth int) error {
// SelectorValidator returns an OnRequestReceivedHook that only validates
// requests if their selector only has no recursions that are greater than
// maxAcceptedDepth
func SelectorValidator(maxAcceptedDepth int) graphsync.OnRequestReceivedHook {
return func(p peer.ID, request graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
err := ValidateMaxRecursionDepth(request.Selector(), maxAcceptedDepth)
if err == nil {
hookActions.ValidateRequest()
}
}
}
// ValidateMaxRecursionDepth examines the given selector node and verifies
// recursive selectors are limited to the given fixed depth
func ValidateMaxRecursionDepth(node ipld.Node, maxAcceptedDepth int) error {
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
// this selector is a selector for traversing selectors...
......
......@@ -11,7 +11,7 @@ import (
"github.com/ipld/go-ipld-prime/traversal/selector/builder"
)
func TestValidateSelector(t *testing.T) {
func TestValidateMaxRecusionDepth(t *testing.T) {
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
successBase := ssb.ExploreRecursive(selector.RecursionLimitDepth(80), ssb.ExploreRecursiveEdge())
......@@ -19,11 +19,11 @@ func TestValidateSelector(t *testing.T) {
failNoneBase := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreRecursiveEdge())
verifyOutcomes := func(t *testing.T, success ipld.Node, fail ipld.Node, failNone ipld.Node) {
err := ValidateSelector(success, 100)
err := ValidateMaxRecursionDepth(success, 100)
require.NoError(t, err, "valid selector should validate")
err = ValidateSelector(fail, 100)
err = ValidateMaxRecursionDepth(fail, 100)
require.Equal(t, ErrInvalidLimit, err, "selector should fail on invalid limit")
err = ValidateSelector(failNone, 100)
err = ValidateMaxRecursionDepth(failNone, 100)
require.Equal(t, ErrInvalidLimit, err, "selector should fail on no limit")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment