responsemanager_test.go 14.7 KB
Newer Older
1 2 3 4
package responsemanager

import (
	"context"
5
	"errors"
6 7 8 9 10 11
	"math"
	"math/rand"
	"sync"
	"testing"
	"time"

12
	"github.com/ipfs/go-graphsync"
13 14
	gsmsg "github.com/ipfs/go-graphsync/message"
	"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
15
	"github.com/ipfs/go-graphsync/selectorvalidator"
16
	"github.com/ipfs/go-graphsync/testutil"
17
	"github.com/ipfs/go-peertaskqueue/peertask"
18 19
	ipld "github.com/ipld/go-ipld-prime"
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
20
	"github.com/libp2p/go-libp2p-core/peer"
Hannah Howard's avatar
Hannah Howard committed
21
	"github.com/stretchr/testify/require"
22 23 24 25 26
)

type fakeQueryQueue struct {
	popWait   sync.WaitGroup
	queriesLk sync.RWMutex
27
	queries   []*peertask.QueueTask
28 29
}

30
func (fqq *fakeQueryQueue) PushTasks(to peer.ID, tasks ...peertask.Task) {
31
	fqq.queriesLk.Lock()
32 33 34 35 36 37

	// This isn't quite right as the queue should deduplicate requests, but
	// it's good enough.
	for _, task := range tasks {
		fqq.queries = append(fqq.queries, peertask.NewQueueTask(task, to, time.Now()))
	}
38 39 40
	fqq.queriesLk.Unlock()
}

41
func (fqq *fakeQueryQueue) PopTasks(targetWork int) (peer.ID, []*peertask.Task, int) {
42 43 44 45
	fqq.popWait.Wait()
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	if len(fqq.queries) == 0 {
46
		return "", nil, -1
47
	}
48 49
	// We're not bothering to implement "work"
	task := fqq.queries[0]
50
	fqq.queries = fqq.queries[1:]
51
	return task.Target, []*peertask.Task{&task.Task}, 0
52 53
}

54
func (fqq *fakeQueryQueue) Remove(topic peertask.Topic, p peer.ID) {
55 56 57
	fqq.queriesLk.Lock()
	defer fqq.queriesLk.Unlock()
	for i, query := range fqq.queries {
58 59
		if query.Target == p && query.Topic == topic {
			fqq.queries = append(fqq.queries[:i], fqq.queries[i+1:]...)
60 61 62 63
		}
	}
}

64 65 66 67
func (fqq *fakeQueryQueue) TasksDone(to peer.ID, tasks ...*peertask.Task) {
	// We don't track active tasks so this is a no-op
}

68 69 70
func (fqq *fakeQueryQueue) ThawRound() {

}
71 72 73 74 75 76 77 78 79 80 81 82

type fakePeerManager struct {
	lastPeer           peer.ID
	peerResponseSender peerresponsemanager.PeerResponseSender
}

func (fpm *fakePeerManager) SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender {
	fpm.lastPeer = p
	return fpm.peerResponseSender
}

type sentResponse struct {
83
	requestID graphsync.RequestID
84 85 86 87
	link      ipld.Link
	data      []byte
}

88 89 90 91 92 93 94 95 96
type sentExtension struct {
	requestID graphsync.RequestID
	extension graphsync.ExtensionData
}

type completedRequest struct {
	requestID graphsync.RequestID
	result    graphsync.ResponseStatusCode
}
97 98
type fakePeerResponseSender struct {
	sentResponses        chan sentResponse
99 100
	sentExtensions       chan sentExtension
	lastCompletedRequest chan completedRequest
101 102 103 104 105 106
}

func (fprs *fakePeerResponseSender) Startup()  {}
func (fprs *fakePeerResponseSender) Shutdown() {}

func (fprs *fakePeerResponseSender) SendResponse(
107
	requestID graphsync.RequestID,
108 109 110 111 112 113
	link ipld.Link,
	data []byte,
) {
	fprs.sentResponses <- sentResponse{requestID, link, data}
}

114 115 116 117 118 119 120
func (fprs *fakePeerResponseSender) SendExtensionData(
	requestID graphsync.RequestID,
	extension graphsync.ExtensionData,
) {
	fprs.sentExtensions <- sentExtension{requestID, extension}
}

121
func (fprs *fakePeerResponseSender) FinishRequest(requestID graphsync.RequestID) {
122
	fprs.lastCompletedRequest <- completedRequest{requestID, graphsync.RequestCompletedFull}
123 124
}

125
func (fprs *fakePeerResponseSender) FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) {
126
	fprs.lastCompletedRequest <- completedRequest{requestID, status}
127 128 129 130
}

func TestIncomingQuery(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
131
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
132
	defer cancel()
133 134

	blockStore := make(map[ipld.Link][]byte)
135
	loader, storer := testutil.NewTestStore(blockStore)
136 137 138
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

139
	requestIDChan := make(chan completedRequest, 1)
140
	sentResponses := make(chan sentResponse, len(blks))
141 142
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
143 144
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
145
	responseManager := New(ctx, loader, peerManager, queryQueue)
146
	responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
147 148
	responseManager.Startup()

149
	requestID := graphsync.RequestID(rand.Int31())
150
	requests := []gsmsg.GraphSyncRequest{
151
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
152 153
	}
	p := testutil.GeneratePeers(1)[0]
154
	responseManager.ProcessRequests(ctx, p, requests)
Hannah Howard's avatar
Hannah Howard committed
155
	testutil.AssertDoesReceive(ctx, t, requestIDChan, "Should have completed request but didn't")
156
	for i := 0; i < len(blks); i++ {
Hannah Howard's avatar
Hannah Howard committed
157 158 159 160 161 162 163
		var sentResponse sentResponse
		testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send responses")
		k := sentResponse.link.(cidlink.Link)
		blockIndex := testutil.IndexOf(blks, k.Cid)
		require.NotEqual(t, blockIndex, -1, "sent incorrect link")
		require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
		require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
164 165 166 167 168
	}
}

func TestCancellationQueryInProgress(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
169
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
170
	defer cancel()
171 172

	blockStore := make(map[ipld.Link][]byte)
173
	loader, storer := testutil.NewTestStore(blockStore)
174 175 176
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)
	blks := blockChain.AllBlocks()

177
	requestIDChan := make(chan completedRequest)
178
	sentResponses := make(chan sentResponse)
179 180
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
181 182
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
183
	responseManager := New(ctx, loader, peerManager, queryQueue)
184
	responseManager.RegisterHook(selectorvalidator.SelectorValidator(100))
185 186
	responseManager.Startup()

187
	requestID := graphsync.RequestID(rand.Int31())
188
	requests := []gsmsg.GraphSyncRequest{
189
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
190 191
	}
	p := testutil.GeneratePeers(1)[0]
192
	responseManager.ProcessRequests(ctx, p, requests)
193 194

	// read one block
Hannah Howard's avatar
Hannah Howard committed
195 196 197 198 199 200 201
	var sentResponse sentResponse
	testutil.AssertReceive(ctx, t, sentResponses, &sentResponse, "did not send response")
	k := sentResponse.link.(cidlink.Link)
	blockIndex := testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "sent incorrect link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "has incorrect response id")
202 203 204 205 206

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
207
	responseManager.ProcessRequests(ctx, p, requests)
208 209 210

	responseManager.synchronize()

211 212
	// at this point we should receive at most one more block, then traversal
	// should complete
Hannah Howard's avatar
Hannah Howard committed
213 214 215 216 217 218 219 220 221
	testutil.AssertReceiveFirst(t, sentResponses, &sentResponse, "should send one additional response", ctx.Done(), requestIDChan)
	k = sentResponse.link.(cidlink.Link)
	blockIndex = testutil.IndexOf(blks, k.Cid)
	require.NotEqual(t, blockIndex, -1, "did not send correct link")
	require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data")
	require.Equal(t, requestID, sentResponse.requestID, "incorrect response id")

	// We should now be done
	testutil.AssertDoesReceiveFirst(t, requestIDChan, "should complete request", ctx.Done(), sentResponses)
222 223 224 225
}

func TestEarlyCancellation(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
226
	ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
227
	defer cancel()
228 229

	blockStore := make(map[ipld.Link][]byte)
230
	loader, storer := testutil.NewTestStore(blockStore)
231 232
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

233
	requestIDChan := make(chan completedRequest)
234
	sentResponses := make(chan sentResponse)
235 236
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: requestIDChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
237 238 239
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}
	queryQueue.popWait.Add(1)
240
	responseManager := New(ctx, loader, peerManager, queryQueue)
241 242
	responseManager.Startup()

243
	requestID := graphsync.RequestID(rand.Int31())
244
	requests := []gsmsg.GraphSyncRequest{
245
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32)),
246 247
	}
	p := testutil.GeneratePeers(1)[0]
248
	responseManager.ProcessRequests(ctx, p, requests)
249 250 251 252 253

	// send a cancellation
	requests = []gsmsg.GraphSyncRequest{
		gsmsg.CancelRequest(requestID),
	}
254
	responseManager.ProcessRequests(ctx, p, requests)
255 256 257 258 259 260 261

	responseManager.synchronize()

	// unblock popping from queue
	queryQueue.popWait.Done()

	// verify no responses processed
Hannah Howard's avatar
Hannah Howard committed
262
	testutil.AssertDoesReceiveFirst(t, ctx.Done(), "should not process more responses", sentResponses, requestIDChan)
263
}
264 265 266

func TestValidationAndExtensions(t *testing.T) {
	ctx := context.Background()
Hannah Howard's avatar
Hannah Howard committed
267
	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
268
	defer cancel()
269 270

	blockStore := make(map[ipld.Link][]byte)
271
	loader, storer := testutil.NewTestStore(blockStore)
272 273
	blockChain := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 5)

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
	completedRequestChan := make(chan completedRequest, 1)
	sentResponses := make(chan sentResponse, 100)
	sentExtensions := make(chan sentExtension, 1)
	fprs := &fakePeerResponseSender{lastCompletedRequest: completedRequestChan, sentResponses: sentResponses, sentExtensions: sentExtensions}
	peerManager := &fakePeerManager{peerResponseSender: fprs}
	queryQueue := &fakeQueryQueue{}

	extensionData := testutil.RandomBytes(100)
	extensionName := graphsync.ExtensionName("AppleSauce/McGee")
	extension := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionData,
	}
	extensionResponseData := testutil.RandomBytes(100)
	extensionResponse := graphsync.ExtensionData{
		Name: extensionName,
		Data: extensionResponseData,
	}

293 294 295 296 297
	requestID := graphsync.RequestID(rand.Int31())
	requests := []gsmsg.GraphSyncRequest{
		gsmsg.NewRequest(requestID, blockChain.TipLink.(cidlink.Link).Cid, blockChain.Selector(), graphsync.Priority(math.MaxInt32), extension),
	}
	p := testutil.GeneratePeers(1)[0]
298

299 300 301 302 303 304 305 306 307 308 309 310 311 312
	t.Run("on its own, should fail validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
	})

	t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.SendExtensionData(extensionResponse)
313
		})
314 315 316 317 318 319 320 321
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
	})
322

323 324 325 326 327 328
	t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
			hookActions.SendExtensionData(extensionResponse)
329
		})
330 331 332 333 334 335 336
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
337 338
	})

339 340 341 342 343
	t.Run("if any hook fails, should fail", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
344
		})
345 346 347 348 349 350 351 352 353 354 355 356
		responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.SendExtensionData(extensionResponse)
			hookActions.TerminateWithError(errors.New("everything went to crap"))
		})
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")
	})
357

358 359 360 361 362 363
	t.Run("hooks can be unregistered", func(t *testing.T) {
		responseManager := New(ctx, loader, peerManager, queryQueue)
		responseManager.Startup()
		unregister := responseManager.RegisterHook(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.RequestReceivedHookActions) {
			hookActions.ValidateRequest()
			hookActions.SendExtensionData(extensionResponse)
364
		})
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381

		// hook validates request
		responseManager.ProcessRequests(ctx, p, requests)
		var lastRequest completedRequest
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed")
		var receivedExtension sentExtension
		testutil.AssertReceive(ctx, t, sentExtensions, &receivedExtension, "should send extension response")
		require.Equal(t, extensionResponse, receivedExtension.extension, "incorrect extension response sent")

		// unregister
		unregister()

		// no same request should fail
		responseManager.ProcessRequests(ctx, p, requests)
		testutil.AssertReceive(ctx, t, completedRequestChan, &lastRequest, "should complete request")
		require.True(t, gsmsg.IsTerminalFailureCode(lastRequest.result), "should terminate with failure")
382 383
	})
}