responsemanager_test.go 14.1 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10 11
	"math"
	"math/rand"
	"sync"
	"testing"
	"time"

12
	"github.com/ipfs/go-graphsync"
13 14 15
	gsmsg "github.com/ipfs/go-graphsync/message"
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
	"github.com/ipfs/go-graphsync/testutil"
16
	"github.com/ipfs/go-peertaskqueue/peertask"
17 18
	ipld "github.com/ipld/go-ipld-prime"
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
19
	"github.com/libp2p/go-libp2p-core/peer"
Hannah Howard's avatar
Hannah Howard committed
20
	"github.com/stretchr/testify/require"
21 22 23 24 25
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
26
	queries   []*peertask.QueueTask
27 28
}

29
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
30
	fqq.queriesLk.Lock()
31 32 33 34 35 36

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
37 38 39
	fqq.queriesLk.Unlock()
}

40
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
41 42 43 44
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
45
		return "", nil, -1
46
	}
47 48
	// We're not bothering to implement "work"
	task := fqq.queries[0]
49
	fqq.queries = fqq.queries[1:]
50
	return task.Target, []*peertask.Task{&task.Task}, 0
51 52
}

53
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
54 55 56
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
57 58
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
59 60 61 62
		}
	}
}

63 64 65 66
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

67 68 69
func (fqq *fakeQueryQueue) ThawRound() {

}
70 71 72 73 74 75 76 77 78 79 80 81

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
82
	requestID graphsync.RequestID
83 84 85 86
	link      ipld.Link
	data      []byte
}

87 88 89 90 91 92 93 94 95
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
96 97
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
98 99
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
100 101 102 103 104 105
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

func (fprs *fakePeerResponseSender) SendResponse(
106
	requestID graphsync.RequestID,
107 108 109 110 111 112
	link ipld.Link,
	data []byte,
) {
	fprs.sentResponses <- sentResponse{requestID, link, data}
}

113 114 115 116 117 118 119
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

120
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
121
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
122 123
}

124
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
125
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
126 127 128 129
}

