responsemanager_test.go 14.6 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10 11 12
	"math"
	"math/rand"
	"reflect"
	"sync"
	"testing"
	"time"

13
	"github.com/ipfs/go-graphsync"
14 15 16
	gsmsg "github.com/ipfs/go-graphsync/message"
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
	"github.com/ipfs/go-graphsync/testutil"
17
	"github.com/ipfs/go-peertaskqueue/peertask"
18 19
	ipld "github.com/ipld/go-ipld-prime"
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
20
	"github.com/libp2p/go-libp2p-core/peer"
21 22 23 24 25
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
26
	queries   []*peertask.QueueTask
27 28
}

29
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
30
	fqq.queriesLk.Lock()
31 32 33 34 35 36

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
37 38 39
	fqq.queriesLk.Unlock()
}

40
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
41 42 43 44
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
45
		return "", nil, -1
46
	}
47 48
	// We're not bothering to implement "work"
	task := fqq.queries[0]
49
	fqq.queries = fqq.queries[1:]
50
	return task.Target, []*peertask.Task{&task.Task}, 0
51 52
}

53
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
54 55 56
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
57 58
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
59 60 61 62
		}
	}
}

63 64 65 66
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

67 68 69
func (fqq *fakeQueryQueue) ThawRound() {

}
70 71 72 73 74 75 76 77 78 79 80 81

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
82
	requestID graphsync.RequestID
83 84 85 86
	link      ipld.Link
	data      []byte
}

87 88 89 90 91 92 93 94 95
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
96 97
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
98 99
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
100 101 102 103 104 105
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

func (fprs *fakePeerResponseSender) SendResponse(
106
	requestID graphsync.RequestID,
107 108 109 110 111 112
	link ipld.Link,
	data []byte,
) {
	fprs.sentResponses <- sentResponse{requestID, link, data}
}

113 114 115 116 117 118 119
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

120
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
121
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
122 123
}

124
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
125
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
126 127 128 129 130 131
}

func TestIncomingQuery(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 40*time.Millisecond)
	defer cancel()
132 133

	blockStore := make(map[ipld.Link][]byte)
134
	loader, storer := testutil.NewTestStore(blockStore)
135 136 137
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

138
	requestIDChan := make(chan completedRequest, 1)
139
	sentResponses := make(chan sentResponse, len(blks))
140 141
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
142 143
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
144
	responseManager := New(ctx, loader, peerManager, queryQueue)
145 146
	responseManager.Startup()

147
	requestID := graphsync.RequestID(rand.Int31())
148
	requests := []gsmsg.GraphSyncRequest{
149
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
150 151
	}
	p := testutil.GeneratePeers(1)[0]
152
	responseManager.ProcessRequests(ctx, p, requests)
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
	select {
	case <-ctx.Done():
		t.Fatal("Should have completed request but didn't")
	case <-requestIDChan:
	}
	for i := 0; i < len(blks); i++ {
		select {
		case sentResponse := <-sentResponses:
			k := sentResponse.link.(cidlink.Link)
			blockIndex := testutil.IndexOf(blks, k.Cid)
			if blockIndex == -1 {
				t.Fatal("sent incorrect link")
			}
			if !reflect.DeepEqual(sentResponse.data, blks[blockIndex].RawData()) {
				t.Fatal("sent incorrect data")
			}
			if sentResponse.requestID != requestID {
				t.Fatal("incorrect response id")
			}
		case <-ctx.Done():
			t.Fatal("did not send enough responses")
		}
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 40*time.Millisecond)
	defer cancel()
182 183

	blockStore := make(map[ipld.Link][]byte)
184
	loader, storer := testutil.NewTestStore(blockStore)
185 186 187
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

188
	requestIDChan := make(chan completedRequest)
189
	sentResponses := make(chan sentResponse)
190 191
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
192 193
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
194
	responseManager := New(ctx, loader, peerManager, queryQueue)
195 196
	responseManager.Startup()

197
	requestID := graphsync.RequestID(rand.Int31())
