graphsync_test.go 14.8 KB
Newer Older
1 2 3 4
package graphsync

import (
	"context"
5
	"errors"
6 7
	"math"
	"math/rand"
8 9 10 11
	"reflect"
	"testing"
	"time"

12
	"github.com/ipld/go-ipld-prime/fluent"
13 14
	ipldfree "github.com/ipld/go-ipld-prime/impl/free"

15
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
16

17
	blocks "github.com/ipfs/go-block-format"
18
	cid "github.com/ipfs/go-cid"
19
	"github.com/ipfs/go-graphsync"
20

21
	"github.com/ipfs/go-graphsync/ipldbridge"
22 23 24 25 26
	gsmsg "github.com/ipfs/go-graphsync/message"
	gsnet "github.com/ipfs/go-graphsync/network"
	"github.com/ipfs/go-graphsync/testbridge"
	"github.com/ipfs/go-graphsync/testutil"
	ipld "github.com/ipld/go-ipld-prime"
Edgar Lee's avatar
Edgar Lee committed
27
	ipldselector "github.com/ipld/go-ipld-prime/traversal/selector"
28
	"github.com/ipld/go-ipld-prime/traversal/selector/builder"
29
	"github.com/libp2p/go-libp2p-core/host"
30
	"github.com/libp2p/go-libp2p-core/peer"
31
	mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
32
	mh "github.com/multiformats/go-multihash"
33 34 35 36 37 38 39
)

func TestMakeRequestToNetwork(t *testing.T) {
	// create network
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
	defer cancel()
40
	td := newGsTestData(ctx, t)
41
	r := &receiver{
42
		messageReceived: make(chan receivedMessage),
43
	}
44 45
	td.gsnet2.SetDelegate(r)
	graphSync := td.GraphSyncHost1()
46

47
	blockChainLength := 100
48
	blockChain := setupBlockChain(ctx, t, td.storer1, td.bridge, 100, blockChainLength)
49

50
	spec := blockChainSelector(blockChainLength)
51

52 53
	requestCtx, requestCancel := context.WithCancel(ctx)
	defer requestCancel()
54
	graphSync.Request(requestCtx, td.host2.ID(), blockChain.tipLink, spec, td.extension)
55

56
	var message receivedMessage
57 58 59
	select {
	case <-ctx.Done():
		t.Fatal("did not receive message sent")
60
	case message = <-r.messageReceived:
61 62
	}

63
	sender := message.sender
64
	if sender != td.host1.ID() {
65 66 67
		t.Fatal("received message from wrong node")
	}

68
	received := message.message
69 70 71 72 73
	receivedRequests := received.Requests()
	if len(receivedRequests) != 1 {
		t.Fatal("Did not add request to received message")
	}
	receivedRequest := receivedRequests[0]
74
	receivedSpec, err := td.bridge.DecodeNode(receivedRequest.Selector())
75 76 77 78 79 80
	if err != nil {
		t.Fatal("unable to decode transmitted selector")
	}
	if !reflect.DeepEqual(spec, receivedSpec) {
		t.Fatal("did not transmit selector spec correctly")
	}
81
	_, err = td.bridge.ParseSelector(receivedSpec)
82 83 84
	if err != nil {
		t.Fatal("did not receive parsible selector on other side")
	}
85

86 87
	returnedData, found := receivedRequest.Extension(td.extensionName)
	if !found || !reflect.DeepEqual(td.extensionData, returnedData) {
88 89
		t.Fatal("Failed to encode extension")
	}
90
}
91 92 93 94 95 96

func TestSendResponseToIncomingRequest(t *testing.T) {
	// create network
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
	defer cancel()
97
	td := newGsTestData(ctx, t)
98 99 100
	r := &receiver{
		messageReceived: make(chan receivedMessage),
	}
101
	td.gsnet1.SetDelegate(r)
102 103

	var receivedRequestData []byte
104
	// initialize graphsync on second node to response to requests
105 106
	gsnet := td.GraphSyncHost2()
	err := gsnet.RegisterRequestReceivedHook(
107
		func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
108
			var has bool
109
			receivedRequestData, has = requestData.Extension(td.extensionName)
110 111 112
			if !has {
				t.Fatal("did not have expected extension")
			}
113
			hookActions.SendExtensionData(td.extensionResponse)
114 115 116 117 118
		},
	)
	if err != nil {
		t.Fatal("error registering extension")
	}
119

120
	blockChainLength := 100
121
	blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 100, blockChainLength)
122

123 124 125
	spec := blockChainSelector(blockChainLength)

	selectorData, err := td.bridge.EncodeNode(spec)
126 127 128
	if err != nil {
		t.Fatal("could not encode selector spec")
	}
129
	requestID := graphsync.RequestID(rand.Int31())
130 131

	message := gsmsg.New()
132
	message.AddRequest(gsmsg.NewRequest(requestID, blockChain.tipLink.(cidlink.Link).Cid, selectorData, graphsync.Priority(math.MaxInt32), td.extension))
133
	// send request across network
134
	td.gsnet1.SendMessage(ctx, td.host2.ID(), message)
135 136 137
	// read the values sent back to requestor
	var received gsmsg.GraphSyncMessage
	var receivedBlocks []blocks.Block
138
	var receivedExtensions [][]byte
139 140 141 142 143 144 145
readAllMessages:
	for {
		select {
		case <-ctx.Done():
			t.Fatal("did not receive complete response")
		case message := <-r.messageReceived:
			sender := message.sender
146
			if sender != td.host2.ID() {
147 148 149 150 151 152
				t.Fatal("received message from wrong node")
			}

			received = message.message
			receivedBlocks = append(receivedBlocks, received.Blocks()...)
			receivedResponses := received.Responses()
153
			receivedExtension, found := receivedResponses[0].Extension(td.extensionName)
154 155 156
			if found {
				receivedExtensions = append(receivedExtensions, receivedExtension)
			}
157 158 159 160 161 162
			if len(receivedResponses) != 1 {
				t.Fatal("Did not receive response")
			}
			if receivedResponses[0].RequestID() != requestID {
				t.Fatal("Sent response for incorrect request id")
			}
163
			if receivedResponses[0].Status() != graphsync.PartialResponse {
164 165 166 167 168
				break readAllMessages
			}
		}
	}

169
	if len(receivedBlocks) != blockChainLength {
170 171
		t.Fatal("Send incorrect number of blocks or there were duplicate blocks")
	}
172

173
	if !reflect.DeepEqual(td.extensionData, receivedRequestData) {
174 175 176 177 178 179 180
		t.Fatal("did not receive correct request extension data")
	}

	if len(receivedExtensions) != 1 {
		t.Fatal("should have sent extension responses but didn't")
	}

181
	if !reflect.DeepEqual(receivedExtensions[0], td.extensionResponseData) {
182 183
		t.Fatal("did not return correct extension data")
	}
184
}
185 186 187 188 189 190

func TestGraphsyncRoundTrip(t *testing.T) {
	// create network
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
	defer cancel()
191
	td := newGsTestData(ctx, t)
192

193
	// initialize graphsync on first node to make requests
194
	requestor := td.GraphSyncHost1()
195 196

	// setup receiving peer to just record message coming in
197
	blockChainLength := 100
198
	blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 100, blockChainLength)
199 200

	// initialize graphsync on second node to response to requests
201
	responder := td.GraphSyncHost2()
202 203 204 205

	var receivedResponseData []byte
	var receivedRequestData []byte

206
	err := requestor.RegisterResponseReceivedHook(
207
		func(p peer.ID, responseData graphsync.ResponseData) error {
208
			data, has := responseData.Extension(td.extensionName)
209 210 211 212 213 214 215 216 217 218 219
			if has {
				receivedResponseData = data
			}
			return nil
		})
	if err != nil {
		t.Fatal("Error setting up extension")
	}

	err = responder.RegisterRequestReceivedHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
		var has bool
220
		receivedRequestData, has = requestData.Extension(td.extensionName)
221 222 223
		if !has {
			hookActions.TerminateWithError(errors.New("Missing extension"))
		} else {
224
			hookActions.SendExtensionData(td.extensionResponse)
225 226 227 228 229 230
		}
	})

	if err != nil {
		t.Fatal("Error setting up extension")
	}
231

232
	spec := blockChainSelector(blockChainLength)
233

234
	progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.tipLink, spec, td.extension)
235 236 237 238

	responses := testutil.CollectResponses(ctx, t, progressChan)
	errs := testutil.CollectErrors(ctx, t, errChan)

239
	if len(responses) != blockChainLength*2 {
240 241
		t.Fatal("did not traverse all nodes")
	}
242 243
	if len(errs) != 0 {
		t.Fatal("errors during traverse")
244
	}
245
	if len(td.blockStore1) != blockChainLength {
246 247
		t.Fatal("did not store all blocks")
	}
248 249 250 251 252 253 254 255 256 257 258 259 260 261

	expectedPath := ""
	for i, response := range responses {
		if response.Path.String() != expectedPath {
			t.Fatal("incorrect path")
		}
		if i%2 == 0 {
			if expectedPath == "" {
				expectedPath = "Parents"
			} else {
				expectedPath = expectedPath + "/Parents"
			}
		} else {
			expectedPath = expectedPath + "/0"
262 263
		}
	}
264 265

	// verify extension roundtrip
266
	if !reflect.DeepEqual(receivedRequestData, td.extensionData) {
267 268 269
		t.Fatal("did not receive correct extension request data")
	}

270
	if !reflect.DeepEqual(receivedResponseData, td.extensionResponseData) {
271 272
		t.Fatal("did not receive correct extension response data")
	}
273
}
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290

// TestRoundTripLargeBlocksSlowNetwork test verifies graphsync continues to work
// under a specific of adverse conditions:
// -- large blocks being returned by a query
// -- slow network connection
// It verifies that Graphsync will properly break up network message packets
// so they can still be decoded on the client side, instead of building up a huge
// backlog of blocks and then sending them in one giant network packet that can't
// be decoded on the client side
func TestRoundTripLargeBlocksSlowNetwork(t *testing.T) {
	// create network
	if testing.Short() {
		t.Skip()
	}
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
	defer cancel()
291 292 293 294 295
	td := newGsTestData(ctx, t)
	td.mn.SetLinkDefaults(mocknet.LinkOptions{Latency: 100 * time.Millisecond, Bandwidth: 3000000})

	// initialize graphsync on first node to make requests
	requestor := td.GraphSyncHost1()
296

297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
	// setup receiving peer to just record message coming in
	blockChainLength := 40
	blockChain := setupBlockChain(ctx, t, td.storer2, td.bridge, 200000, blockChainLength)

	// initialize graphsync on second node to response to requests
	td.GraphSyncHost2()

	spec := blockChainSelector(blockChainLength)
	progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.tipLink, spec)

	responses := testutil.CollectResponses(ctx, t, progressChan)
	errs := testutil.CollectErrors(ctx, t, errChan)

	if len(responses) != blockChainLength*2 {
		t.Fatal("did not traverse all nodes")
	}
	if len(errs) != 0 {
		t.Fatal("errors during traverse")
	}
}

type gsTestData struct {
	mn                       mocknet.Mocknet
	ctx                      context.Context
	host1                    host.Host
	host2                    host.Host
	gsnet1                   gsnet.GraphSyncNetwork
	gsnet2                   gsnet.GraphSyncNetwork
	blockStore1, blockStore2 map[ipld.Link][]byte
	loader1, loader2         ipld.Loader
	storer1, storer2         ipld.Storer
	bridge                   ipldbridge.IPLDBridge
	extensionData            []byte
	extensionName            graphsync.ExtensionName
	extension                graphsync.ExtensionData
	extensionResponseData    []byte
	extensionResponse        graphsync.ExtensionData
}

func newGsTestData(ctx context.Context, t *testing.T) *gsTestData {
	td := &gsTestData{ctx: ctx}
	td.mn = mocknet.New(ctx)
	var err error
340
	// setup network
341
	td.host1, err = td.mn.GenPeer()
342 343 344
	if err != nil {
		t.Fatal("error generating host")
	}
345
	td.host2, err = td.mn.GenPeer()
346 347 348
	if err != nil {
		t.Fatal("error generating host")
	}
349
	err = td.mn.LinkAll()
350 351 352 353
	if err != nil {
		t.Fatal("error linking hosts")
	}

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
	td.gsnet1 = gsnet.NewFromLibp2pHost(td.host1)
	td.gsnet2 = gsnet.NewFromLibp2pHost(td.host2)
	td.blockStore1 = make(map[ipld.Link][]byte)
	td.loader1, td.storer1 = testbridge.NewMockStore(td.blockStore1)
	td.blockStore2 = make(map[ipld.Link][]byte)
	td.loader2, td.storer2 = testbridge.NewMockStore(td.blockStore2)
	td.bridge = ipldbridge.NewIPLDBridge()
	// setup extension handlers
	td.extensionData = testutil.RandomBytes(100)
	td.extensionName = graphsync.ExtensionName("AppleSauce/McGee")
	td.extension = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionData,
	}
	td.extensionResponseData = testutil.RandomBytes(100)
	td.extensionResponse = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionResponseData,
	}
373

374 375
	return td
}
376

377 378 379
func (td *gsTestData) GraphSyncHost1() graphsync.GraphExchange {
	return New(td.ctx, td.gsnet1, td.bridge, td.loader1, td.storer1)
}
380

381
func (td *gsTestData) GraphSyncHost2() graphsync.GraphExchange {
382

383 384
	return New(td.ctx, td.gsnet2, td.bridge, td.loader2, td.storer2)
}
385

386 387 388 389
type receivedMessage struct {
	message gsmsg.GraphSyncMessage
	sender  peer.ID
}
390

391 392 393 394
// Receiver is an interface for receiving messages from the GraphSyncNetwork.
type receiver struct {
	messageReceived chan receivedMessage
}
395

396 397 398 399
func (r *receiver) ReceiveMessage(
	ctx context.Context,
	sender peer.ID,
	incoming gsmsg.GraphSyncMessage) {
400

401 402 403 404 405
	select {
	case <-ctx.Done():
	case r.messageReceived <- receivedMessage{incoming, sender}:
	}
}
406

407 408
func (r *receiver) ReceiveError(err error) {
}
409

410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
func (r *receiver) Connected(p peer.ID) {
}

func (r *receiver) Disconnected(p peer.ID) {
}

type blockChain struct {
	genisisNode ipld.Node
	genisisLink ipld.Link
	middleNodes []ipld.Node
	middleLinks []ipld.Link
	tipNode     ipld.Node
	tipLink     ipld.Link
}

func createBlock(nb ipldbridge.NodeBuilder, parents []ipld.Link, size int64) ipld.Node {
	return nb.CreateMap(func(mb ipldbridge.MapBuilder, knb ipldbridge.NodeBuilder, vnb ipldbridge.NodeBuilder) {
		mb.Insert(knb.CreateString("Parents"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) {
			for _, parent := range parents {
				lb.Append(vnb.CreateLink(parent))
			}
		}))
		mb.Insert(knb.CreateString("Messages"), vnb.CreateList(func(lb ipldbridge.ListBuilder, vnb ipldbridge.NodeBuilder) {
			lb.Append(vnb.CreateBytes(testutil.RandomBytes(size)))
		}))
	})
}

func setupBlockChain(
	ctx context.Context,
	t *testing.T,
	storer ipldbridge.Storer,
	bridge ipldbridge.IPLDBridge,
	size int64,
	blockChainLength int) *blockChain {
	linkBuilder := cidlink.LinkBuilder{Prefix: cid.NewPrefixV1(cid.DagCBOR, mh.SHA2_256)}
	var genisisNode ipld.Node
	err := fluent.Recover(func() {
		nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
		genisisNode = createBlock(nb, []ipld.Link{}, size)
	})
	if err != nil {
		t.Fatal("Error creating genesis block")
453
	}
454 455 456
	genesisLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, genisisNode, storer)
	if err != nil {
		t.Fatal("Error creating link to genesis block")
457
	}
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
	parent := genesisLink
	middleNodes := make([]ipld.Node, 0, blockChainLength-2)
	middleLinks := make([]ipld.Link, 0, blockChainLength-2)
	for i := 0; i < blockChainLength-2; i++ {
		var node ipld.Node
		err := fluent.Recover(func() {
			nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
			node = createBlock(nb, []ipld.Link{parent}, size)
		})
		if err != nil {
			t.Fatal("Error creating middle block")
		}
		middleNodes = append(middleNodes, node)
		link, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, node, storer)
		if err != nil {
			t.Fatal("Error creating link to middle block")
		}
		middleLinks = append(middleLinks, link)
		parent = link
	}
	var tipNode ipld.Node
	err = fluent.Recover(func() {
		nb := fluent.WrapNodeBuilder(ipldfree.NodeBuilder())
		tipNode = createBlock(nb, []ipld.Link{parent}, size)
	})
	if err != nil {
		t.Fatal("Error creating tip block")
	}
	tipLink, err := linkBuilder.Build(ctx, ipldbridge.LinkContext{}, tipNode, storer)
	if err != nil {
		t.Fatal("Error creating link to tip block")
	}
	return &blockChain{genisisNode, genesisLink, middleNodes, middleLinks, tipNode, tipLink}
}

func blockChainSelector(blockChainLength int) ipld.Node {
	ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
	return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
		ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
			efsb.Insert("Parents", ssb.ExploreAll(
				ssb.ExploreRecursiveEdge()))
		})).Node()
500
}