Unverified Commit 93f57247 authored by Hannah Howard's avatar Hannah Howard Committed by GitHub

Merge pull request #50 from ipfs/feat/response-hooks

Add response hooks
parents 64619a79 dffc8e69
......@@ -2,6 +2,7 @@ package graphsync
import (
"context"
"errors"
"github.com/ipfs/go-cid"
"github.com/ipld/go-ipld-prime"
......@@ -83,6 +84,11 @@ const (
RequestFailedContentNotFound = ResponseStatusCode(34)
)
var (
// ErrExtensionAlreadyRegistered means a user extension can be registered only once
ErrExtensionAlreadyRegistered = errors.New("extension already registered")
)
// ResponseProgress is the fundamental unit of responses making progress in Graphsync.
type ResponseProgress struct {
Node ipld.Node // a node which matched the graphsync query
......@@ -115,6 +121,19 @@ type RequestData interface {
IsCancel() bool
}
// ResponseData describes a received Graphsync response
type ResponseData interface {
// RequestID returns the request ID for this response
RequestID() RequestID
// Status returns the status for a response
Status() ResponseStatusCode
// Extension returns the content for an extension on a response, or errors
// if extension is not present
Extension(name ExtensionName) ([]byte, bool)
}
// RequestReceivedHookActions are actions that a request hook can take to change
// behavior for the response
type RequestReceivedHookActions interface {
......@@ -130,6 +149,11 @@ type RequestReceivedHookActions interface {
// err - error - if not nil, halt request and return RequestRejected with the responseData
type OnRequestReceivedHook func(p peer.ID, request RequestData, hookActions RequestReceivedHookActions)
// OnResponseReceivedHook is a hook that runs each time a response is received.
// It receives the peer that sent the response and all data about the response.
// If it returns an error processing is halted and the original request is cancelled.
type OnResponseReceivedHook func(p peer.ID, responseData ResponseData) error
// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
type GraphExchange interface {
// Request initiates a new GraphSync request to the given peer using the given selector spec.
......@@ -140,4 +164,7 @@ type GraphExchange interface {
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
RegisterRequestReceivedHook(hook OnRequestReceivedHook) error
// RegisterResponseReceivedHook adds a hook that runs when a response is received
RegisterResponseReceivedHook(OnResponseReceivedHook) error
}
......@@ -92,7 +92,12 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
func (gs *GraphSync) RegisterRequestReceivedHook(hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(hook)
// may be a need to return errors here in the future...
return nil
}
// RegisterResponseReceivedHook adds a hook that runs when a response is received
func (gs *GraphSync) RegisterResponseReceivedHook(hook graphsync.OnResponseReceivedHook) error {
gs.requestManager.RegisterHook(hook)
return nil
}
......
This diff is collapsed.
......@@ -32,6 +32,10 @@ type inProgressRequestStatus struct {
networkError chan error
}
type responseHook struct {
hook graphsync.OnResponseReceivedHook
}
// PeerHandler is an interface that can send requests to peers
type PeerHandler interface {
SendRequest(p peer.ID, graphSyncRequest gsmsg.GraphSyncRequest)
......@@ -61,6 +65,7 @@ type RequestManager struct {
// dont touch out side of run loop
nextRequestID graphsync.RequestID
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
responseHooks []responseHook
}
type requestManagerMessage interface {
......@@ -197,6 +202,15 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn
}
}
// RegisterHook registers an extension to processincoming responses
func (rm *RequestManager) RegisterHook(
hook graphsync.OnResponseReceivedHook) {
select {
case rm.messages <- &responseHook{hook}:
case <-rm.ctx.Done():
}
}
// Startup starts processing for the WantManager.
func (rm *RequestManager) Startup() {
go rm.run()
......@@ -266,11 +280,16 @@ func (crm *cancelRequestMessage) handle(rm *RequestManager) {
func (prm *processResponseMessage) handle(rm *RequestManager) {
filteredResponses := rm.filterResponsesForPeer(prm.responses, prm.p)
filteredResponses = rm.processExtensions(filteredResponses, prm.p)
responseMetadata := metadataForResponses(filteredResponses, rm.ipldBridge)
rm.asyncLoader.ProcessResponse(responseMetadata, prm.blks)
rm.processTerminations(filteredResponses)
}
func (rh *responseHook) handle(rm *RequestManager) {
rm.responseHooks = append(rm.responseHooks, *rh)
}
func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
responsesForPeer := make([]gsmsg.GraphSyncResponse, 0, len(responses))
for _, response := range responses {
......@@ -283,6 +302,34 @@ func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResp
return responsesForPeer
}
func (rm *RequestManager) processExtensions(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
remainingResponses := make([]gsmsg.GraphSyncResponse, 0, len(responses))
for _, response := range responses {
success := rm.processExtensionsForResponse(p, response)
if success {
remainingResponses = append(remainingResponses, response)
}
}
return remainingResponses
}
func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg.GraphSyncResponse) bool {
for _, responseHook := range rm.responseHooks {
err := responseHook.hook(p, response)
if err != nil {
requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown)
select {
case requestStatus.networkError <- responseError:
case <-requestStatus.ctx.Done():
}
requestStatus.cancelFn()
return false
}
}
return true
}
func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncResponse) {
for _, response := range responses {
if gsmsg.IsTerminalResponseCode(response.Status()) {
......
......@@ -2,6 +2,7 @@ package requestmanager
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
......@@ -631,7 +632,19 @@ func TestEncodingExtensions(t *testing.T) {
Name: extensionName2,
Data: extensionData2,
}
_, _ = requestManager.SendRequest(requestCtx, peers[0], root, selector, extension1, extension2)
expectedError := make(chan error, 2)
receivedExtensionData := make(chan []byte, 2)
hook := func(p peer.ID, responseData graphsync.ResponseData) error {
data, has := responseData.Extension(extensionName1)
if !has {
t.Fatal("Did not receive extension data in response")
}
receivedExtensionData <- data
return <-expectedError
}
requestManager.RegisterHook(hook)
returnedResponseChan, returnedErrorChan := requestManager.SendRequest(requestCtx, peers[0], root, selector, extension1, extension2)
rr := readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0]
......@@ -646,4 +659,55 @@ func TestEncodingExtensions(t *testing.T) {
t.Fatal("Failed to encode first extension")
}
t.Run("responding to extensions", func(t *testing.T) {
expectedData := testutil.RandomBytes(100)
firstResponses := []gsmsg.GraphSyncResponse{
gsmsg.NewResponse(gsr.ID(),
graphsync.PartialResponse, graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: nil,
},
graphsync.ExtensionData{
Name: extensionName1,
Data: expectedData,
},
),
}
expectedError <- nil
requestManager.ProcessResponses(peers[0], firstResponses, nil)
select {
case <-requestCtx.Done():
t.Fatal("Should have checked extension but didn't")
case received := <-receivedExtensionData:
if !reflect.DeepEqual(received, expectedData) {
t.Fatal("Did not receive correct extension data from resposne")
}
}
nextExpectedData := testutil.RandomBytes(100)
secondResponses := []gsmsg.GraphSyncResponse{
gsmsg.NewResponse(gsr.ID(),
graphsync.PartialResponse, graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: nil,
},
graphsync.ExtensionData{
Name: extensionName1,
Data: nextExpectedData,
},
),
}
expectedError <- errors.New("a terrible thing happened")
requestManager.ProcessResponses(peers[0], secondResponses, nil)
select {
case <-requestCtx.Done():
t.Fatal("Should have checked extension but didn't")
case received := <-receivedExtensionData:
if !reflect.DeepEqual(received, nextExpectedData) {
t.Fatal("Did not receive correct extension data from resposne")
}
}
testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan)
testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan)
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment