requestmanager.go 15.2 KB
Newer Older
1 2 3 4 5 6 7
package requestmanager

import (
	"context"
	"fmt"
	"math"

8
	blocks "github.com/ipfs/go-block-format"
9
	"github.com/ipfs/go-graphsync"
10
	ipldutil "github.com/ipfs/go-graphsync/ipldutil"
11
	gsmsg "github.com/ipfs/go-graphsync/message"
12 13 14 15 16
	"github.com/ipfs/go-graphsync/metadata"
	"github.com/ipfs/go-graphsync/requestmanager/loader"
	"github.com/ipfs/go-graphsync/requestmanager/types"
	logging "github.com/ipfs/go-log"
	"github.com/ipld/go-ipld-prime"
17
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
18
	"github.com/ipld/go-ipld-prime/traversal"
19
	"github.com/ipld/go-ipld-prime/traversal/selector"
20
	"github.com/libp2p/go-libp2p-core/peer"
21 22
)

23
var log = logging.Logger("graphsync")
24 25 26

const (
	// maxPriority is the max priority as defined by the bitswap protocol
27
	maxPriority = graphsync.Priority(math.MaxInt32)
28 29 30
)

type inProgressRequestStatus struct {
31 32 33 34
	ctx          context.Context
	cancelFn     func()
	p            peer.ID
	networkError chan error
35 36
}

37
type responseHook struct {
38
	key  uint64
39 40 41 42 43 44
	hook graphsync.OnIncomingResponseHook
}

type requestHook struct {
	key  uint64
	hook graphsync.OnOutgoingRequestHook
45 46
}

47 48
// PeerHandler is an interface that can send requests to peers
type PeerHandler interface {
49
	SendRequest(p peer.ID, graphSyncRequest gsmsg.GraphSyncRequest)
50 51
}

52 53 54
// AsyncLoader is an interface for loading links asynchronously, returning
// results as new responses are processed
type AsyncLoader interface {
55
	StartRequest(graphsync.RequestID, string) error
56
	ProcessResponse(responses map[graphsync.RequestID]metadata.Metadata,
57
		blks []blocks.Block)
58 59 60
	AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult
	CompleteResponsesFor(requestID graphsync.RequestID)
	CleanupRequest(requestID graphsync.RequestID)
61 62
}

63 64 65 66 67 68 69 70
// RequestManager tracks outgoing requests and processes incoming reponses
// to them.
type RequestManager struct {
	ctx         context.Context
	cancel      func()
	messages    chan requestManagerMessage
	peerHandler PeerHandler
	rc          *responseCollector
71
	asyncLoader AsyncLoader
72
	// dont touch out side of run loop
73 74
	nextRequestID             graphsync.RequestID
	inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
75
	hooksNextKey              uint64
76
	responseHooks             []responseHook
77
	requestHooks              []requestHook
78 79 80 81 82 83 84
}

type requestManagerMessage interface {
	handle(rm *RequestManager)
}

// New generates a new request manager from a context, network, and selectorQuerier
85
func New(ctx context.Context, asyncLoader AsyncLoader) *RequestManager {
86 87 88 89
	ctx, cancel := context.WithCancel(ctx)
	return &RequestManager{
		ctx:                       ctx,
		cancel:                    cancel,
90
		asyncLoader:               asyncLoader,
91 92
		rc:                        newResponseCollector(ctx),
		messages:                  make(chan requestManagerMessage, 16),
93
		inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus),
94 95 96 97 98 99 100 101 102
	}
}

// SetDelegate specifies who will send messages out to the internet.
func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) {
	rm.peerHandler = peerHandler
}

type inProgressRequest struct {
103 104
	requestID     graphsync.RequestID
	incoming      chan graphsync.ResponseProgress
105
	incomingError chan error
106 107 108 109
}

type newRequestMessage struct {
	p                     peer.ID
110
	root                  ipld.Link
111
	selector              ipld.Node
112
	extensions            []graphsync.ExtensionData
113 114 115 116
	inProgressRequestChan chan<- inProgressRequest
}

// SendRequest initiates a new GraphSync request to the given peer.
117 118
func (rm *RequestManager) SendRequest(ctx context.Context,
	p peer.ID,
119
	root ipld.Link,
120 121
	selector ipld.Node,
	extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) {
122
	if _, err := ipldutil.ParseSelector(selector); err != nil {
123
		return rm.singleErrorResponse(fmt.Errorf("Invalid Selector Spec"))
124 125 126 127 128
	}

	inProgressRequestChan := make(chan inProgressRequest)

	select {
129
	case rm.messages <- &newRequestMessage{p, root, selector, extensions, inProgressRequestChan}:
130
	case <-rm.ctx.Done():
131
		return rm.emptyResponse()
132
	case <-ctx.Done():
133
		return rm.emptyResponse()
134 135 136 137
	}
	var receivedInProgressRequest inProgressRequest
	select {
	case <-rm.ctx.Done():
138
		return rm.emptyResponse()
139 140 141
	case receivedInProgressRequest = <-inProgressRequestChan:
	}

142 143 144 145 146 147 148 149 150 151
	return rm.rc.collectResponses(ctx,
		receivedInProgressRequest.incoming,
		receivedInProgressRequest.incomingError,
		func() {
			rm.cancelRequest(receivedInProgressRequest.requestID,
				receivedInProgressRequest.incoming,
				receivedInProgressRequest.incomingError)
		})
}

152 153
func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) {
	ch := make(chan graphsync.ResponseProgress)
154
	close(ch)
155
	errCh := make(chan error)
156 157 158 159
	close(errCh)
	return ch, errCh
}

160 161
func (rm *RequestManager) singleErrorResponse(err error) (chan graphsync.ResponseProgress, chan error) {
	ch := make(chan graphsync.ResponseProgress)
162
	close(ch)
163
	errCh := make(chan error, 1)
164
	errCh <- err
165 166
	close(errCh)
	return ch, errCh
167 168 169
}

type cancelRequestMessage struct {
170
	requestID graphsync.RequestID
171 172
}

173 174
func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
	incomingResponses chan graphsync.ResponseProgress,
175
	incomingErrors chan error) {
176
	cancelMessageChannel := rm.messages
177
	for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil {
178 179 180 181 182 183 184
		select {
		case cancelMessageChannel <- &cancelRequestMessage{requestID}:
			cancelMessageChannel = nil
		// clear out any remaining responses, in case and "incoming reponse"
		// messages get processed before our cancel message
		case _, ok := <-incomingResponses:
			if !ok {
185 186 187 188 189
				incomingResponses = nil
			}
		case _, ok := <-incomingErrors:
			if !ok {
				incomingErrors = nil
190 191 192 193 194 195 196 197
			}
		case <-rm.ctx.Done():
			return
		}
	}
}

type processResponseMessage struct {
198
	p         peer.ID
199 200
	responses []gsmsg.GraphSyncResponse
	blks      []blocks.Block
201 202 203 204
}

// ProcessResponses ingests the given responses from the network and
// and updates the in progress requests based on those responses.
205
func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyncResponse,
206
	blks []blocks.Block) {
207
	select {
208
	case rm.messages <- &processResponseMessage{p, responses, blks}:
209 210 211 212
	case <-rm.ctx.Done():
	}
}

213 214 215 216 217 218 219
type registerRequestHookMessage struct {
	hook               graphsync.OnOutgoingRequestHook
	unregisterHookChan chan graphsync.UnregisterHookFunc
}

type registerResponseHookMessage struct {
	hook               graphsync.OnIncomingResponseHook
220 221 222
	unregisterHookChan chan graphsync.UnregisterHookFunc
}

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
// RegisterRequestHook registers an extension to process outgoing requests
func (rm *RequestManager) RegisterRequestHook(hook graphsync.OnOutgoingRequestHook) graphsync.UnregisterHookFunc {
	response := make(chan graphsync.UnregisterHookFunc)
	select {
	case rm.messages <- &registerRequestHookMessage{hook, response}:
	case <-rm.ctx.Done():
		return nil
	}
	select {
	case unregister := <-response:
		return unregister
	case <-rm.ctx.Done():
		return nil
	}
}

// RegisterResponseHook registers an extension to process incoming responses
func (rm *RequestManager) RegisterResponseHook(
	hook graphsync.OnIncomingResponseHook) graphsync.UnregisterHookFunc {
242 243
	response := make(chan graphsync.UnregisterHookFunc)
	select {
244
	case rm.messages <- &registerResponseHookMessage{hook, response}:
245 246 247
	case <-rm.ctx.Done():
		return nil
	}
248
	select {
249 250
	case unregister := <-response:
		return unregister
251
	case <-rm.ctx.Done():
252
		return nil
253 254 255
	}
}

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
// Startup starts processing for the WantManager.
func (rm *RequestManager) Startup() {
	go rm.run()
}

// Shutdown ends processing for the want manager.
func (rm *RequestManager) Shutdown() {
	rm.cancel()
}

func (rm *RequestManager) run() {
	// NOTE: Do not open any streams or connections from anywhere in this
	// event loop. Really, just don't do anything likely to block.
	defer rm.cleanupInProcessRequests()

	for {
		select {
		case message := <-rm.messages:
			message.handle(rm)
		case <-rm.ctx.Done():
			return
		}
	}
}

func (rm *RequestManager) cleanupInProcessRequests() {
	for _, requestStatus := range rm.inProgressRequestStatuses {
283
		requestStatus.cancelFn()
284 285 286
	}
}

287
type terminateRequestMessage struct {
288
	requestID graphsync.RequestID
289
}
290

291
func (nrm *newRequestMessage) handle(rm *RequestManager) {
292 293 294
	requestID := rm.nextRequestID
	rm.nextRequestID++

295
	inProgressChan, inProgressErr := rm.setupRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
296 297 298

	select {
	case nrm.inProgressRequestChan <- inProgressRequest{
299 300 301
		requestID:     requestID,
		incoming:      inProgressChan,
		incomingError: inProgressErr,
302 303 304 305 306
	}:
	case <-rm.ctx.Done():
	}
}

307 308 309 310 311
func (trm *terminateRequestMessage) handle(rm *RequestManager) {
	delete(rm.inProgressRequestStatuses, trm.requestID)
	rm.asyncLoader.CleanupRequest(trm.requestID)
}

312 313 314 315 316 317
func (crm *cancelRequestMessage) handle(rm *RequestManager) {
	inProgressRequestStatus, ok := rm.inProgressRequestStatuses[crm.requestID]
	if !ok {
		return
	}

318
	rm.peerHandler.SendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(crm.requestID))
319
	delete(rm.inProgressRequestStatuses, crm.requestID)
320
	inProgressRequestStatus.cancelFn()
321 322 323
}

func (prm *processResponseMessage) handle(rm *RequestManager) {
324
	filteredResponses := rm.filterResponsesForPeer(prm.responses, prm.p)
325
	filteredResponses = rm.processExtensions(filteredResponses, prm.p)
326
	responseMetadata := metadataForResponses(filteredResponses)
327 328 329 330
	rm.asyncLoader.ProcessResponse(responseMetadata, prm.blks)
	rm.processTerminations(filteredResponses)
}

331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
func (rhm *registerRequestHookMessage) handle(rm *RequestManager) {
	rh := requestHook{rm.hooksNextKey, rhm.hook}
	rm.hooksNextKey++
	rm.requestHooks = append(rm.requestHooks, rh)
	select {
	case rhm.unregisterHookChan <- func() {
		for i, matchHook := range rm.requestHooks {
			if rh.key == matchHook.key {
				rm.requestHooks = append(rm.requestHooks[:i], rm.requestHooks[i+1:]...)
				return
			}
		}
	}:
	case <-rm.ctx.Done():
	}
}

func (rhm *registerResponseHookMessage) handle(rm *RequestManager) {
	rh := responseHook{rm.hooksNextKey, rhm.hook}
	rm.hooksNextKey++
351 352 353 354 355 356 357 358 359 360 361 362
	rm.responseHooks = append(rm.responseHooks, rh)
	select {
	case rhm.unregisterHookChan <- func() {
		for i, matchHook := range rm.responseHooks {
			if rh.key == matchHook.key {
				rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...)
				return
			}
		}
	}:
	case <-rm.ctx.Done():
	}
363 364
}

365 366 367 368 369 370
func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
	responsesForPeer := make([]gsmsg.GraphSyncResponse, 0, len(responses))
	for _, response := range responses {
		requestStatus, ok := rm.inProgressRequestStatuses[response.RequestID()]
		if !ok || requestStatus.p != p {
			continue
371
		}
372
		responsesForPeer = append(responsesForPeer, response)
373
	}
374 375
	return responsesForPeer
}
376

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
func (rm *RequestManager) processExtensions(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
	remainingResponses := make([]gsmsg.GraphSyncResponse, 0, len(responses))
	for _, response := range responses {
		success := rm.processExtensionsForResponse(p, response)
		if success {
			remainingResponses = append(remainingResponses, response)
		}
	}
	return remainingResponses
}

func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg.GraphSyncResponse) bool {
	for _, responseHook := range rm.responseHooks {
		err := responseHook.hook(p, response)
		if err != nil {
			requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
			responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown)
			select {
			case requestStatus.networkError <- responseError:
			case <-requestStatus.ctx.Done():
			}
			requestStatus.cancelFn()
			return false
		}
	}
	return true
}

405 406
func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncResponse) {
	for _, response := range responses {
407
		if gsmsg.IsTerminalResponseCode(response.Status()) {
408 409 410 411 412 413
			if gsmsg.IsTerminalFailureCode(response.Status()) {
				requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
				responseError := rm.generateResponseErrorFromStatus(response.Status())
				select {
				case requestStatus.networkError <- responseError:
				case <-requestStatus.ctx.Done():
414
				}
415
				requestStatus.cancelFn()
416
			}
417 418
			rm.asyncLoader.CompleteResponsesFor(response.RequestID())
			delete(rm.inProgressRequestStatuses, response.RequestID())
419 420 421
		}
	}
}
422

423
func (rm *RequestManager) generateResponseErrorFromStatus(status graphsync.ResponseStatusCode) error {
424
	switch status {
425
	case graphsync.RequestFailedBusy:
426
		return fmt.Errorf("Request Failed - Peer Is Busy")
427
	case graphsync.RequestFailedContentNotFound:
428
		return fmt.Errorf("Request Failed - Content Not Found")
429
	case graphsync.RequestFailedLegal:
430
		return fmt.Errorf("Request Failed - For Legal Reasons")
431
	case graphsync.RequestFailedUnknown:
432
		return fmt.Errorf("Request Failed - Unknown Reason")
433
	default:
434
		return fmt.Errorf("Unknown")
435 436
	}
}
437

438 439 440 441 442 443 444 445 446 447 448 449 450
type hookActions struct {
	persistenceOption  string
	nodeBuilderChooser traversal.NodeBuilderChooser
}

func (ha *hookActions) UsePersistenceOption(name string) {
	ha.persistenceOption = name
}

func (ha *hookActions) UseNodeBuilderChooser(nodeBuilderChooser traversal.NodeBuilderChooser) {
	ha.nodeBuilderChooser = nodeBuilderChooser
}

451
func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node, extensions []graphsync.ExtensionData) (chan graphsync.ResponseProgress, chan error) {
452
	_, err := ipldutil.EncodeNode(selectorSpec)
453 454 455
	if err != nil {
		return rm.singleErrorResponse(err)
	}
456
	selector, err := ipldutil.ParseSelector(selectorSpec)
457 458 459
	if err != nil {
		return rm.singleErrorResponse(err)
	}
460 461 462 463
	asCidLink, ok := root.(cidlink.Link)
	if !ok {
		return rm.singleErrorResponse(fmt.Errorf("request failed: link has no cid"))
	}
464 465 466 467 468
	networkErrorChan := make(chan error, 1)
	ctx, cancel := context.WithCancel(rm.ctx)
	rm.inProgressRequestStatuses[requestID] = &inProgressRequestStatus{
		ctx, cancel, p, networkErrorChan,
	}
469 470 471 472 473 474 475 476 477 478 479
	request := gsmsg.NewRequest(requestID, asCidLink.Cid, selectorSpec, maxPriority, extensions...)
	ha := &hookActions{}
	for _, hook := range rm.requestHooks {
		hook.hook(p, request, ha)
	}
	err = rm.asyncLoader.StartRequest(requestID, ha.persistenceOption)
	if err != nil {
		return rm.singleErrorResponse(err)
	}
	rm.peerHandler.SendRequest(p, request)
	return rm.executeTraversal(ctx, requestID, root, selector, ha.nodeBuilderChooser, networkErrorChan)
480 481 482 483
}

func (rm *RequestManager) executeTraversal(
	ctx context.Context,
484
	requestID graphsync.RequestID,
485
	root ipld.Link,
486
	selector selector.Selector,
487
	nodeBuilderChooser traversal.NodeBuilderChooser,
488
	networkErrorChan chan error,
489 490
) (chan graphsync.ResponseProgress, chan error) {
	inProgressChan := make(chan graphsync.ResponseProgress)
491 492 493 494
	inProgressErr := make(chan error)
	loaderFn := loader.WrapAsyncLoader(ctx, rm.asyncLoader.AsyncLoad, requestID, inProgressErr)
	visitor := visitToChannel(ctx, inProgressChan)
	go func() {
495
		_ = ipldutil.Traverse(ctx, loaderFn, nodeBuilderChooser, root, selector, visitor)
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
		select {
		case networkError := <-networkErrorChan:
			select {
			case <-rm.ctx.Done():
			case inProgressErr <- networkError:
			}
		default:
		}
		select {
		case <-ctx.Done():
		case rm.messages <- &terminateRequestMessage{requestID}:
		}
		close(inProgressChan)
		close(inProgressErr)
	}()
	return inProgressChan, inProgressErr
}