Commit 87ef592d authored by hannahhoward's avatar hannahhoward Committed by Hannah Howard

feat(extensions): send extension in request

add the ability to send extensions when creating requests
parent 61a51d9f
......@@ -2,7 +2,6 @@ package graphsync
import (
"context"
"errors"
"github.com/ipld/go-ipld-prime"
peer "github.com/libp2p/go-libp2p-peer"
......@@ -83,11 +82,6 @@ const (
RequestFailedContentNotFound = ResponseStatusCode(34)
)
var (
// ErrExtensionNotPresent means the looked up extension was not found
ErrExtensionNotPresent = errors.New("Extension is missing from this message")
)
// ResponseProgress is the fundamental unit of responses making progress in Graphsync.
type ResponseProgress struct {
Node ipld.Node // a node which matched the graphsync query
......@@ -100,5 +94,6 @@ type ResponseProgress struct {
// GraphExchange is a protocol that can exchange IPLD graphs based on a selector
type GraphExchange interface {
Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node) (<-chan ResponseProgress, <-chan error)
// Request initiates a new GraphSync request to the given peer using the given selector spec.
Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node, extensions ...ExtensionData) (<-chan ResponseProgress, <-chan error)
}
......@@ -82,8 +82,8 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
}
// Request initiates a new GraphSync request to the given peer using the given selector spec.
func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node) (<-chan graphsync.ResponseProgress, <-chan error) {
return gs.requestManager.SendRequest(ctx, p, root, selector)
func (gs *GraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) {
return gs.requestManager.SendRequest(ctx, p, root, selector, extensions...)
}
type graphSyncReceiver GraphSync
......
......@@ -175,9 +175,17 @@ func TestMakeRequestToNetwork(t *testing.T) {
if err != nil {
t.Fatal("Failed creating selector")
}
extensionData := testutil.RandomBytes(100)
extensionName := graphsync.ExtensionName("AppleSauce/McGee")
extension := graphsync.ExtensionData{
Name: extensionName,
Data: extensionData,
}
requestCtx, requestCancel := context.WithCancel(ctx)
defer requestCancel()
graphSync.Request(requestCtx, host2.ID(), blockChain.tipLink, spec)
graphSync.Request(requestCtx, host2.ID(), blockChain.tipLink, spec, extension)
var message receivedMessage
select {
......@@ -208,6 +216,11 @@ func TestMakeRequestToNetwork(t *testing.T) {
if err != nil {
t.Fatal("did not receive parsible selector on other side")
}
returnedData, found := receivedRequest.Extension(extensionName)
if !found || !reflect.DeepEqual(extensionData, returnedData) {
t.Fatal("Failed to encode extension")
}
}
func TestSendResponseToIncomingRequest(t *testing.T) {
......
......@@ -96,6 +96,7 @@ type newRequestMessage struct {
p peer.ID
root ipld.Link
selector ipld.Node
extensions []graphsync.ExtensionData
inProgressRequestChan chan<- inProgressRequest
}
......@@ -103,7 +104,8 @@ type newRequestMessage struct {
func (rm *RequestManager) SendRequest(ctx context.Context,
p peer.ID,
root ipld.Link,
selector ipld.Node) (<-chan graphsync.ResponseProgress, <-chan error) {
selector ipld.Node,
extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) {
if _, err := rm.ipldBridge.ParseSelector(selector); err != nil {
return rm.singleErrorResponse(fmt.Errorf("Invalid Selector Spec"))
}
......@@ -111,7 +113,7 @@ func (rm *RequestManager) SendRequest(ctx context.Context,
inProgressRequestChan := make(chan inProgressRequest)
select {
case rm.messages <- &newRequestMessage{p, root, selector, inProgressRequestChan}:
case rm.messages <- &newRequestMessage{p, root, selector, extensions, inProgressRequestChan}:
case <-rm.ctx.Done():
return rm.emptyResponse()
case <-ctx.Done():
......@@ -234,7 +236,7 @@ func (nrm *newRequestMessage) handle(rm *RequestManager) {
requestID := rm.nextRequestID
rm.nextRequestID++
inProgressChan, inProgressErr := rm.setupRequest(requestID, nrm.p, nrm.root, nrm.selector)
inProgressChan, inProgressErr := rm.setupRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
select {
case nrm.inProgressRequestChan <- inProgressRequest{
......@@ -314,7 +316,7 @@ func (rm *RequestManager) generateResponseErrorFromStatus(status graphsync.Respo
}
}
func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node) (chan graphsync.ResponseProgress, chan error) {
func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node, extensions []graphsync.ExtensionData) (chan graphsync.ResponseProgress, chan error) {
selectorBytes, err := rm.ipldBridge.EncodeNode(selectorSpec)
if err != nil {
return rm.singleErrorResponse(err)
......@@ -333,7 +335,7 @@ func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID,
ctx, cancel, p, networkErrorChan,
}
rm.asyncLoader.StartRequest(requestID)
rm.peerHandler.SendRequest(p, gsmsg.NewRequest(requestID, asCidLink.Cid, selectorBytes, maxPriority))
rm.peerHandler.SendRequest(p, gsmsg.NewRequest(requestID, asCidLink.Cid, selectorBytes, maxPriority, extensions...))
return rm.executeTraversal(ctx, requestID, root, selector, networkErrorChan)
}
......
......@@ -600,3 +600,50 @@ func TestRequestReturnsMissingBlocks(t *testing.T) {
}
}
func TestEncodingExtensions(t *testing.T) {
requestRecordChan := make(chan requestRecord, 2)
fph := &fakePeerHandler{requestRecordChan}
fakeIPLDBridge := testbridge.NewMockIPLDBridge()
ctx := context.Background()
fal := newFakeAsyncLoader()
requestManager := New(ctx, fal, fakeIPLDBridge)
requestManager.SetDelegate(fph)
requestManager.Startup()
requestCtx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
peers := testutil.GeneratePeers(1)
cids := testutil.GenerateCids(1)
root := cidlink.Link{Cid: cids[0]}
selector := testbridge.NewMockSelectorSpec(cids)
extensionData1 := testutil.RandomBytes(100)
extensionName1 := graphsync.ExtensionName("AppleSauce/McGee")
extension1 := graphsync.ExtensionData{
Name: extensionName1,
Data: extensionData1,
}
extensionData2 := testutil.RandomBytes(100)
extensionName2 := graphsync.ExtensionName("HappyLand/Happenstance")
extension2 := graphsync.ExtensionData{
Name: extensionName2,
Data: extensionData2,
}
_, _ = requestManager.SendRequest(requestCtx, peers[0], root, selector, extension1, extension2)
rr := readNNetworkRequests(requestCtx, t, requestRecordChan, 1)[0]
gsr := rr.gsr
returnedData1, found := gsr.Extension(extensionName1)
if !found || !reflect.DeepEqual(extensionData1, returnedData1) {
t.Fatal("Failed to encode first extension")
}
returnedData2, found := gsr.Extension(extensionName2)
if !found || !reflect.DeepEqual(extensionData2, returnedData2) {
t.Fatal("Failed to encode first extension")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment