Commit 3a078bd5 authored by hannahhoward's avatar hannahhoward

refactor(responsemanager): refactor hook interface

refactor hook interface to pass an actions object
parent 492e1fd4
......@@ -115,12 +115,20 @@ type RequestData interface {
IsCancel() bool
}
// RequestReceivedHookActions are actions that a request hook can take to change
// behavior for the response
type RequestReceivedHookActions interface {
SendExtensionData(ExtensionData)
TerminateWithError(error)
ValidateRequest()
}
// OnRequestReceivedHook is a hook that runs each time a request is received.
// It receives the peer that sent the request and all data about the request.
// It should return:
// extensionData - any extension data to add to the outgoing response
// err - error - if not nil, halt request and return RequestRejected with the responseData
type OnRequestReceivedHook func(p peer.ID, request RequestData) (extensionData []ExtensionData, err error)
type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions RequestReceivedHookActions)
// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
type GraphExchange interface {
......@@ -131,5 +139,5 @@ type GraphExchange interface {
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
RegisterRequestReceivedHook(overrideDefaultValidation bool, hook OnRequestReceivedHook) error
RegisterRequestReceivedHook(hook OnRequestReceivedHook) error
}
......@@ -90,8 +90,8 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
// If overrideDefaultValidation is set to true, then if the hook does not error,
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
func (gs *GraphSync) RegisterRequestReceivedHook(overrideDefaultValidation bool, hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(overrideDefaultValidation, hook)
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(hook)
// may be a need to return errors here in the future...
return nil
}
......
......@@ -278,14 +278,14 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
var receivedRequestData []byte
// initialize graphsync on second node to response to requests
gsnet := New(ctx, gsnet2, bridge, loader, storer)
err = gsnet.RegisterRequestReceivedHook(false,
func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
err = gsnet.RegisterRequestReceivedHook(
func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
var has bool
receivedRequestData, has = requestData.Extension(extensionName)
if !has {
t.Fatal("did not have expected extension")
}
return []graphsync.ExtensionData{extensionResponse}, nil
hookActions.SendExtensionData(extensionResponse)
},
)
if err != nil {
......
......@@ -13,7 +13,6 @@ import (
"github.com/ipfs/go-peertaskqueue/peertask"
ipld "github.com/ipld/go-ipld-prime"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/ipld/go-ipld-prime/traversal/selector"
"github.com/libp2p/go-libp2p-core/peer"
)
......@@ -40,8 +39,7 @@ type responseTaskData struct {
}
type requestHook struct {
overrideDefaultValidation bool
hook graphsync.OnRequestReceivedHook
hook graphsync.OnRequestReceivedHook
}
// QueryQueue is an interface that can receive new selector query tasks
......@@ -116,11 +114,9 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque
}
// RegisterHook registers an extension to process new incoming requests
func (rm *ResponseManager) RegisterHook(
overrideDefaultValidation bool,
hook graphsync.OnRequestReceivedHook) {
func (rm *ResponseManager) RegisterHook(hook graphsync.OnRequestReceivedHook) {
select {
case rm.messages <- &requestHook{overrideDefaultValidation, hook}:
case rm.messages <- &requestHook{hook}:
case <-rm.ctx.Done():
}
}
......@@ -195,14 +191,50 @@ func noopVisitor(tp ipldbridge.TraversalProgress, n ipld.Node, tr ipldbridge.Tra
return nil
}
type hookActions struct {
isValidated bool
requestID graphsync.RequestID
peerResponseSender peerresponsemanager.PeerResponseSender
err error
}
func (ha *hookActions) SendExtensionData(ext graphsync.ExtensionData) {
ha.peerResponseSender.SendExtensionData(ha.requestID, ext)
}
func (ha *hookActions) TerminateWithError(err error) {
ha.err = err
ha.peerResponseSender.FinishWithError(ha.requestID, graphsync.RequestFailedUnknown)
}
func (ha *hookActions) ValidateRequest() {
ha.isValidated = true
}
func (rm *ResponseManager) executeQuery(ctx context.Context,
p peer.ID,
request gsmsg.GraphSyncRequest) {
peerResponseSender := rm.peerManager.SenderForPeer(p)
extensionData, selector, err := rm.validateRequest(p, request)
for _, datum := range extensionData {
peerResponseSender.SendExtensionData(request.ID(), datum)
selectorSpec, err := rm.ipldBridge.DecodeNode(request.Selector())
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
ha := &hookActions{false, request.ID(), peerResponseSender, nil}
for _, requestHook := range rm.requestHooks {
requestHook.hook(p, request, ha)
if ha.err != nil {
return
}
}
if !ha.isValidated {
err = selectorvalidator.ValidateSelector(rm.ipldBridge, selectorSpec, maxRecursionDepth)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
}
selector, err := rm.ipldBridge.ParseSelector(selectorSpec)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
......@@ -217,33 +249,6 @@ func (rm *ResponseManager) executeQuery(ctx context.Context,
peerResponseSender.FinishRequest(request.ID())
}
func (rm *ResponseManager) validateRequest(p peer.ID, request graphsync.RequestData) ([]graphsync.ExtensionData, selector.Selector, error) {
selectorSpec, err := rm.ipldBridge.DecodeNode(request.Selector())
if err != nil {
return nil, nil, err
}
var isValidated bool
var allExtensionData []graphsync.ExtensionData
for _, requestHook := range rm.requestHooks {
extensionData, err := requestHook.hook(p, request)
allExtensionData = append(allExtensionData, extensionData...)
if err != nil {
return allExtensionData, nil, err
}
if requestHook.overrideDefaultValidation {
isValidated = true
}
}
if !isValidated {
err = selectorvalidator.ValidateSelector(rm.ipldBridge, selectorSpec, maxRecursionDepth)
if err != nil {
return allExtensionData, nil, err
}
}
selector, err := rm.ipldBridge.ParseSelector(selectorSpec)
return allExtensionData, selector, err
}
// Startup starts processing for the WantManager.
func (rm *ResponseManager) Startup() {
go rm.run()
......
......@@ -2,7 +2,7 @@ package responsemanager
import (
"context"
"fmt"
"errors"
"math"
"math/rand"
"reflect"
......@@ -387,8 +387,8 @@ func TestValidationAndExtensions(t *testing.T) {
t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(false, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, nil
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
select {
......@@ -412,8 +412,9 @@ func TestValidationAndExtensions(t *testing.T) {
t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(true, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, nil
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.ValidateRequest()
hookActions.SendExtensionData(extensionResponse)
})
responseManager.ProcessRequests(ctx, p, requests)
select {
......@@ -464,8 +465,9 @@ func TestValidationAndExtensions(t *testing.T) {
t.Run("if any hook fails, should fail", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(false, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, fmt.Errorf("everything went to crap")
responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
hookActions.SendExtensionData(extensionResponse)
hookActions.TerminateWithError(errors.New("everything went to crap"))
})
responseManager.ProcessRequests(ctx, p, requests)
select {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment