responsemanager_test.go 18.3 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10 11
	"math"
	"math/rand"
	"sync"
	"testing"
	"time"

12
	"github.com/ipfs/go-graphsync"
13 14
	gsmsg "github.com/ipfs/go-graphsync/message"
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
15
	"github.com/ipfs/go-graphsync/selectorvalidator"
16
	"github.com/ipfs/go-graphsync/testutil"
17
	"github.com/ipfs/go-peertaskqueue/peertask"
18
	ipld "github.com/ipld/go-ipld-prime"
19
	ipldfree "github.com/ipld/go-ipld-prime/impl/free"
20
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
21
	"github.com/libp2p/go-libp2p-core/peer"
Hannah Howard's avatar
Hannah Howard committed
22
	"github.com/stretchr/testify/require"
23 24 25 26 27
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
28
	queries   []*peertask.QueueTask
29 30
}

31
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
32
	fqq.queriesLk.Lock()
33 34 35 36 37 38

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
39 40 41
	fqq.queriesLk.Unlock()
}

42
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
43 44 45 46
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
47
		return "", nil, -1
48
	}
49 50
	// We're not bothering to implement "work"
	task := fqq.queries[0]
51
	fqq.queries = fqq.queries[1:]
52
	return task.Target, []*peertask.Task{&task.Task}, 0
53 54
}

55
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
56 57 58
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
59 60
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
61 62 63 64
		}
	}
}

65 66 67 68
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

69 70 71
func (fqq *fakeQueryQueue) ThawRound() {

}
72 73 74 75 76 77 78 79 80 81 82 83

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
84
	requestID graphsync.RequestID
85 86 87 88
	link      ipld.Link
	data      []byte
}

89 90 91 92 93 94 95 96 97
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
98 99
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
100 101
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
102 103 104 105 106 107
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

func (fprs *fakePeerResponseSender) SendResponse(
108
	requestID graphsync.RequestID,
109 110 111 112 113 114
	link ipld.Link,
	data []byte,
) {
	fprs.sentResponses <- sentResponse{requestID, link, data}
}

115 116 117 118 119 120 121
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

122
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
123
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
124 125
}

126
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
127
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
128 129 130 131
}

func TestIncomingQuery(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
132
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
133
	defer cancel()
134 135

	blockStore := make(map[ipld.Link][]byte)
136
	loader, storer := testutil.NewTestStore(blockStore)
137 138 139
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

140
	requestIDChan := make(chan completedRequest, 1)
141
	sentResponses := make(chan sentResponse, len(blks))
142 143
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
144 145
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
146
	responseManager := New(ctx, loader, peerManager, queryQueue)
147
	responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
148 149
	responseManager.Startup()

150
	requestID := graphsync.RequestID(rand.Int31())
151
	requests := []gsmsg.GraphSyncRequest{
152
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
153 154
	}
	p := testutil.GeneratePeers(1)[0]
155
	responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
156
	testutil.AssertDoesReceive(ctx, t, requestIDChan, "Should have completed request but didn't")
157
	for i := 0; i < len(blks); i++ {
Hannah Howard's avatar
Hannah Howard committed
158 159 160 161 162 163 164
		var sentResponse sentResponse
		testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send responses")
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		require.NotEqual(t, blockIndex, -1, "sent incorrect link")
		require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
		require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
165 166 167 168 169
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
170
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
171
	defer cancel()
172 173

	blockStore := make(map[ipld.Link][]byte)
174
	loader, storer := testutil.NewTestStore(blockStore)
175 176 177
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

178
	requestIDChan := make(chan completedRequest)
179
	sentResponses := make(chan sentResponse)
180 181
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
182 183
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
184
	responseManager := New(ctx, loader, peerManager, queryQueue)
185
	responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
186 187
	responseManager.Startup()

188
	requestID := graphsync.RequestID(rand.Int31())
189
	requests := []gsmsg.GraphSyncRequest{
190
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
191 192
	}
	p := testutil.GeneratePeers(1)[0]
193
	responseManager.ProcessRequests(ctx, p, requests)
194 195

	// read one block
Hannah Howard's avatar
Hannah Howard committed
196 197 198 199 200 201 202
	var sentResponse sentResponse
	testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send response")
	k := sentResponse.link.(cidlink.Link)
	blockIndex := testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "sent incorrect link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
203 204 205 206 207

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
208
	responseManager.ProcessRequests(ctx, p, requests)
209 210 211

	responseManager.synchronize()

212 213
	// at this point we should receive at most one more block, then traversal
	// should complete
Hannah Howard's avatar
Hannah Howard committed
214 215 216 217 218 219 220 221 222
	testutil.AssertReceiveFirst(t, sentResponses, &sentResponse, "should send one additional response", ctx.Done(), requestIDChan)
	k = sentResponse.link.(cidlink.Link)
	blockIndex = testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "did not send correct link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "incorrect response id")

	// We should now be done
	testutil.AssertDoesReceiveFirst(t, requestIDChan, "should complete request", ctx.Done(), sentResponses)
223 224 225 226
}

func TestEarlyCancellation(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
227
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
228
	defer cancel()
229 230

	blockStore := make(map[ipld.Link][]byte)
231
	loader, storer := testutil.NewTestStore(blockStore)
232 233
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

234
	requestIDChan := make(chan completedRequest)
235
	sentResponses := make(chan sentResponse)
236 237
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
238 239 240
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
	queryQueue.popWait.Add(1)
241
	responseManager := New(ctx, loader, peerManager, queryQueue)
242 243
	responseManager.Startup()

244
	requestID := graphsync.RequestID(rand.Int31())
245
	requests := []gsmsg.GraphSyncRequest{
246
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
247 248
	}
	p := testutil.GeneratePeers(1)[0]
249
	responseManager.ProcessRequests(ctx, p, requests)
250 251 252 253 254

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
255
	responseManager.ProcessRequests(ctx, p, requests)
256 257 258 259 260 261 262

	responseManager.synchronize()

	// unblock popping from queue
	queryQueue.popWait.Done()

	// verify no responses processed
Hannah Howard's avatar
Hannah Howard committed
263
	testutil.AssertDoesReceiveFirst(t, ctx.Done(), "should not process more responses", sentResponses, requestIDChan)
264
}
265 266 267

func TestValidationAndExtensions(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
268
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
269
	defer cancel()
270 271

	blockStore := make(map[ipld.Link][]byte)
272
	loader, storer := testutil.NewTestStore(blockStore)
273 274
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
	completedRequestChan := make(chan completedRequest, 1)
	sentResponses := make(chan sentResponse, 100)
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: completedRequestChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}

	extensionData := testutil.RandomBytes(100)
	extensionName := graphsync.ExtensionName("AppleSauce/McGee")
	extension := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionData,
	}
	extensionResponseData := testutil.RandomBytes(100)
	extensionResponse := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionResponseData,
	}