198
	requests := []gsmsg.GraphSyncRequest{
199
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
200 201
	}
	p := testutil.GeneratePeers(1)[0]
202
	responseManager.ProcessRequests(ctx, p, requests)
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225

	// read one block
	select {
	case sentResponse := <-sentResponses:
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		if blockIndex == -1 {
			t.Fatal("sent incorrect link")
		}
		if !reflect.DeepEqual(sentResponse.data, blks[blockIndex].RawData()) {
			t.Fatal("sent incorrect data")
		}
		if sentResponse.requestID != requestID {
			t.Fatal("incorrect response id")
		}
	case <-ctx.Done():
		t.Fatal("did not send responses")
	}

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
226
	responseManager.ProcessRequests(ctx, p, requests)
227 228 229

	responseManager.synchronize()

230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
	// at this point we should receive at most one more block, then traversal
	// should complete
	additionalMessageCount := 0
drainqueue:
	for {
		select {
		case <-ctx.Done():
			t.Fatal("Should have completed request but didn't")
		case sentResponse := <-sentResponses:
			if additionalMessageCount > 0 {
				t.Fatal("should not send any more responses")
			}
			k := sentResponse.link.(cidlink.Link)
			blockIndex := testutil.IndexOf(blks, k.Cid)
			if blockIndex == -1 {
				t.Fatal("sent incorrect link")
			}
			if !reflect.DeepEqual(sentResponse.data, blks[blockIndex].RawData()) {
				t.Fatal("sent incorrect data")
			}
			if sentResponse.requestID != requestID {
				t.Fatal("incorrect response id")
			}
			additionalMessageCount++
		case <-requestIDChan:
			break drainqueue
256 257 258 259 260 261 262 263
		}
	}
}

func TestEarlyCancellation(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 40*time.Millisecond)
	defer cancel()
264 265

	blockStore := make(map[ipld.Link][]byte)
266
	loader, storer := testutil.NewTestStore(blockStore)
267 268
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

269
	requestIDChan := make(chan completedRequest)
270
	sentResponses := make(chan sentResponse)
271 272
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
273 274 275
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
	queryQueue.popWait.Add(1)
276
	responseManager := New(ctx, loader, peerManager, queryQueue)
277 278
	responseManager.Startup()

279
	requestID := graphsync.RequestID(rand.Int31())
280
	requests := []gsmsg.GraphSyncRequest{
281
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
282 283
	}
	p := testutil.GeneratePeers(1)[0]
284
	responseManager.ProcessRequests(ctx, p, requests)
285 286 287 288 289

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
290
	responseManager.ProcessRequests(ctx, p, requests)
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305

	responseManager.synchronize()

	// unblock popping from queue
	queryQueue.popWait.Done()

	// verify no responses processed
	select {
	case <-ctx.Done():
	case <-sentResponses:
		t.Fatal("should not send any more responses")
	case <-requestIDChan:
		t.Fatal("should not send have completed response")
	}
}
306 307 308 309 310

func TestValidationAndExtensions(t *testing.T) {
	ctx := context.Background()
	ctx, cancel := context.WithTimeout(ctx, 40*time.Millisecond)
	defer cancel()
311 312

	blockStore := make(map[ipld.Link][]byte)
313
	loader, storer := testutil.NewTestStore(blockStore)
314 315
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
	completedRequestChan := make(chan completedRequest, 1)
	sentResponses := make(chan sentResponse, 100)
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: completedRequestChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}

	extensionData := testutil.RandomBytes(100)
	extensionName := graphsync.ExtensionName("AppleSauce/McGee")
	extension := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionData,
	}
	extensionResponseData := testutil.RandomBytes(100)
	extensionResponse := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionResponseData,
	}

	t.Run("with invalid selector", func(t *testing.T) {
336
		selectorSpec := testutil.NewInvalidSelectorSpec()
337 338
		requestID := graphsync.RequestID(rand.Int31())
		requests := []gsmsg.GraphSyncRequest{
339
			gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, selectorSpec, graphsync.Priority(math.MaxInt32), extension),
340 341 342 343
		}
		p := testutil.GeneratePeers(1)[0]

		t.Run("on its own, should fail validation", func(t *testing.T) {
344
			responseManager := New(ctx, loader, peerManager, queryQueue)
345 346 347 348 349 350 351 352 353 354 355 356 357
			responseManager.Startup()
			responseManager.ProcessRequests(ctx, p, requests)
			select {
			case <-ctx.Done():
				t.Fatal("Should have completed request but didn't")
			case lastRequest := <-completedRequestChan:
				if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
					t.Fatal("Request should have failed but didn't")
				}
			}
		})

		t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
