responsemanager_test.go 32.9 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10
	"math/rand"
	"sync"
	"testing"
	"time"

11
	"github.com/ipfs/go-graphsync"
12
	gsmsg "github.com/ipfs/go-graphsync/message"
Hannah Howard's avatar
Hannah Howard committed
13
	"github.com/ipfs/go-graphsync/responsemanager/hooks"
14
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
15
	"github.com/ipfs/go-graphsync/responsemanager/persistenceoptions"
16
	"github.com/ipfs/go-graphsync/selectorvalidator"
17
	"github.com/ipfs/go-graphsync/testutil"
18
	"github.com/ipfs/go-peertaskqueue/peertask"
19
	ipld "github.com/ipld/go-ipld-prime"
20
	ipldfree "github.com/ipld/go-ipld-prime/impl/free"
21
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
22
	"github.com/libp2p/go-libp2p-core/peer"
Hannah Howard's avatar
Hannah Howard committed
23
	"github.com/stretchr/testify/require"
24 25 26 27 28
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
29
	queries   []*peertask.QueueTask
30 31
}

32
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
33
	fqq.queriesLk.Lock()
34 35 36 37 38 39

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
40 41 42
	fqq.queriesLk.Unlock()
}

43
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
44 45 46 47
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
48
		return "", nil, -1
49
	}
50 51
	// We're not bothering to implement "work"
	task := fqq.queries[0]
52
	fqq.queries = fqq.queries[1:]
53
	return task.Target, []*peertask.Task{&task.Task}, 0
54 55
}

56
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
57 58 59
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
60 61
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
62 63 64 65
		}
	}
}

66 67 68 69
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

70 71 72
func (fqq *fakeQueryQueue) ThawRound() {

}
73 74 75 76 77 78 79 80 81 82 83 84

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
85
	requestID graphsync.RequestID
86 87 88 89
	link      ipld.Link
	data      []byte
}

90 91 92 93 94 95 96 97 98
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
Hannah Howard's avatar
Hannah Howard committed
99 100 101 102
type pausedRequest struct {
	requestID graphsync.RequestID
}

103 104
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
105 106
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
Hannah Howard's avatar
Hannah Howard committed
107
	pausedRequests       chan pausedRequest
108 109 110 111 112
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
type fakeBlkData struct {
	link ipld.Link
	size uint64
}

func (fbd fakeBlkData) Link() ipld.Link {
	return fbd.link
}

func (fbd fakeBlkData) BlockSize() uint64 {
	return fbd.size
}

func (fbd fakeBlkData) BlockSizeOnWire() uint64 {
	return fbd.size
}

130
func (fprs *fakePeerResponseSender) SendResponse(
131
	requestID graphsync.RequestID,
132 133
	link ipld.Link,
	data []byte,
134
) graphsync.BlockData {
135
	fprs.sentResponses <- sentResponse{requestID, link, data}
136
	return fakeBlkData{link, uint64(len(data))}
137 138
}

139 140 141 142 143 144 145
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

146
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
147
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
148 149
}

150
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
151
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
152 153
}

Hannah Howard's avatar
Hannah Howard committed
154 155 156 157
func (fprs *fakePeerResponseSender) PauseRequest(requestID graphsync.RequestID) {
	fprs.pausedRequests <- pausedRequest{requestID}
}

158
func TestIncomingQuery(t *testing.T) {
159 160 161 162
	td := newTestData(t)
	defer td.cancel()
	blks := td.blockChain.AllBlocks()

Hannah Howard's avatar
Hannah Howard committed
163
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
164
	td.requestHooks.Register(selectorvalidator.SelectorValidator(100))
165 166
	responseManager.Startup()

167 168
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
	testutil.AssertDoesReceive(td.ctx, t, td.completedRequestChan, "Should have completed request but didn't")
169
	for i := 0; i < len(blks); i++ {
Hannah Howard's avatar
Hannah Howard committed
170
		var sentResponse sentResponse
171
		testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send responses")
Hannah Howard's avatar
Hannah Howard committed
172 173 174 175
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		require.NotEqual(t, blockIndex, -1, "sent incorrect link")
		require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
176
		require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id")
177 178 179 180
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
181 182 183
	td := newTestData(t)
	defer td.cancel()
	blks := td.blockChain.AllBlocks()
Hannah Howard's avatar
Hannah Howard committed
184
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
185
	td.requestHooks.Register(selectorvalidator.SelectorValidator(100))
186
	responseManager.Startup()
187
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
188 189

	// read one block
Hannah Howard's avatar
Hannah Howard committed
190
	var sentResponse sentResponse
191
	testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send response")
Hannah Howard's avatar
Hannah Howard committed
192 193 194 195
	k := sentResponse.link.(cidlink.Link)
	blockIndex := testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "sent incorrect link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
196
	require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id")
197 198

	// send a cancellation
199 200
	cancelRequests := []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(td.requestID),
201
	}
202
	responseManager.ProcessRequests(td.ctx, td.p, cancelRequests)
203 204 205

	responseManager.synchronize()

206 207
	// at this point we should receive at most one more block, then traversal
	// should complete
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
	additionalBlocks := 0
	for {
		select {
		case <-td.ctx.Done():
			t.Fatal("should complete request before context closes")
		case sentResponse = <-td.sentResponses:
			k = sentResponse.link.(cidlink.Link)
			blockIndex = testutil.IndexOf(blks, k.Cid)
			require.NotEqual(t, blockIndex, -1, "did not send correct link")
			require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
			require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id")
			additionalBlocks++
		case <-td.completedRequestChan:
			require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block")
			return
		}
	}
225 226 227
}

func TestEarlyCancellation(t *testing.T) {
228 229 230
	td := newTestData(t)
	defer td.cancel()
	td.queryQueue.popWait.Add(1)
Hannah Howard's avatar
Hannah Howard committed
231
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
232
	responseManager.Startup()
233
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
234 235

	// send a cancellation
236 237
	cancelRequests := []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(td.requestID),
238
	}
239
	responseManager.ProcessRequests(td.ctx, td.p, cancelRequests)
240 241 242 243

	responseManager.synchronize()

	// unblock popping from queue
244
	td.queryQueue.popWait.Done()
245

246
	timer := time.NewTimer(time.Second)
247
	// verify no responses processed
248
	testutil.AssertDoesReceiveFirst(t, timer.C, "should not process more responses", td.sentResponses, td.completedRequestChan)
249
}
250 251

func TestValidationAndExtensions(t *testing.T) {
252
	t.Run("on its own, should fail validation", func(t *testing.T) {
253 254
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
255
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
256
		responseManager.Startup()
257
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
258
		var lastRequest completedRequest
259
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
260 261 262 263
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
	})

	t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
264 265
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
266
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
267
		responseManager.Startup()
268 269
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			hookActions.SendExtensionData(td.extensionResponse)
270
		})
271
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
272
		var lastRequest completedRequest
273
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
274 275
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
276 277
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
278
	})
279

280
	t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
281 282
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
283
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
284
		responseManager.Startup()
285
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
286
			hookActions.ValidateRequest()
287
			hookActions.SendExtensionData(td.extensionResponse)
288
		})
289
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
290
		var lastRequest completedRequest
291
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
292 293
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
294 295
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
296 297
	})

298
	t.Run("if any hook fails, should fail", func(t *testing.T) {
299 300
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
301
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
302
		responseManager.Startup()
303
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
304
			hookActions.ValidateRequest()
305
		})
306 307
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			hookActions.SendExtensionData(td.extensionResponse)
308 309
			hookActions.TerminateWithError(errors.New("everything went to crap"))
		})
310
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
311
		var lastRequest completedRequest
312
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
313 314
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
315 316
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
317
	})
318

319
	t.Run("hooks can be unregistered", func(t *testing.T) {
320 321
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
322
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
323
		responseManager.Startup()
324
		unregister := td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
325
			hookActions.ValidateRequest()
326
			hookActions.SendExtensionData(td.extensionResponse)
327
		})
328 329

		// hook validates request
330
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
331
		var lastRequest completedRequest
332
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
333 334
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
335 336
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
337 338 339 340 341

		// unregister
		unregister()

		// no same request should fail
342 343
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
344
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
345
	})
346 347

	t.Run("hooks can alter the loader", func(t *testing.T) {
348 349
		td := newTestData(t)
		defer td.cancel()
350 351
		obs := make(map[ipld.Link][]byte)
		oloader, _ := testutil.NewTestStore(obs)
Hannah Howard's avatar
Hannah Howard committed
352
		responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
353 354
		responseManager.Startup()
		// add validating hook -- so the request SHOULD succeed
355
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
356 357 358 359 360
			hookActions.ValidateRequest()
		})

		// request fails with base loader reading from block store that's missing data
		var lastRequest completedRequest
361 362
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
363 364
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")

365
		err := td.peristenceOptions.Register("chainstore", td.loader)
366
		require.NoError(t, err)
367
		// register hook to use different loader
368 369
		_ = td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			if _, found := requestData.Extension(td.extensionName); found {
370
				hookActions.UsePersistenceOption("chainstore")
371
				hookActions.SendExtensionData(td.extensionResponse)
372 373 374
			}
		})
		// hook uses different loader that should make request succeed