294 295 296 297 298
	requestID := graphsync.RequestID(rand.Int31())
	requests := []gsmsg.GraphSyncRequest{
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
	}
	p := testutil.GeneratePeers(1)[0]
299

300 301 302 303 304 305 306 307 308 309 310 311 312 313
	t.Run("on its own, should fail validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
	})

	t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.SendExtensionData(extensionResponse)
314
		})
315 316 317 318 319 320 321 322
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
	})
323

324 325 326 327 328 329
	t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
			hookActions.SendExtensionData(extensionResponse)
330
		})
331 332 333 334 335 336 337
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
338 339
	})

340 341 342 343 344
	t.Run("if any hook fails, should fail", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
345
		})
346 347 348 349 350 351 352 353 354 355 356 357
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.SendExtensionData(extensionResponse)
			hookActions.TerminateWithError(errors.New("everything went to crap"))
		})
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
	})
358

359 360 361 362 363 364
	t.Run("hooks can be unregistered", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		unregister := responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
			hookActions.SendExtensionData(extensionResponse)
365
		})
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382

		// hook validates request
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")

		// unregister
		unregister()

		// no same request should fail
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
383
	})
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456

	t.Run("hooks can alter the loader", func(t *testing.T) {
		obs := make(map[ipld.Link][]byte)
		oloader, _ := testutil.NewTestStore(obs)
		responseManager := New(ctx, oloader, peerManager, queryQueue)
		responseManager.Startup()
		// add validating hook -- so the request SHOULD succeed
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
		})

		// request fails with base loader reading from block store that's missing data
		var lastRequest completedRequest
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")

		// register hook to use different loader
		_ = responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			if _, found := requestData.Extension(extensionName); found {
				hookActions.UseLoader(loader)
				hookActions.SendExtensionData(extensionResponse)
			}
		})

		// hook uses different loader that should make request succeed
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
	})

	t.Run("hooks can alter the node builder chooser", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()

		customChooserCallCount := 0
		customChooser := func(ipld.Link, ipld.LinkContext) ipld.NodeBuilder {
			customChooserCallCount++
			return ipldfree.NodeBuilder()
		}

		// add validating hook -- so the request SHOULD succeed
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
		})

		// with default chooser, customer chooser not called
		var lastRequest completedRequest
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		require.Equal(t, 0, customChooserCallCount)

		// register hook to use custom chooser
		_ = responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			if _, found := requestData.Extension(extensionName); found {
				hookActions.UseNodeBuilderChooser(customChooser)
				hookActions.SendExtensionData(extensionResponse)
			}
		})

		// verify now that request succeeds and uses custom chooser
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
		require.Equal(t, 5, customChooserCallCount)
	})
457
}