responsemanager_test.go 34.6 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10
	"math/rand"
	"sync"
	"testing"
	"time"

11
	"github.com/ipfs/go-graphsync"
12
	gsmsg "github.com/ipfs/go-graphsync/message"
Hannah Howard's avatar
Hannah Howard committed
13
	"github.com/ipfs/go-graphsync/responsemanager/hooks"
14
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
15
	"github.com/ipfs/go-graphsync/responsemanager/persistenceoptions"
16
	"github.com/ipfs/go-graphsync/selectorvalidator"
17
	"github.com/ipfs/go-graphsync/testutil"
18
	"github.com/ipfs/go-peertaskqueue/peertask"
19 20
	ipld "github.com/ipld/go-ipld-prime"
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
Hannah Howard's avatar
Hannah Howard committed
21
	basicnode "github.com/ipld/go-ipld-prime/node/basic"
22
	"github.com/libp2p/go-libp2p-core/peer"
Hannah Howard's avatar
Hannah Howard committed
23
	"github.com/stretchr/testify/require"
24 25 26 27 28
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
29
	queries   []*peertask.QueueTask
30 31
}

32
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
33
	fqq.queriesLk.Lock()
34 35 36 37 38 39

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
40 41 42
	fqq.queriesLk.Unlock()
}

43
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
44 45 46 47
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
48
		return "", nil, -1
49
	}
50 51
	// We're not bothering to implement "work"
	task := fqq.queries[0]
52
	fqq.queries = fqq.queries[1:]
53
	return task.Target, []*peertask.Task{&task.Task}, 0
54 55
}

56
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
57 58 59
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
60 61
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
62 63 64 65
		}
	}
}

66 67 68 69
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

70 71 72
func (fqq *fakeQueryQueue) ThawRound() {

}
73 74 75 76 77 78 79 80 81 82 83 84

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
85
	requestID graphsync.RequestID
86 87 88 89
	link      ipld.Link
	data      []byte
}

90 91 92 93 94 95 96 97 98
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
Hannah Howard's avatar
Hannah Howard committed
99 100 101 102
type pausedRequest struct {
	requestID graphsync.RequestID
}

103 104
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
105 106
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
Hannah Howard's avatar
Hannah Howard committed
107
	pausedRequests       chan pausedRequest
108 109 110 111 112
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
type fakeBlkData struct {
	link ipld.Link
	size uint64
}

func (fbd fakeBlkData) Link() ipld.Link {
	return fbd.link
}

func (fbd fakeBlkData) BlockSize() uint64 {
	return fbd.size
}

func (fbd fakeBlkData) BlockSizeOnWire() uint64 {
	return fbd.size
}

130
func (fprs *fakePeerResponseSender) SendResponse(
131
	requestID graphsync.RequestID,
132 133
	link ipld.Link,
	data []byte,
134
) graphsync.BlockData {
135
	fprs.sentResponses <- sentResponse{requestID, link, data}
136
	return fakeBlkData{link, uint64(len(data))}
137 138
}

139 140 141 142 143 144 145
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

146
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) graphsync.ResponseStatusCode {
147
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
148
	return graphsync.RequestCompletedFull
149 150
}

151
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
152
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
153 154
}

Hannah Howard's avatar
Hannah Howard committed
155 156 157 158
func (fprs *fakePeerResponseSender) PauseRequest(requestID graphsync.RequestID) {
	fprs.pausedRequests <- pausedRequest{requestID}
}

159
func TestIncomingQuery(t *testing.T) {
160 161 162 163
	td := newTestData(t)
	defer td.cancel()
	blks := td.blockChain.AllBlocks()

164
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
165
	td.requestHooks.Register(selectorvalidator.SelectorValidator(100))
166 167
	responseManager.Startup()

168 169
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
	testutil.AssertDoesReceive(td.ctx, t, td.completedRequestChan, "Should have completed request but didn't")
170
	for i := 0; i < len(blks); i++ {
Hannah Howard's avatar
Hannah Howard committed
171
		var sentResponse sentResponse
172
		testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send responses")
Hannah Howard's avatar
Hannah Howard committed
173 174 175 176
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		require.NotEqual(t, blockIndex, -1, "sent incorrect link")
		require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
177
		require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id")
178 179 180 181
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
182 183 184
	td := newTestData(t)
	defer td.cancel()
	blks := td.blockChain.AllBlocks()
185
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
186
	td.requestHooks.Register(selectorvalidator.SelectorValidator(100))
187
	responseManager.Startup()
188
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
189 190

	// read one block
Hannah Howard's avatar
Hannah Howard committed
191
	var sentResponse sentResponse
192
	testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send response")
Hannah Howard's avatar
Hannah Howard committed
193 194 195 196
	k := sentResponse.link.(cidlink.Link)
	blockIndex := testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "sent incorrect link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
197
	require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id")
198 199

	// send a cancellation
200 201
	cancelRequests := []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(td.requestID),
202
	}
203
	responseManager.ProcessRequests(td.ctx, td.p, cancelRequests)
204 205 206

	responseManager.synchronize()

207 208
	// at this point we should receive at most one more block, then traversal
	// should complete
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
	additionalBlocks := 0
	for {
		select {
		case <-td.ctx.Done():
			t.Fatal("should complete request before context closes")
		case sentResponse = <-td.sentResponses:
			k = sentResponse.link.(cidlink.Link)
			blockIndex = testutil.IndexOf(blks, k.Cid)
			require.NotEqual(t, blockIndex, -1, "did not send correct link")
			require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
			require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id")
			additionalBlocks++
		case <-td.completedRequestChan:
			require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block")
			return
		}
	}
226 227 228
}

func TestEarlyCancellation(t *testing.T) {
229 230 231
	td := newTestData(t)
	defer td.cancel()
	td.queryQueue.popWait.Add(1)
232
	responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
233
	responseManager.Startup()
234
	responseManager.ProcessRequests(td.ctx, td.p, td.requests)
235 236

	// send a cancellation
237 238
	cancelRequests := []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(td.requestID),
239
	}
240
	responseManager.ProcessRequests(td.ctx, td.p, cancelRequests)
241 242 243 244

	responseManager.synchronize()

	// unblock popping from queue
245
	td.queryQueue.popWait.Done()
246

247
	timer := time.NewTimer(time.Second)
248
	// verify no responses processed
249
	testutil.AssertDoesReceiveFirst(t, timer.C, "should not process more responses", td.sentResponses, td.completedRequestChan)
250
}
251 252

func TestValidationAndExtensions(t *testing.T) {
253
	t.Run("on its own, should fail validation", func(t *testing.T) {
254 255
		td := newTestData(t)
		defer td.cancel()
256
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
257
		responseManager.Startup()
258
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
259
		var lastRequest completedRequest
260
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
261 262 263 264
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
	})

	t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
265 266
		td := newTestData(t)
		defer td.cancel()
267
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
268
		responseManager.Startup()
269 270
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			hookActions.SendExtensionData(td.extensionResponse)
271
		})
272
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
273
		var lastRequest completedRequest
274
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
275 276
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
277 278
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
279
	})
280

281
	t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
282 283
		td := newTestData(t)
		defer td.cancel()
284
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
285
		responseManager.Startup()
286
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
287
			hookActions.ValidateRequest()
288
			hookActions.SendExtensionData(td.extensionResponse)
289
		})
290
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
291
		var lastRequest completedRequest
292
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
293 294
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
295 296
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
297 298
	})

299
	t.Run("if any hook fails, should fail", func(t *testing.T) {
300 301
		td := newTestData(t)
		defer td.cancel()
302
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
303
		responseManager.Startup()
304
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
305
			hookActions.ValidateRequest()
306
		})
307 308
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			hookActions.SendExtensionData(td.extensionResponse)
309 310
			hookActions.TerminateWithError(errors.New("everything went to crap"))
		})
311
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
312
		var lastRequest completedRequest
313
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
314 315
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
316 317
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
318
	})
319

320
	t.Run("hooks can be unregistered", func(t *testing.T) {
321 322
		td := newTestData(t)
		defer td.cancel()
323
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
324
		responseManager.Startup()
325
		unregister := td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
326
			hookActions.ValidateRequest()
327
			hookActions.SendExtensionData(td.extensionResponse)
328
		})
329 330

		// hook validates request
331
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
332
		var lastRequest completedRequest
333
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
334 335
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
336 337
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
338 339 340 341 342

		// unregister
		unregister()

		// no same request should fail
343 344
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
345
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
346
	})
347 348

	t.Run("hooks can alter the loader", func(t *testing.T) {
349 350
		td := newTestData(t)
		defer td.cancel()
351 352
		obs := make(map[ipld.Link][]byte)
		oloader, _ := testutil.NewTestStore(obs)
353
		responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
354 355
		responseManager.Startup()
		// add validating hook -- so the request SHOULD succeed
356
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
357 358 359 360 361
			hookActions.ValidateRequest()
		})

		// request fails with base loader reading from block store that's missing data
		var lastRequest completedRequest
362 363
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
364 365
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")

366
		err := td.peristenceOptions.Register("chainstore", td.loader)
367
		require.NoError(t, err)
368
		// register hook to use different loader
369 370
		_ = td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			if _, found := requestData.Extension(td.extensionName); found {
371
				hookActions.UsePersistenceOption("chainstore")
372
				hookActions.SendExtensionData(td.extensionResponse)
373 374 375
			}
		})
		// hook uses different loader that should make request succeed
376 377
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
378 379
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
380 381
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
382 383 384
	})

	t.Run("hooks can alter the node builder chooser", func(t *testing.T) {
385 386
		td := newTestData(t)
		defer td.cancel()
387
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
388 389 390
		responseManager.Startup()

		customChooserCallCount := 0
Hannah Howard's avatar
Hannah Howard committed
391
		customChooser := func(ipld.Link, ipld.LinkContext) (ipld.NodeStyle, error) {
392
			customChooserCallCount++
Hannah Howard's avatar
Hannah Howard committed
393
			return basicnode.Style.Any, nil
394 395 396
		}

		// add validating hook -- so the request SHOULD succeed
397
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
398 399 400 401 402
			hookActions.ValidateRequest()
		})

		// with default chooser, customer chooser not called
		var lastRequest completedRequest
403 404
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
405 406 407 408
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		require.Equal(t, 0, customChooserCallCount)

		// register hook to use custom chooser
409 410
		_ = td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			if _, found := requestData.Extension(td.extensionName); found {
Hannah Howard's avatar
Hannah Howard committed
411
				hookActions.UseLinkTargetNodeStyleChooser(customChooser)
412
				hookActions.SendExtensionData(td.extensionResponse)
413 414 415 416
			}
		})

		// verify now that request succeeds and uses custom chooser
417 418
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
419 420
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
421 422
		testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
423 424
		require.Equal(t, 5, customChooserCallCount)
	})
425 426 427 428 429

	t.Run("test block hook processing", func(t *testing.T) {
		t.Run("can send extension data", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
430
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
				hookActions.SendExtensionData(td.extensionResponse)
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
			for i := 0; i < td.blockChainLength; i++ {
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
				require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
			}
		})

		t.Run("can send errors", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
452
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
453 454 455 456 457 458 459 460 461 462
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
				hookActions.TerminateWithError(errors.New("failed"))
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
Hannah Howard's avatar
Hannah Howard committed
463
			require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")
464 465 466 467 468
		})

		t.Run("can pause/unpause", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
469
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
470 471 472 473
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
Hannah Howard's avatar
Hannah Howard committed
474
			blkIndex := 0
475 476
			blockCount := 3
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
Hannah Howard's avatar
Hannah Howard committed
477 478
				blkIndex++
				if blkIndex == blockCount {
479 480
					hookActions.PauseResponse()
				}
Hannah Howard's avatar
Hannah Howard committed
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			timer := time.NewTimer(500 * time.Millisecond)
			testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			for i := 0; i < blockCount; i++ {
				testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
			}
			testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
			var pausedRequest pausedRequest
			testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
			err := responseManager.UnpauseResponse(td.p, td.requestID)
			require.NoError(t, err)
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		})

	})

	t.Run("test update hook processing", func(t *testing.T) {

		t.Run("can pause/unpause", func(t *testing.T) {
			td := newTestData(t)
			defer td.cancel()
505
			responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
Hannah Howard's avatar
Hannah Howard committed
506 507 508 509 510 511 512
			responseManager.Startup()
			td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
				hookActions.ValidateRequest()
			})
			blkIndex := 0
			blockCount := 3
			td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
513
				blkIndex++
Hannah Howard's avatar
Hannah Howard committed
514 515 516 517 518 519 520 521
				if blkIndex == blockCount {
					hookActions.PauseResponse()
				}
			})
			td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
				if _, found := updateData.Extension(td.extensionName); found {
					hookActions.UnpauseResponse()
				}
522 523 524 525 526
			})
			responseManager.ProcessRequests(td.ctx, td.p, td.requests)
			timer := time.NewTimer(500 * time.Millisecond)
			testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			var sentResponses []sentResponse