375 376
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
377 378
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
379 380
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
381 382 383
	})

	t.Run("hooks can alter the node builder chooser", func(t *testing.T) {
384 385
		td := newTestData(t)
		defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
386
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
387 388 389
		responseManager.Startup()

		customChooserCallCount := 0
390
		customChooser := func(ipld.Link, ipld.LinkContext) (ipld.NodeBuilder, error) {
391
			customChooserCallCount++
392
			return ipldfree.NodeBuilder(), nil
393 394 395
		}

		// add validating hook -- so the request SHOULD succeed
396
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
397 398 399 400 401
			hookActions.ValidateRequest()
		})

		// with default chooser, customer chooser not called
		var lastRequest completedRequest
402 403
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
404 405 406 407
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		require.Equal(t, 0, customChooserCallCount)

		// register hook to use custom chooser
408 409
		_ = td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			if _, found := requestData.Extension(td.extensionName); found {
410
				hookActions.UseNodeBuilderChooser(customChooser)
411
				hookActions.SendExtensionData(td.extensionResponse)
412 413 414 415
			}
		})

		// verify now that request succeeds and uses custom chooser
416 417
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
418 419
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
420 421
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
422 423
		require.Equal(t, 5, customChooserCallCount)
	})
424 425 426 427 428

	t.Run("test block hook processing", func(t *testing.T) {
		t.Run("can send extension data", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
429
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
				hookActions.SendExtensionData(td.extensionResponse)
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
			for i := 0; i < td.blockChainLength; i++ {
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
				require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
			}
		})

		t.Run("can send errors", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
451
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
452 453 454 455 456 457 458 459 460 461
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
				hookActions.TerminateWithError(errors.New("failed"))
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
Hannah Howard's avatar
Hannah Howard committed
462
			require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")
463 464 465 466 467
		})

		t.Run("can pause/unpause", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
Hannah Howard's avatar
Hannah Howard committed
468
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
469 470 471 472
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
Hannah Howard's avatar
Hannah Howard committed
473
			blkIndex := 0
474 475
			blockCount := 3
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
Hannah Howard's avatar
Hannah Howard committed
476 477
				blkIndex++
				if blkIndex == blockCount {
478 479
					hookActions.PauseResponse()
				}
Hannah Howard's avatar
Hannah Howard committed
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			timer := time.NewTimer(500 * time.Millisecond)
			testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			for i := 0; i < blockCount; i++ {
				testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
			}
			testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
			var pausedRequest pausedRequest
			testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
			err := responseManager.UnpauseResponse(td.p, td.requestID)
			require.NoError(t, err)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		})

	})

	t.Run("test update hook processing", func(t *testing.T) {

		t.Run("can pause/unpause", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			blkIndex := 0
			blockCount := 3
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
512
				blkIndex++
Hannah Howard's avatar
Hannah Howard committed
513 514 515 516 517 518 519 520
				if blkIndex == blockCount {
					hookActions.PauseResponse()
				}
			})
			td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
				if _, found := updateData.Extension(td.extensionName); found {
					hookActions.UnpauseResponse()
				}
521 522 523 524 525
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			timer := time.NewTimer(500 * time.Millisecond)
			testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			var sentResponses []sentResponse
Hannah Howard's avatar
Hannah Howard committed
526 527
			for i := 0; i < blockCount; i++ {
				testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
528
			}
Hannah Howard's avatar
Hannah Howard committed
529 530 531
			testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
			var pausedRequest pausedRequest
			testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
532
			require.LessOrEqual(t, len(sentResponses), blockCount)
Hannah Howard's avatar
Hannah Howard committed
533
			responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
534 535 536 537 538
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		})

Hannah Howard's avatar
Hannah Howard committed
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
		t.Run("can send extension data", func(t *testing.T) {
			t.Run("when unpaused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				wait := make(chan struct{})
				sent := make(chan struct{})
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						close(sent)
						<-wait
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.SendExtensionData(td.extensionResponse)
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks")
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
				responseManager.synchronize()
				close(wait)
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
				require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
			})

			t.Run("when paused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						hookActions.PauseResponse()
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.SendExtensionData(td.extensionResponse)
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				var sentResponses []sentResponse
				for i := 0; i < blockCount; i++ {
					testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
				}
				testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
				var pausedRequest pausedRequest
				testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
				require.LessOrEqual(t, len(sentResponses), blockCount)

				// send update
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)

				// receive data
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")

				// should still be paused
				timer := time.NewTimer(500 * time.Millisecond)
				testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			})
		})

		t.Run("can send errors", func(t *testing.T) {
			t.Run("when unpaused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				wait := make(chan struct{})
				sent := make(chan struct{})
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						close(sent)
						<-wait
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.TerminateWithError(errors.New("something went wrong"))
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks")
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
				responseManager.synchronize()
				close(wait)
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")
			})

			t.Run("when paused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks)
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						hookActions.PauseResponse()
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.TerminateWithError(errors.New("something went wrong"))
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				var sentResponses []sentResponse
				for i := 0; i < blockCount; i++ {
					testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
				}
				testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
				var pausedRequest pausedRequest
				testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
				require.LessOrEqual(t, len(sentResponses), blockCount)

				// send update
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)

				// should terminate
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")

				// cannot unpause
				err := responseManager.UnpauseResponse(td.p, td.requestID)
				require.Error(t, err)
			})
		})