358
			responseManager := New(ctx, loader, peerManager, queryQueue)
359
			responseManager.Startup()
360 361
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.SendExtensionData(extensionResponse)
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
			})
			responseManager.ProcessRequests(ctx, p, requests)
			select {
			case <-ctx.Done():
				t.Fatal("Should have completed request but didn't")
			case lastRequest := <-completedRequestChan:
				if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
					t.Fatal("Request should have succeeded but didn't")
				}
			}
			select {
			case <-ctx.Done():
				t.Fatal("Should have sent extension response but didn't")
			case receivedExtension := <-sentExtensions:
				if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
					t.Fatal("Proper Extension response should have been sent but wasn't")
				}
			}
		})

		t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
383
			responseManager := New(ctx, loader, peerManager, queryQueue)
384
			responseManager.Startup()
385 386 387
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.ValidateRequest()
				hookActions.SendExtensionData(extensionResponse)
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
			})
			responseManager.ProcessRequests(ctx, p, requests)
			select {
			case <-ctx.Done():
				t.Fatal("Should have completed request but didn't")
			case lastRequest := <-completedRequestChan:
				if !gsmsg.IsTerminalSuccessCode(lastRequest.result) {
					t.Fatal("Request should have succeeded but didn't")
				}
			}
			select {
			case <-ctx.Done():
				t.Fatal("Should have sent extension response but didn't")
			case receivedExtension := <-sentExtensions:
				if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
					t.Fatal("Proper Extension response should have been sent but wasn't")
				}
			}
		})
	})

	t.Run("with valid selector", func(t *testing.T) {
		requestID := graphsync.RequestID(rand.Int31())
		requests := []gsmsg.GraphSyncRequest{
412
			gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
413 414 415 416
		}
		p := testutil.GeneratePeers(1)[0]

		t.Run("on its own, should pass validation", func(t *testing.T) {
417
			responseManager := New(ctx, loader, peerManager, queryQueue)
418 419 420 421 422 423 424 425 426 427 428 429 430
			responseManager.Startup()
			responseManager.ProcessRequests(ctx, p, requests)
			select {
			case <-ctx.Done():
				t.Fatal("Should have completed request but didn't")
			case lastRequest := <-completedRequestChan:
				if !gsmsg.IsTerminalSuccessCode(lastRequest.result) {
					t.Fatal("Request should have failed but didn't")
				}
			}
		})

		t.Run("if any hook fails, should fail", func(t *testing.T) {
431
			responseManager := New(ctx, loader, peerManager, queryQueue)
432
			responseManager.Startup()
433 434 435
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.SendExtensionData(extensionResponse)
				hookActions.TerminateWithError(errors.New("everything went to crap"))
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
			})
			responseManager.ProcessRequests(ctx, p, requests)
			select {
			case <-ctx.Done():
				t.Fatal("Should have completed request but didn't")
			case lastRequest := <-completedRequestChan:
				if !gsmsg.IsTerminalFailureCode(lastRequest.result) {
					t.Fatal("Request should have succeeded but didn't")
				}
			}
			select {
			case <-ctx.Done():
				t.Fatal("Should have sent extension response but didn't")
			case receivedExtension := <-sentExtensions:
				if !reflect.DeepEqual(receivedExtension.extension, extensionResponse) {
					t.Fatal("Proper Extension response should have been sent but wasn't")
				}
			}
		})
	})
}