Hannah Howard's avatar
Hannah Howard committed
527 528
			for i := 0; i < blockCount; i++ {
				testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
529
			}
Hannah Howard's avatar
Hannah Howard committed
530 531 532
			testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
			var pausedRequest pausedRequest
			testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
533
			require.LessOrEqual(t, len(sentResponses), blockCount)
Hannah Howard's avatar
Hannah Howard committed
534
			responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
535 536 537 538 539
			var lastRequest completedRequest
			testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		})

Hannah Howard's avatar
Hannah Howard committed
540 541 542 543
		t.Run("can send extension data", func(t *testing.T) {
			t.Run("when unpaused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
544
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
Hannah Howard's avatar
Hannah Howard committed
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				wait := make(chan struct{})
				sent := make(chan struct{})
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						close(sent)
						<-wait
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.SendExtensionData(td.extensionResponse)
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks")
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
				responseManager.synchronize()
				close(wait)
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")
				require.Equal(t, td.extensionResponse, receivedExtension.extension, "incorrect extension response sent")
			})

			t.Run("when paused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
581
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
Hannah Howard's avatar
Hannah Howard committed
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						hookActions.PauseResponse()
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.SendExtensionData(td.extensionResponse)
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				var sentResponses []sentResponse
				for i := 0; i < blockCount; i++ {
					testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
				}
				testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
				var pausedRequest pausedRequest
				testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
				require.LessOrEqual(t, len(sentResponses), blockCount)

				// send update
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)

				// receive data
				var receivedExtension sentExtension
				testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response")

				// should still be paused
				timer := time.NewTimer(500 * time.Millisecond)
				testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan)
			})
		})

		t.Run("can send errors", func(t *testing.T) {
			t.Run("when unpaused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
626
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
Hannah Howard's avatar
Hannah Howard committed
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				wait := make(chan struct{})
				sent := make(chan struct{})
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						close(sent)
						<-wait
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.TerminateWithError(errors.New("something went wrong"))
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				testutil.AssertDoesReceive(td.ctx, t, sent, "sends blocks")
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)
				responseManager.synchronize()
				close(wait)
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")
			})

			t.Run("when paused", func(t *testing.T) {
				td := newTestData(t)
				defer td.cancel()
660
				responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
Hannah Howard's avatar
Hannah Howard committed
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
				responseManager.Startup()
				td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
					hookActions.ValidateRequest()
				})
				blkIndex := 0
				blockCount := 3
				td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) {
					blkIndex++
					if blkIndex == blockCount {
						hookActions.PauseResponse()
					}
				})
				td.updateHooks.Register(func(p peer.ID, requestData graphsync.RequestData, updateData graphsync.RequestData, hookActions graphsync.RequestUpdatedHookActions) {
					if _, found := updateData.Extension(td.extensionName); found {
						hookActions.TerminateWithError(errors.New("something went wrong"))
					}
				})
				responseManager.ProcessRequests(td.ctx, td.p, td.requests)
				var sentResponses []sentResponse
				for i := 0; i < blockCount; i++ {
					testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block")
				}
				testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks")
				var pausedRequest pausedRequest
				testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request")
				require.LessOrEqual(t, len(sentResponses), blockCount)

				// send update
				responseManager.ProcessRequests(td.ctx, td.p, td.updateRequests)

				// should terminate
				var lastRequest completedRequest
				testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
				require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "request should fail")

				// cannot unpause
				err := responseManager.UnpauseResponse(td.p, td.requestID)
				require.Error(t, err)
			})
		})

702
	})
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
	t.Run("final response status listeners", func(t *testing.T) {
		td := newTestData(t)
		defer td.cancel()
		responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners)
		responseManager.Startup()
		td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) {
			hookActions.ValidateRequest()
		})
		statusChan := make(chan graphsync.ResponseStatusCode, 1)
		td.completedListeners.Register(func(p peer.ID, requestData graphsync.RequestData, status graphsync.ResponseStatusCode) {
			select {
			case statusChan <- status:
			default:
			}
		})
		responseManager.ProcessRequests(td.ctx, td.p, td.requests)
		var lastRequest completedRequest
		testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var status graphsync.ResponseStatusCode
		testutil.AssertReceive(td.ctx, t, statusChan, &status, "should receive status")
		require.True(t, gsmsg.IsTerminalSuccessCode(status), "request should succeed")
	})
726 727 728 729 730 731 732 733 734 735 736 737 738
}

type testData struct {
	ctx                   context.Context
	cancel                func()
	blockStore            map[ipld.Link][]byte
	loader                ipld.Loader
	storer                ipld.Storer
	blockChainLength      int
	blockChain            *testutil.TestBlockChain
	completedRequestChan  chan completedRequest
	sentResponses         chan sentResponse
	sentExtensions        chan sentExtension
Hannah Howard's avatar
Hannah Howard committed
739
	pausedRequests        chan pausedRequest
740 741 742 743 744 745 746
	peerManager           *fakePeerManager
	queryQueue            *fakeQueryQueue
	extensionData         []byte
	extensionName         graphsync.ExtensionName
	extension             graphsync.ExtensionData
	extensionResponseData []byte
	extensionResponse     graphsync.ExtensionData
Hannah Howard's avatar
Hannah Howard committed
747 748
	extensionUpdateData   []byte
	extensionUpdate       graphsync.ExtensionData
749 750
	requestID             graphsync.RequestID
	requests              []gsmsg.GraphSyncRequest
Hannah Howard's avatar
Hannah Howard committed
751
	updateRequests        []gsmsg.GraphSyncRequest
752 753
	p                     peer.ID
	peristenceOptions     *persistenceoptions.PersistenceOptions
Hannah Howard's avatar
Hannah Howard committed
754 755 756
	requestHooks          *hooks.IncomingRequestHooks
	blockHooks            *hooks.OutgoingBlockHooks
	updateHooks           *hooks.RequestUpdatedHooks
757
	completedListeners    *hooks.CompletedResponseListeners
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
}

func newTestData(t *testing.T) testData {
	ctx := context.Background()
	td := testData{}
	td.ctx, td.cancel = context.WithTimeout(ctx, 10*time.Second)

	td.blockStore = make(map[ipld.Link][]byte)
	td.loader, td.storer = testutil.NewTestStore(td.blockStore)
	td.blockChainLength = 5
	td.blockChain = testutil.SetupBlockChain(ctx, t, td.loader, td.storer, 100, td.blockChainLength)

	td.completedRequestChan = make(chan completedRequest, 1)
	td.sentResponses = make(chan sentResponse, td.blockChainLength*2)
	td.sentExtensions = make(chan sentExtension, td.blockChainLength*2)
Hannah Howard's avatar
Hannah Howard committed
773 774
	td.pausedRequests = make(chan pausedRequest, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: td.completedRequestChan, sentResponses: td.sentResponses, sentExtensions: td.sentExtensions, pausedRequests: td.pausedRequests}
775 776 777 778 779 780 781 782 783 784 785 786 787 788
	td.peerManager = &fakePeerManager{peerResponseSender: fprs}
	td.queryQueue = &fakeQueryQueue{}

	td.extensionData = testutil.RandomBytes(100)
	td.extensionName = graphsync.ExtensionName("AppleSauce/McGee")
	td.extension = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionData,
	}
	td.extensionResponseData = testutil.RandomBytes(100)
	td.extensionResponse = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionResponseData,
	}
Hannah Howard's avatar
Hannah Howard committed
789 790 791 792 793
	td.extensionUpdateData = testutil.RandomBytes(100)
	td.extensionUpdate = graphsync.ExtensionData{
		Name: td.extensionName,
		Data: td.extensionUpdateData,
	}
794 795 796 797
	td.requestID = graphsync.RequestID(rand.Int31())
	td.requests = []gsmsg.GraphSyncRequest{
		gsmsg.NewRequest(td.requestID, td.blockChain.TipLink.(cidlink.Link).Cid, td.blockChain.Selector(), graphsync.Priority(0), td.extension),
	}
Hannah Howard's avatar
Hannah Howard committed
798 799 800
	td.updateRequests = []gsmsg.GraphSyncRequest{
		gsmsg.UpdateRequest(td.requestID, td.extensionUpdate),
	}
801 802
	td.p = testutil.GeneratePeers(1)[0]
	td.peristenceOptions = persistenceoptions.New()
Hannah Howard's avatar
Hannah Howard committed
803 804 805
	td.requestHooks = hooks.NewRequestHooks(td.peristenceOptions)
	td.blockHooks = hooks.NewBlockHooks()
	td.updateHooks = hooks.NewUpdateHooks()
806
	td.completedListeners = hooks.NewCompletedResponseListeners()
807
	return td
808
}