701 702 703 704 705 706 707 708 709 710 711 712 713 714
	})
}

type testData struct {
	ctx                   context.Context
	cancel                func()
	blockStore            map[ipld.Link][]byte
	loader                ipld.Loader
	storer                ipld.Storer
	blockChainLength      int
	blockChain            *testutil.TestBlockChain
	completedRequestChan  chan completedRequest
	sentResponses         chan sentResponse
	sentExtensions        chan sentExtension
Hannah Howard's avatar
Hannah Howard committed
715
	pausedRequests        chan pausedRequest
716 717 718 719 720 721 722
	peerManager           *fakePeerManager
	queryQueue            *fakeQueryQueue
	extensionData         []byte
	extensionName         graphsync.ExtensionName
	extension             graphsync.ExtensionData
	extensionResponseData []byte
	extensionResponse     graphsync.ExtensionData
Hannah Howard's avatar
Hannah Howard committed
723 724
	extensionUpdateData   []byte
	extensionUpdate       graphsync.ExtensionData
725 726
	requestID             graphsync.RequestID
	requests              []gsmsg.GraphSyncRequest
Hannah Howard's avatar
Hannah Howard committed
727
	updateRequests        []gsmsg.GraphSyncRequest
728 729
	p                     peer.ID
	peristenceOptions     *persistenceoptions.PersistenceOptions
Hannah Howard's avatar
Hannah Howard committed
730 731 732
	requestHooks          *hooks.IncomingRequestHooks
	blockHooks            *hooks.OutgoingBlockHooks
	updateHooks           *hooks.RequestUpdatedHooks
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
}

func newTestData(t *testing.T) testData {
	ctx := context.Background()
	td := testData{}
	td.ctx, td.cancel = context.WithTimeout(ctx, 10*time.Second)

	td.blockStore = make(map[ipld.Link][]byte)
	td.loader, td.storer = testutil.NewTestStore(td.blockStore)
	td.blockChainLength = 5
	td.blockChain = testutil.SetupBlockChain(ctx, t, td.loader, td.storer, 100, td.blockChainLength)

	td.completedRequestChan = make(chan completedRequest, 1)
	td.sentResponses = make(chan sentResponse, td.blockChainLength*2)
	td.sentExtensions = make(chan sentExtension, td.blockChainLength*2)
Hannah Howard's avatar
Hannah Howard committed
748 749
	td.pausedRequests = make(chan pausedRequest, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: td.completedRequestChan, sentResponses: td.sentResponses, sentExtensions: td.sentExtensions, pausedRequests: td.pausedRequests}
750 751 752 753 754 755 756 757 758 759 760 761 762 763
	td.peerManager = &fakePeerManager{peerResponseSender: fprs}
	td.queryQueue = &fakeQueryQueue{}

	td.extensionData = testutil.RandomBytes(100)
	td.extensionName = graphsync.ExtensionName("AppleSauce/McGee")
	td.extension = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionData,
	}
	td.extensionResponseData = testutil.RandomBytes(100)
	td.extensionResponse = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionResponseData,
	}
Hannah Howard's avatar
Hannah Howard committed
764 765 766 767 768
	td.extensionUpdateData = testutil.RandomBytes(100)
	td.extensionUpdate = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionUpdateData,
	}
769 770 771 772
	td.requestID = graphsync.RequestID(rand.Int31())
	td.requests = []gsmsg.GraphSyncRequest{
		gsmsg.NewRequest(td.requestID, td.blockChain.TipLink.(cidlink.Link).Cid, td.blockChain.Selector(), graphsync.Priority(0), td.extension),
	}
Hannah Howard's avatar
Hannah Howard committed
773 774 775
	td.updateRequests = []gsmsg.GraphSyncRequest{
		gsmsg.UpdateRequest(td.requestID, td.extensionUpdate),
	}
776 777
	td.p = testutil.GeneratePeers(1)[0]
	td.peristenceOptions = persistenceoptions.New()
Hannah Howard's avatar
Hannah Howard committed
778 779 780
	td.requestHooks = hooks.NewRequestHooks(td.peristenceOptions)
	td.blockHooks = hooks.NewBlockHooks()
	td.updateHooks = hooks.NewUpdateHooks()
781
	return td
782
}