func TestIncomingQuery(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
130
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
131
	defer cancel()
132 133

	blockStore := make(map[ipld.Link][]byte)
134
	loader, storer := testutil.NewTestStore(blockStore)
135 136 137
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

138
	requestIDChan := make(chan completedRequest, 1)
139
	sentResponses := make(chan sentResponse, len(blks))
140 141
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
142 143
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
144
	responseManager := New(ctx, loader, peerManager, queryQueue)
145 146
	responseManager.Startup()

147
	requestID := graphsync.RequestID(rand.Int31())
148
	requests := []gsmsg.GraphSyncRequest{
149
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
150 151
	}
	p := testutil.GeneratePeers(1)[0]
152
	responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
153
	testutil.AssertDoesReceive(ctx, t, requestIDChan, "Should have completed request but didn't")
154
	for i := 0; i < len(blks); i++ {
Hannah Howard's avatar
Hannah Howard committed
155 156 157 158 159 160 161
		var sentResponse sentResponse
		testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send responses")
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		require.NotEqual(t, blockIndex, -1, "sent incorrect link")
		require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
		require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
162 163 164 165 166
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
167
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
168
	defer cancel()
169 170

	blockStore := make(map[ipld.Link][]byte)
171
	loader, storer := testutil.NewTestStore(blockStore)
172 173 174
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

175
	requestIDChan := make(chan completedRequest)
176
	sentResponses := make(chan sentResponse)
177 178
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
179 180
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
181
	responseManager := New(ctx, loader, peerManager, queryQueue)
182 183
	responseManager.Startup()

184
	requestID := graphsync.RequestID(rand.Int31())
185
	requests := []gsmsg.GraphSyncRequest{
186
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
187 188
	}
	p := testutil.GeneratePeers(1)[0]
189
	responseManager.ProcessRequests(ctx, p, requests)
190 191

	// read one block
Hannah Howard's avatar
Hannah Howard committed
192 193 194 195 196 197 198
	var sentResponse sentResponse
	testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send response")
	k := sentResponse.link.(cidlink.Link)
	blockIndex := testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "sent incorrect link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
199 200 201 202 203

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
204
	responseManager.ProcessRequests(ctx, p, requests)
205 206 207

	responseManager.synchronize()

208 209
	// at this point we should receive at most one more block, then traversal
	// should complete
Hannah Howard's avatar
Hannah Howard committed
210 211 212 213 214 215 216 217 218
	testutil.AssertReceiveFirst(t, sentResponses, &sentResponse, "should send one additional response", ctx.Done(), requestIDChan)
	k = sentResponse.link.(cidlink.Link)
	blockIndex = testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "did not send correct link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "incorrect response id")

	// We should now be done
	testutil.AssertDoesReceiveFirst(t, requestIDChan, "should complete request", ctx.Done(), sentResponses)
219 220 221 222
}

func TestEarlyCancellation(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
223
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
224
	defer cancel()
225 226

	blockStore := make(map[ipld.Link][]byte)
227
	loader, storer := testutil.NewTestStore(blockStore)
228 229
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

230
	requestIDChan := make(chan completedRequest)
231
	sentResponses := make(chan sentResponse)
232 233
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
234 235 236
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
	queryQueue.popWait.Add(1)
237
	responseManager := New(ctx, loader, peerManager, queryQueue)
238 239
	responseManager.Startup()

240
	requestID := graphsync.RequestID(rand.Int31())
241
	requests := []gsmsg.GraphSyncRequest{
242
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
243 244
	}
	p := testutil.GeneratePeers(1)[0]
245
	responseManager.ProcessRequests(ctx, p, requests)
246 247 248 249 250

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
251
	responseManager.ProcessRequests(ctx, p, requests)
252 253 254 255 256 257 258

	responseManager.synchronize()

	// unblock popping from queue
	queryQueue.popWait.Done()

	// verify no responses processed
Hannah Howard's avatar
Hannah Howard committed
259
	testutil.AssertDoesReceiveFirst(t, ctx.Done(), "should not process more responses", sentResponses, requestIDChan)
260
}
261 262 263

func TestValidationAndExtensions(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
264
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
265
	defer cancel()
266 267

	blockStore := make(map[ipld.Link][]byte)
268
	loader, storer := testutil.NewTestStore(blockStore)
269 270
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
	completedRequestChan := make(chan completedRequest, 1)
	sentResponses := make(chan sentResponse, 100)
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: completedRequestChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}

	extensionData := testutil.RandomBytes(100)
	extensionName := graphsync.ExtensionName("AppleSauce/McGee")
	extension := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionData,
	}
	extensionResponseData := testutil.RandomBytes(100)
	extensionResponse := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionResponseData,
	}

	t.Run("with invalid selector", func(t *testing.T) {
291
		selectorSpec := testutil.NewInvalidSelectorSpec()
292 293
		requestID := graphsync.RequestID(rand.Int31())
		requests := []gsmsg.GraphSyncRequest{
294
			gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, selectorSpec, graphsync.Priority(math.MaxInt32), extension),
295 296 297 298
		}
		p := testutil.GeneratePeers(1)[0]

		t.Run("on its own, should fail validation", func(t *testing.T) {
299
			responseManager := New(ctx, loader, peerManager, queryQueue)
300 301
			responseManager.Startup()
			responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
302 303 304
			var lastRequest completedRequest
			testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
305 306 307
		})

		t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
308
			responseManager := New(ctx, loader, peerManager, queryQueue)
309
			responseManager.Startup()
310 311
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.SendExtensionData(extensionResponse)
312 313
			})
			responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
314 315 316 317 318 319
			var lastRequest completedRequest
			testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
			var receivedExtension sentExtension
			testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
			require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
320 321 322
		})

		t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
323
			responseManager := New(ctx, loader, peerManager, queryQueue)
324
			responseManager.Startup()
325 326 327
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.ValidateRequest()
				hookActions.SendExtensionData(extensionResponse)
328 329
			})
			responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
330 331 332 333 334 335
			var lastRequest completedRequest
			testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
			var receivedExtension sentExtension
			testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
			require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
336 337 338 339 340 341
		})
	})

	t.Run("with valid selector", func(t *testing.T) {
		requestID := graphsync.RequestID(rand.Int31())
		requests := []gsmsg.GraphSyncRequest{
342
			gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
343 344 345 346
		}
		p := testutil.GeneratePeers(1)[0]

		t.Run("on its own, should pass validation", func(t *testing.T) {
347
			responseManager := New(ctx, loader, peerManager, queryQueue)
348 349
			responseManager.Startup()
			responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
350 351 352
			var lastRequest completedRequest
			testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
353 354 355
		})

		t.Run("if any hook fails, should fail", func(t *testing.T) {
356
			responseManager := New(ctx, loader, peerManager, queryQueue)
357
			responseManager.Startup()
358 359 360
			responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
				hookActions.SendExtensionData(extensionResponse)
				hookActions.TerminateWithError(errors.New("everything went to crap"))
361 362
			})
			responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
363 364 365 366 367 368
			var lastRequest completedRequest
			testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
			require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
			var receivedExtension sentExtension
			testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
			require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
369 370 371
		})
	})
}