Commit 6d98af58 authored by hannahhoward's avatar hannahhoward

feat(responsemanager): process sent extensions

When the request manager receives extensions in a request, it processes them and returns their
response
parent d4193eda
......@@ -395,7 +395,6 @@ package "go-graphsync" {
package ipldbridge {
interface IPLDBridge {
BuildNode(func(NodeBuilder) ipld.Node) (ipld.Node, error)
EncodeNode(ipld.Node) ([]byte, error)
DecodeNode([]byte) (ipld.Node, error)
ParseSelector(selector ipld.Node) (Selector, error)
......
......@@ -91,6 +91,8 @@ func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, sel
// it is considered to have "validated" the request -- and that validation supersedes
// the normal validation of requests Graphsync does (i.e. all selectors can be accepted)
func (gs *GraphSync) RegisterRequestReceivedHook(overrideDefaultValidation bool, hook graphsync.OnRequestReceivedHook) error {
gs.responseManager.RegisterHook(overrideDefaultValidation, hook)
// may be a need to return errors here in the future...
return nil
}
......
......@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/ipld/go-ipld-prime/fluent"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
......@@ -89,8 +90,10 @@ func setupBlockChain(
size int64,
blockChainLength int) *blockChain {
linkBuilder := cidlink.LinkBuilder{Prefix: cid.NewPrefixV1(cid.DagCBOR, mh.SHA2_256)}
genisisNode, err := bridge.BuildNode(func(nb ipldbridge.NodeBuilder) ipld.Node {
return createBlock(nb, []ipld.Link{}, size)
var genisisNode ipld.Node
err := fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
genisisNode = createBlock(nb, []ipld.Link{}, size)
})
if err != nil {
t.Fatal("Error creating genesis block")
......@@ -103,8 +106,10 @@ func setupBlockChain(
middleNodes := make([]ipld.Node, 0, blockChainLength-2)
middleLinks := make([]ipld.Link, 0, blockChainLength-2)
for i := 0; i < blockChainLength-2; i++ {
node, err := bridge.BuildNode(func(nb ipldbridge.NodeBuilder) ipld.Node {
return createBlock(nb, []ipld.Link{parent}, size)
var node ipld.Node
err := fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
node = createBlock(nb, []ipld.Link{parent}, size)
})
if err != nil {
t.Fatal("Error creating middle block")
......@@ -117,8 +122,10 @@ func setupBlockChain(
middleLinks = append(middleLinks, link)
parent = link
}
tipNode, err := bridge.BuildNode(func(nb ipldbridge.NodeBuilder) ipld.Node {
return createBlock(nb, []ipld.Link{parent}, size)
var tipNode ipld.Node
err = fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
tipNode = createBlock(nb, []ipld.Link{parent}, size)
})
if err != nil {
t.Fatal("Error creating tip block")
......@@ -256,8 +263,34 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
loader, storer := testbridge.NewMockStore(blockStore)
bridge := ipldbridge.NewIPLDBridge()
extensionData := testutil.RandomBytes(100)
extensionName := graphsync.ExtensionName("AppleSauce/McGee")
extension := graphsync.ExtensionData{
Name: extensionName,
Data: extensionData,
}
extensionResponseData := testutil.RandomBytes(100)
extensionResponse := graphsync.ExtensionData{
Name: extensionName,
Data: extensionResponseData,
}
var receivedRequestData []byte
// initialize graphsync on second node to response to requests
New(ctx, gsnet2, bridge, loader, storer)
gsnet := New(ctx, gsnet2, bridge, loader, storer)
err = gsnet.RegisterRequestReceivedHook(false,
func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
var has bool
receivedRequestData, has = requestData.Extension(extensionName)
if !has {
t.Fatal("did not have expected extension")
}
return []graphsync.ExtensionData{extensionResponse}, nil
},
)
if err != nil {
t.Fatal("error registering extension")
}
blockChainLength := 100
blockChain := setupBlockChain(ctx, t, storer, bridge, 100, blockChainLength)
......@@ -275,12 +308,13 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
requestID := graphsync.RequestID(rand.Int31())
message := gsmsg.New()
message.AddRequest(gsmsg.NewRequest(requestID, blockChain.tipLink.(cidlink.Link).Cid, selectorData, graphsync.Priority(math.MaxInt32)))
message.AddRequest(gsmsg.NewRequest(requestID, blockChain.tipLink.(cidlink.Link).Cid, selectorData, graphsync.Priority(math.MaxInt32), extension))
// send request across network
gsnet1.SendMessage(ctx, host2.ID(), message)
// read the values sent back to requestor
var received gsmsg.GraphSyncMessage
var receivedBlocks []blocks.Block
var receivedExtensions [][]byte
readAllMessages:
for {
select {
......@@ -295,6 +329,10 @@ readAllMessages:
received = message.message
receivedBlocks = append(receivedBlocks, received.Blocks()...)
receivedResponses := received.Responses()
receivedExtension, found := receivedResponses[0].Extension(extensionName)
if found {
receivedExtensions = append(receivedExtensions, receivedExtension)
}
if len(receivedResponses) != 1 {
t.Fatal("Did not receive response")
}
......@@ -310,6 +348,18 @@ readAllMessages:
if len(receivedBlocks) != blockChainLength {
t.Fatal("Send incorrect number of blocks or there were duplicate blocks")
}
if !reflect.DeepEqual(extensionData, receivedRequestData) {
t.Fatal("did not receive correct request extension data")
}
if len(receivedExtensions) != 1 {
t.Fatal("should have sent extension responses but didn't")
}
if !reflect.DeepEqual(receivedExtensions[0], extensionResponseData) {
t.Fatal("did not return correct extension data")
}
}
func TestGraphsyncRoundTrip(t *testing.T) {
......
......@@ -4,8 +4,6 @@ import (
"bytes"
"context"
"github.com/ipld/go-ipld-prime/fluent"
ipld "github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/encoding/dagcbor"
free "github.com/ipld/go-ipld-prime/impl/free"
......@@ -24,30 +22,6 @@ func NewIPLDBridge() IPLDBridge {
return &ipldBridge{}
}
func (rb *ipldBridge) ExtractData(node ipld.Node, buildFn func(SimpleNode) interface{}) (interface{}, error) {
var value interface{}
err := fluent.Recover(func() {
simpleNode := fluent.WrapNode(node)
value = buildFn(simpleNode)
})
if err != nil {
return nil, err
}
return value, nil
}
func (rb *ipldBridge) BuildNode(buildFn func(NodeBuilder) ipld.Node) (ipld.Node, error) {
var node ipld.Node
err := fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(free.NodeBuilder())
node = buildFn(nb)
})
if err != nil {
return nil, err
}
return node, nil
}
func (rb *ipldBridge) Traverse(ctx context.Context, loader Loader, root ipld.Link, s Selector, fn AdvVisitFn) error {
node, err := root.Load(ctx, LinkContext{}, free.NodeBuilder(), loader)
if err != nil {
......
......@@ -72,14 +72,6 @@ type SimpleNode = fluent.Node
// replaced with alternative implementations
type IPLDBridge interface {
// ExtractData provides an efficient mechanism for reading nodes w/ fluent
// interface
ExtractData(ipld.Node, func(SimpleNode) interface{}) (interface{}, error)
// BuildNode provides an efficient mechanism for assembling nodes w/ fluent
// interface
BuildNode(func(NodeBuilder) ipld.Node) (ipld.Node, error)
// EncodeNode encodes an IPLD Node to bytes for network transfer.
EncodeNode(ipld.Node) ([]byte, error)
......
......@@ -3,6 +3,8 @@ package metadata
import (
"github.com/ipfs/go-graphsync/ipldbridge"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/fluent"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
)
// Item is a single link traversed in a repsonse
......@@ -22,7 +24,9 @@ func DecodeMetadata(data []byte, ipldBridge ipldbridge.IPLDBridge) (Metadata, er
if err != nil {
return nil, err
}
decodedData, err := ipldBridge.ExtractData(node, func(simpleNode ipldbridge.SimpleNode) interface{} {
var decodedData interface{}
err = fluent.Recover(func() {
simpleNode := fluent.WrapNode(node)
iterator := simpleNode.ListIterator()
var metadata Metadata
if simpleNode.Length() != -1 {
......@@ -35,7 +39,7 @@ func DecodeMetadata(data []byte, ipldBridge ipldbridge.IPLDBridge) (Metadata, er
blockPresent := item.LookupString("blockPresent").AsBool()
metadata = append(metadata, Item{link, blockPresent})
}
return metadata
decodedData = metadata
})
if err != nil {
return nil, err
......@@ -45,8 +49,10 @@ func DecodeMetadata(data []byte, ipldBridge ipldbridge.IPLDBridge) (Metadata, er
// EncodeMetadata encodes metadata to an IPLD node then serializes to raw bytes
func EncodeMetadata(entries Metadata, ipldBridge ipldbridge.IPLDBridge) ([]byte, error) {
node, err := ipldBridge.BuildNode(func(nb ipldbridge.NodeBuilder) ipld.Node {
return nb.CreateList(func(lb ipldbridge.ListBuilder, nb ipldbridge.NodeBuilder) {
var node ipld.Node
err := fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
node = nb.CreateList(func(lb ipldbridge.ListBuilder, nb ipldbridge.NodeBuilder) {
for _, item := range entries {
lb.Append(
nb.CreateMap(func(mb ipldbridge.MapBuilder, knb ipldbridge.NodeBuilder, vnb ipldbridge.NodeBuilder) {
......
......@@ -57,6 +57,7 @@ type PeerResponseSender interface {
link ipld.Link,
data []byte,
)
SendExtensionData(graphsync.RequestID, graphsync.ExtensionData)
FinishRequest(requestID graphsync.RequestID)
FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode)
}
......@@ -86,6 +87,14 @@ func (prm *peerResponseSender) Shutdown() {
prm.cancel()
}
func (prm *peerResponseSender) SendExtensionData(requestID graphsync.RequestID, extension graphsync.ExtensionData) {
if prm.buildResponse(0, func(responseBuilder *responsebuilder.ResponseBuilder) {
responseBuilder.AddExtensionData(requestID, extension)
}) {
prm.signalWork()
}
}
// SendResponse sends a given link for a given
// requestID across the wire, as well as its corresponding
// block if the block is present and has not already been sent
......
......@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math/rand"
"reflect"
"testing"
"time"
......@@ -294,6 +295,84 @@ func TestPeerResponseManagerSendsVeryLargeBlocksResponses(t *testing.T) {
}
func TestPeerResponseManagerSendsExtensionData(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
defer cancel()
p := testutil.GeneratePeers(1)[0]
requestID1 := graphsync.RequestID(rand.Int31())
blks := testutil.GenerateBlocksOfSize(5, 100)
links := make([]ipld.Link, 0, len(blks))
for _, block := range blks {
links = append(links, cidlink.Link{Cid: block.Cid()})
}
done := make(chan struct{}, 1)
sent := make(chan struct{}, 1)
fph := &fakePeerHandler{
done: done,
sent: sent,
}
ipldBridge := testbridge.NewMockIPLDBridge()
peerResponseManager := NewResponseSender(ctx, p, fph, ipldBridge)
peerResponseManager.Startup()
peerResponseManager.SendResponse(requestID1, links[0], blks[0].RawData())
select {
case <-ctx.Done():
t.Fatal("Did not send first message")
case <-sent:
}
if len(fph.lastBlocks) != 1 || fph.lastBlocks[0].Cid() != blks[0].Cid() {
t.Fatal("Did not send correct blocks for first message")
}
if len(fph.lastResponses) != 1 || fph.lastResponses[0].RequestID() != requestID1 ||
fph.lastResponses[0].Status() != graphsync.PartialResponse {
t.Fatal("Did not send correct responses for first message")
}
extensionData1 := testutil.RandomBytes(100)
extensionName1 := graphsync.ExtensionName("AppleSauce/McGee")
extension1 := graphsync.ExtensionData{
Name: extensionName1,
Data: extensionData1,
}
extensionData2 := testutil.RandomBytes(100)
extensionName2 := graphsync.ExtensionName("HappyLand/Happenstance")
extension2 := graphsync.ExtensionData{
Name: extensionName2,
Data: extensionData2,
}
peerResponseManager.SendResponse(requestID1, links[1], blks[1].RawData())
peerResponseManager.SendExtensionData(requestID1, extension1)
peerResponseManager.SendExtensionData(requestID1, extension2)
// let peer reponse manager know last message was sent so message sending can continue
done <- struct{}{}
select {
case <-ctx.Done():
t.Fatal("Should have sent second message but didn't")
case <-sent:
}
if len(fph.lastResponses) != 1 {
t.Fatal("Did not send correct number of responses for second message")
}
lastResponse := fph.lastResponses[0]
returnedData1, found := lastResponse.Extension(extensionName1)
if !found || !reflect.DeepEqual(extensionData1, returnedData1) {
t.Fatal("Failed to encode first extension")
}
returnedData2, found := lastResponse.Extension(extensionName2)
if !found || !reflect.DeepEqual(extensionData2, returnedData2) {
t.Fatal("Failed to encode first extension")
}
}
func findResponseForRequestID(responses []gsmsg.GraphSyncResponse, requestID graphsync.RequestID) (gsmsg.GraphSyncResponse, error) {
for _, response := range responses {
if response.RequestID() == requestID {
......
......@@ -17,6 +17,7 @@ type ResponseBuilder struct {
blkSize int
completedResponses map[graphsync.RequestID]graphsync.ResponseStatusCode
outgoingResponses map[graphsync.RequestID]metadata.Metadata
extensions map[graphsync.RequestID][]graphsync.ExtensionData
}
// New generates a new ResponseBuilder.
......@@ -24,6 +25,7 @@ func New() *ResponseBuilder {
return &ResponseBuilder{
completedResponses: make(map[graphsync.RequestID]graphsync.ResponseStatusCode),
outgoingResponses: make(map[graphsync.RequestID]metadata.Metadata),
extensions: make(map[graphsync.RequestID][]graphsync.ExtensionData),
}
}
......@@ -33,6 +35,11 @@ func (rb *ResponseBuilder) AddBlock(block blocks.Block) {
rb.outgoingBlocks = append(rb.outgoingBlocks, block)
}
// AddExtensionData adds the given extension data to to the response
func (rb *ResponseBuilder) AddExtensionData(requestID graphsync.RequestID, extension graphsync.ExtensionData) {
rb.extensions[requestID] = append(rb.extensions[requestID], extension)
}
// BlockSize returns the total size of all blocks in this response
func (rb *ResponseBuilder) BlockSize() int {
return rb.blkSize
......@@ -69,12 +76,12 @@ func (rb *ResponseBuilder) Build(ipldBridge ipldbridge.IPLDBridge) ([]gsmsg.Grap
if err != nil {
return nil, nil, err
}
md := graphsync.ExtensionData{
rb.extensions[requestID] = append(rb.extensions[requestID], graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: mdRaw,
}
})
status, isComplete := rb.completedResponses[requestID]
responses = append(responses, gsmsg.NewResponse(requestID, responseCode(status, isComplete), md))
responses = append(responses, gsmsg.NewResponse(requestID, responseCode(status, isComplete), rb.extensions[requestID]...))
}
return responses, rb.outgoingBlocks, nil
}
......
......@@ -52,6 +52,22 @@ func TestMessageBuilding(t *testing.T) {
if rb.BlockSize() != 300 {
t.Fatal("did not calculate block size correctly")
}
extensionData1 := testutil.RandomBytes(100)
extensionName1 := graphsync.ExtensionName("AppleSauce/McGee")
extension1 := graphsync.ExtensionData{
Name: extensionName1,
Data: extensionData1,
}
extensionData2 := testutil.RandomBytes(100)
extensionName2 := graphsync.ExtensionName("HappyLand/Happenstance")
extension2 := graphsync.ExtensionData{
Name: extensionName2,
Data: extensionData2,
}
rb.AddExtensionData(requestID1, extension1)
rb.AddExtensionData(requestID3, extension2)
responses, sentBlocks, err := rb.Build(ipldBridge)
if err != nil {
......@@ -80,6 +96,11 @@ func TestMessageBuilding(t *testing.T) {
t.Fatal("Metadata did not match expected")
}
response1ReturnedExtensionData, found := response1.Extension(extensionName1)
if !found || !reflect.DeepEqual(extensionData1, response1ReturnedExtensionData) {
t.Fatal("Failed to encode first extension")
}
response2, err := findResponseForRequestID(responses, requestID2)
if err != nil || response2.Status() != graphsync.RequestCompletedFull {
t.Fatal("did not generate completed partial response")
......@@ -113,6 +134,11 @@ func TestMessageBuilding(t *testing.T) {
t.Fatal("Metadata did not match expected")
}
response3ReturnedExtensionData, found := response3.Extension(extensionName2)
if !found || !reflect.DeepEqual(extensionData2, response3ReturnedExtensionData) {
t.Fatal("Failed to encode second extension")
}
response4, err := findResponseForRequestID(responses, requestID4)
if err != nil || response4.Status() != graphsync.RequestCompletedFull {
t.Fatal("did not generate completed partial response")
......
......@@ -38,6 +38,11 @@ type responseTaskData struct {
request gsmsg.GraphSyncRequest
}
type requestHook struct {
overrideDefaultValidation bool
hook graphsync.OnRequestReceivedHook
}
// QueryQueue is an interface that can receive new selector query tasks
// and prioritize them as needed, and pop them off later
type QueryQueue interface {
......@@ -70,6 +75,7 @@ type ResponseManager struct {
workSignal chan struct{}
ticker *time.Ticker
inProgressResponses map[responseKey]inProgressResponseStatus
requestHooks []requestHook
}
// New creates a new response manager from the given context, loader,
......@@ -108,6 +114,16 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque
}
}
// RegisterExtension registers an extension to process new incoming requests
func (rm *ResponseManager) RegisterHook(
overrideDefaultValidation bool,
hook graphsync.OnRequestReceivedHook) {
select {
case rm.messages <- &requestHook{overrideDefaultValidation, hook}:
case <-rm.ctx.Done():
}
}
type synchronizeMessage struct {
sync chan struct{}
}
......@@ -187,13 +203,30 @@ func (rm *ResponseManager) executeQuery(ctx context.Context,
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
err = selectorvalidator.ValidateSelector(rm.ipldBridge, selectorSpec, maxRecursionDepth)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
var isValidated bool
for _, requestHook := range rm.requestHooks {
extensionData, err := requestHook.hook(p, request)
for _, datum := range extensionData {
peerResponseSender.SendExtensionData(request.ID(), datum)
}
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
if requestHook.overrideDefaultValidation {
isValidated = true
}
}
if !isValidated {
err = selectorvalidator.ValidateSelector(rm.ipldBridge, selectorSpec, maxRecursionDepth)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
}
rootLink := cidlink.Link{Cid: request.Root()}
selector, err := rm.ipldBridge.ParseSelector(selectorSpec)
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
......@@ -265,6 +298,10 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) {
}
}
func (rh *requestHook) handle(rm *ResponseManager) {
rm.requestHooks = append(rm.requestHooks, *rh)
}
func (rdr *responseDataRequest) handle(rm *ResponseManager) {
response, ok := rm.inProgressResponses[rdr.key]
var taskData *responseTaskData
......
......@@ -2,6 +2,7 @@ package responsemanager
import (
"context"
"fmt"
"math"
"math/rand"
"reflect"
......@@ -87,9 +88,19 @@ type sentResponse struct {
data []byte
}
type sentExtension struct {
requestID graphsync.RequestID
extension graphsync.ExtensionData
}
type completedRequest struct {
requestID graphsync.RequestID
result graphsync.ResponseStatusCode
}
type fakePeerResponseSender struct {
sentResponses chan sentResponse
lastCompletedRequest chan graphsync.RequestID
sentExtensions chan sentExtension
lastCompletedRequest chan completedRequest
}
func (fprs *fakePeerResponseSender) Startup() {}
......@@ -103,12 +114,19 @@ func (fprs *fakePeerResponseSender) SendResponse(
fprs.sentResponses <- sentResponse{requestID, link, data}
}
func (fprs *fakePeerResponseSender) SendExtensionData(
requestID graphsync.RequestID,
extension graphsync.ExtensionData,
) {
fprs.sentExtensions <- sentExtension{requestID, extension}
}
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
fprs.lastCompletedRequest <- requestID
fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
}
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
fprs.lastCompletedRequest <- requestID
fprs.lastCompletedRequest <- completedRequest{requestID, status}
}
func TestIncomingQuery(t *testing.T) {
......@@ -118,9 +136,10 @@ func TestIncomingQuery(t *testing.T) {
blks := testutil.GenerateBlocksOfSize(5, 20)
loader := testbridge.NewMockLoader(blks)
ipldBridge := testbridge.NewMockIPLDBridge()
requestIDChan := make(chan graphsync.RequestID, 1)
requestIDChan := make(chan completedRequest, 1)
sentResponses := make(chan sentResponse, len(blks))
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses}
sentExtensions := make(chan sentExtension, 1)
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
......@@ -173,9 +192,10 @@ func TestCancellationQueryInProgress(t *testing.T) {
blks := testutil.GenerateBlocksOfSize(5, 20)
loader := testbridge.NewMockLoader(blks)
ipldBridge := testbridge.NewMockIPLDBridge()
requestIDChan := make(chan graphsync.RequestID)
requestIDChan := make(chan completedRequest)
sentResponses := make(chan sentResponse)
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses}
sentExtensions := make(chan sentExtension, 1)
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
......@@ -260,9 +280,10 @@ func TestEarlyCancellation(t *testing.T) {
blks := testutil.GenerateBlocksOfSize(5, 20)
loader := testbridge.NewMockLoader(blks)
ipldBridge := testbridge.NewMockIPLDBridge()
requestIDChan := make(chan graphsync.RequestID)
requestIDChan := make(chan completedRequest)
sentResponses := make(chan sentResponse)
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses}
sentExtensions := make(chan sentExtension, 1)
fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
queryQueue.popWait.Add(1)
......@@ -305,3 +326,164 @@ func TestEarlyCancellation(t *testing.T) {
t.Fatal("should not send have completed response")
}
}
func TestValidationAndExtensions(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 40*time.Millisecond)
defer cancel()
blks := testutil.GenerateBlocksOfSize(5, 20)
loader := testbridge.NewMockLoader(blks)
ipldBridge := testbridge.NewMockIPLDBridge()
completedRequestChan := make(chan completedRequest, 1)
sentResponses := make(chan sentResponse, 100)
sentExtensions := make(chan sentExtension, 1)
fprs := &fakePeerResponseSender{lastCompletedRequest: completedRequestChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
peerManager := &fakePeerManager{peerResponseSender: fprs}
queryQueue := &fakeQueryQueue{}
cids := make([]cid.Cid, 0, 5)
for _, block := range blks {
cids = append(cids, block.Cid())
}
extensionData := testutil.RandomBytes(100)
extensionName := graphsync.ExtensionName("AppleSauce/McGee")
extension := graphsync.ExtensionData{
Name: extensionName,
Data: extensionData,
}
extensionResponseData := testutil.RandomBytes(100)
extensionResponse := graphsync.ExtensionData{
Name: extensionName,
Data: extensionResponseData,
}
t.Run("with invalid selector", func(t *testing.T) {
selectorSpec := testbridge.NewInvalidSelectorSpec(cids)
selector, err := ipldBridge.EncodeNode(selectorSpec)
if err != nil {
t.Fatal("error encoding selector")
}
requestID := graphsync.RequestID(rand.Int31())
requests := []gsmsg.GraphSyncRequest{
gsmsg.NewRequest(requestID, cids[0], selector, graphsync.Priority(math.MaxInt32), extension),
}
p := testutil.GeneratePeers(1)[0]
t.Run("on its own, should fail validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.ProcessRequests(ctx, p, requests)
select {
case <-ctx.Done():
t.Fatal("Should have completed request but didn't")
case lastRequest := <-completedRequestChan:
if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
t.Fatal("Request should have failed but didn't")
}
}
})
t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(false, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, nil
})
responseManager.ProcessRequests(ctx, p, requests)
select {
case <-ctx.Done():
t.Fatal("Should have completed request but didn't")
case lastRequest := <-completedRequestChan:
if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
t.Fatal("Request should have succeeded but didn't")
}
}
select {
case <-ctx.Done():
t.Fatal("Should have sent extension response but didn't")
case receivedExtension := <-sentExtensions:
if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
t.Fatal("Proper Extension response should have been sent but wasn't")
}
}
})
t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(true, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, nil
})
responseManager.ProcessRequests(ctx, p, requests)
select {
case <-ctx.Done():
t.Fatal("Should have completed request but didn't")
case lastRequest := <-completedRequestChan:
if !gsmsg.IsTerminalSuccessCode(lastRequest.result) {
t.Fatal("Request should have succeeded but didn't")
}
}
select {
case <-ctx.Done():
t.Fatal("Should have sent extension response but didn't")
case receivedExtension := <-sentExtensions:
if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
t.Fatal("Proper Extension response should have been sent but wasn't")
}
}
})
})
t.Run("with valid selector", func(t *testing.T) {
selectorSpec := testbridge.NewMockSelectorSpec(cids)
selector, err := ipldBridge.EncodeNode(selectorSpec)
if err != nil {
t.Fatal("error encoding selector")
}
requestID := graphsync.RequestID(rand.Int31())
requests := []gsmsg.GraphSyncRequest{
gsmsg.NewRequest(requestID, cids[0], selector, graphsync.Priority(math.MaxInt32), extension),
}
p := testutil.GeneratePeers(1)[0]
t.Run("on its own, should pass validation", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.ProcessRequests(ctx, p, requests)
select {
case <-ctx.Done():
t.Fatal("Should have completed request but didn't")
case lastRequest := <-completedRequestChan:
if !gsmsg.IsTerminalSuccessCode(lastRequest.result) {
t.Fatal("Request should have failed but didn't")
}
}
})
t.Run("if any hook fails, should fail", func(t *testing.T) {
responseManager := New(ctx, loader, ipldBridge, peerManager, queryQueue)
responseManager.Startup()
responseManager.RegisterHook(false, func(p peer.ID, requestData graphsync.RequestData) ([]graphsync.ExtensionData, error) {
return []graphsync.ExtensionData{extensionResponse}, fmt.Errorf("everything went to crap")
})
responseManager.ProcessRequests(ctx, p, requests)
select {
case <-ctx.Done():
t.Fatal("Should have completed request but didn't")
case lastRequest := <-completedRequestChan:
if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
t.Fatal("Request should have succeeded but didn't")
}
}
select {
case <-ctx.Done():
t.Fatal("Should have sent extension response but didn't")
case receivedExtension := <-sentExtensions:
if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
t.Fatal("Proper Extension response should have been sent but wasn't")
}
}
})
})
}
......@@ -8,29 +8,35 @@ import (
)
type mockSelectorSpec struct {
cidsVisited []cid.Cid
failValidation bool
failEncode bool
CidsVisited []cid.Cid
FalseParse bool
FailEncode bool
FailValidation bool
}
// NewMockSelectorSpec returns a new mock selector that will visit the given
// cids.
func NewMockSelectorSpec(cidsVisited []cid.Cid) ipld.Node {
return &mockSelectorSpec{cidsVisited, false, false}
return &mockSelectorSpec{cidsVisited, false, false, false}
}
// NewInvalidSelectorSpec returns a spec that will fail when you attempt to
// NewUnparsableSelectorSpec returns a spec that will fail when you attempt to
// validate it or decompose to a node + selector.
func NewUnparsableSelectorSpec(cidsVisited []cid.Cid) ipld.Node {
return &mockSelectorSpec{cidsVisited, true, false, false}
}
// NewInvalidSelectorSpec returns a spec that will fail when you attempt to
// encode it.
func NewInvalidSelectorSpec(cidsVisited []cid.Cid) ipld.Node {
return &mockSelectorSpec{cidsVisited, true, false}
return &mockSelectorSpec{cidsVisited, false, false, true}
}
// NewUnencodableSelectorSpec returns a spec that will fail when you attempt to
// encode it.
func NewUnencodableSelectorSpec(cidsVisited []cid.Cid) ipld.Node {
return &mockSelectorSpec{cidsVisited, false, true}
return &mockSelectorSpec{cidsVisited, false, true, false}
}
func (mss *mockSelectorSpec) ReprKind() ipld.ReprKind { return ipld.ReprKind_Null }
func (mss *mockSelectorSpec) Lookup(key ipld.Node) (ipld.Node, error) {
return nil, fmt.Errorf("404")
......
......@@ -11,7 +11,7 @@ type mockSelector struct {
}
func newMockSelector(mss *mockSelectorSpec) ipldbridge.Selector {
return &mockSelector{mss.cidsVisited}
return &mockSelector{mss.CidsVisited}
}
func (ms *mockSelector) Explore(ipld.Node, ipld.PathSegment) ipldbridge.Selector {
......
......@@ -11,7 +11,6 @@ import (
ipldbridge "github.com/ipfs/go-graphsync/ipldbridge"
ipld "github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/encoding/dagjson"
"github.com/ipld/go-ipld-prime/fluent"
free "github.com/ipld/go-ipld-prime/impl/free"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
multihash "github.com/multiformats/go-multihash"
......@@ -20,42 +19,16 @@ import (
type mockIPLDBridge struct {
}
// NewMockIPLDBridge returns an IPLD bridge that works with MockSelectors
// NewMockIPLDBridge returns an ipld bridge with mocked behavior
func NewMockIPLDBridge() ipldbridge.IPLDBridge {
return &mockIPLDBridge{}
}
func (mb *mockIPLDBridge) ExtractData(
node ipld.Node,
buildFn func(ipldbridge.SimpleNode) interface{}) (interface{}, error) {
var value interface{}
err := fluent.Recover(func() {
simpleNode := fluent.WrapNode(node)
value = buildFn(simpleNode)
})
if err != nil {
return nil, err
}
return value, nil
}
func (mb *mockIPLDBridge) BuildNode(buildFn func(ipldbridge.NodeBuilder) ipld.Node) (ipld.Node, error) {
var node ipld.Node
err := fluent.Recover(func() {
nb := fluent.WrapNodeBuilder(free.NodeBuilder())
node = buildFn(nb)
})
if err != nil {
return nil, err
}
return node, nil
}
func (mb *mockIPLDBridge) EncodeNode(node ipld.Node) ([]byte, error) {
spec, ok := node.(*mockSelectorSpec)
if ok {
if !spec.failEncode {
data, err := json.Marshal(spec.cidsVisited)
if !spec.FailEncode {
data, err := json.Marshal(*spec)
if err != nil {
return nil, err
}
......@@ -72,10 +45,10 @@ func (mb *mockIPLDBridge) EncodeNode(node ipld.Node) ([]byte, error) {
}
func (mb *mockIPLDBridge) DecodeNode(data []byte) (ipld.Node, error) {
var cidsVisited []cid.Cid
err := json.Unmarshal(data, &cidsVisited)
var spec mockSelectorSpec
err := json.Unmarshal(data, &spec)
if err == nil {
return &mockSelectorSpec{cidsVisited, false, false}, nil
return &spec, nil
}
reader := bytes.NewReader(data)
return dagjson.Decoder(free.NodeBuilder(), reader)
......@@ -83,7 +56,7 @@ func (mb *mockIPLDBridge) DecodeNode(data []byte) (ipld.Node, error) {
func (mb *mockIPLDBridge) ParseSelector(selectorSpec ipld.Node) (ipldbridge.Selector, error) {
spec, ok := selectorSpec.(*mockSelectorSpec)
if !ok || spec.failValidation {
if !ok || spec.FalseParse {
return nil, fmt.Errorf("not a selector")
}
return newMockSelector(spec), nil
......@@ -113,6 +86,10 @@ func (mb *mockIPLDBridge) Traverse(ctx context.Context, loader ipldbridge.Loader
}
func (mb *mockIPLDBridge) WalkMatching(node ipld.Node, s ipldbridge.Selector, fn ipldbridge.VisitFn) error {
spec, ok := node.(*mockSelectorSpec)
if ok && spec.FailValidation {
return fmt.Errorf("not a valid kind of selector")
}
return nil
}
......
......@@ -65,6 +65,7 @@ func TestEncodeParseSelector(t *testing.T) {
spec := NewMockSelectorSpec(cids)
bridge := NewMockIPLDBridge()
data, err := bridge.EncodeNode(spec)
fmt.Println(string(data))
if err != nil {
t.Fatal("error encoding selector spec")
}
......@@ -76,17 +77,17 @@ func TestEncodeParseSelector(t *testing.T) {
if !ok {
t.Fatal("did not decode a selector")
}
if len(returnedSpec.cidsVisited) != 5 {
if len(returnedSpec.CidsVisited) != 5 {
t.Fatal("did not decode enough cids")
}
if !reflect.DeepEqual(cids, returnedSpec.cidsVisited) {
if !reflect.DeepEqual(cids, returnedSpec.CidsVisited) {
t.Fatal("did not decode correct cids")
}
}
func TestFailValidationSelectorSpec(t *testing.T) {
func TestFailParseSelectorSpec(t *testing.T) {
cids := testutil.GenerateCids(5)
spec := NewInvalidSelectorSpec(cids)
spec := NewUnparsableSelectorSpec(cids)
bridge := NewMockIPLDBridge()
_, err := bridge.ParseSelector(spec)
if err == nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment