requestmanager.go 13.3 KB
Newer Older
1 2 3 4 5 6 7
package requestmanager

import (
	"context"
	"fmt"
	"math"

8
	blocks "github.com/ipfs/go-block-format"
9
	"github.com/ipfs/go-graphsync"
10
	ipldutil "github.com/ipfs/go-graphsync/ipldutil"
11
	gsmsg "github.com/ipfs/go-graphsync/message"
12 13 14 15 16
	"github.com/ipfs/go-graphsync/metadata"
	"github.com/ipfs/go-graphsync/requestmanager/loader"
	"github.com/ipfs/go-graphsync/requestmanager/types"
	logging "github.com/ipfs/go-log"
	"github.com/ipld/go-ipld-prime"
17
	cidlink "github.com/ipld/go-ipld-prime/linking/cid"
18
	"github.com/ipld/go-ipld-prime/traversal/selector"
19
	"github.com/libp2p/go-libp2p-core/peer"
20 21
)

22
var log = logging.Logger("graphsync")
23 24 25

const (
	// maxPriority is the max priority as defined by the bitswap protocol
26
	maxPriority = graphsync.Priority(math.MaxInt32)
27 28 29
)

type inProgressRequestStatus struct {
30 31 32 33
	ctx          context.Context
	cancelFn     func()
	p            peer.ID
	networkError chan error
34 35
}

36
type responseHook struct {
37
	key  uint64
38 39 40
	hook graphsync.OnResponseReceivedHook
}

41 42
// PeerHandler is an interface that can send requests to peers
type PeerHandler interface {
43
	SendRequest(p peer.ID, graphSyncRequest gsmsg.GraphSyncRequest)
44 45
}

46 47 48
// AsyncLoader is an interface for loading links asynchronously, returning
// results as new responses are processed
type AsyncLoader interface {
49 50
	StartRequest(requestID graphsync.RequestID)
	ProcessResponse(responses map[graphsync.RequestID]metadata.Metadata,
51
		blks []blocks.Block)
52 53 54
	AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult
	CompleteResponsesFor(requestID graphsync.RequestID)
	CleanupRequest(requestID graphsync.RequestID)
55 56
}

57 58 59 60 61 62 63 64
// RequestManager tracks outgoing requests and processes incoming reponses
// to them.
type RequestManager struct {
	ctx         context.Context
	cancel      func()
	messages    chan requestManagerMessage
	peerHandler PeerHandler
	rc          *responseCollector
65
	asyncLoader AsyncLoader
66
	// dont touch out side of run loop
67 68
	nextRequestID             graphsync.RequestID
	inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
69
	responseHookNextKey       uint64
70
	responseHooks             []responseHook
71 72 73 74 75 76 77
}

type requestManagerMessage interface {
	handle(rm *RequestManager)
}

// New generates a new request manager from a context, network, and selectorQuerier
78
func New(ctx context.Context, asyncLoader AsyncLoader) *RequestManager {
79 80 81 82
	ctx, cancel := context.WithCancel(ctx)
	return &RequestManager{
		ctx:                       ctx,
		cancel:                    cancel,
83
		asyncLoader:               asyncLoader,
84 85
		rc:                        newResponseCollector(ctx),
		messages:                  make(chan requestManagerMessage, 16),
86
		inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus),
87 88 89 90 91 92 93 94 95
	}
}

// SetDelegate specifies who will send messages out to the internet.
func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) {
	rm.peerHandler = peerHandler
}

type inProgressRequest struct {
96 97
	requestID     graphsync.RequestID
	incoming      chan graphsync.ResponseProgress
98
	incomingError chan error
99 100 101 102
}

type newRequestMessage struct {
	p                     peer.ID
103
	root                  ipld.Link
104
	selector              ipld.Node
105
	extensions            []graphsync.ExtensionData
106 107 108 109
	inProgressRequestChan chan<- inProgressRequest
}

// SendRequest initiates a new GraphSync request to the given peer.
110 111
func (rm *RequestManager) SendRequest(ctx context.Context,
	p peer.ID,
112
	root ipld.Link,
113 114
	selector ipld.Node,
	extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) {
115
	if _, err := ipldutil.ParseSelector(selector); err != nil {
116
		return rm.singleErrorResponse(fmt.Errorf("Invalid Selector Spec"))
117 118 119 120 121
	}

	inProgressRequestChan := make(chan inProgressRequest)

	select {
122
	case rm.messages <- &newRequestMessage{p, root, selector, extensions, inProgressRequestChan}:
123
	case <-rm.ctx.Done():
124
		return rm.emptyResponse()
125
	case <-ctx.Done():
126
		return rm.emptyResponse()
127 128 129 130
	}
	var receivedInProgressRequest inProgressRequest
	select {
	case <-rm.ctx.Done():
131
		return rm.emptyResponse()
132 133 134
	case receivedInProgressRequest = <-inProgressRequestChan:
	}

135 136 137 138 139 140 141 142 143 144
	return rm.rc.collectResponses(ctx,
		receivedInProgressRequest.incoming,
		receivedInProgressRequest.incomingError,
		func() {
			rm.cancelRequest(receivedInProgressRequest.requestID,
				receivedInProgressRequest.incoming,
				receivedInProgressRequest.incomingError)
		})
}

145 146
func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) {
	ch := make(chan graphsync.ResponseProgress)
147
	close(ch)
148
	errCh := make(chan error)
149 150 151 152
	close(errCh)
	return ch, errCh
}

153 154
func (rm *RequestManager) singleErrorResponse(err error) (chan graphsync.ResponseProgress, chan error) {
	ch := make(chan graphsync.ResponseProgress)
155
	close(ch)
156
	errCh := make(chan error, 1)
157
	errCh <- err
158 159
	close(errCh)
	return ch, errCh
160 161 162
}

type cancelRequestMessage struct {
163
	requestID graphsync.RequestID
164 165
}

166 167
func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID,
	incomingResponses chan graphsync.ResponseProgress,
168
	incomingErrors chan error) {
169
	cancelMessageChannel := rm.messages
170
	for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil {
171 172 173 174 175 176 177
		select {
		case cancelMessageChannel <- &cancelRequestMessage{requestID}:
			cancelMessageChannel = nil
		// clear out any remaining responses, in case and "incoming reponse"
		// messages get processed before our cancel message
		case _, ok := <-incomingResponses:
			if !ok {
178 179 180 181 182
				incomingResponses = nil
			}
		case _, ok := <-incomingErrors:
			if !ok {
				incomingErrors = nil
183 184 185 186 187 188 189 190
			}
		case <-rm.ctx.Done():
			return
		}
	}
}

type processResponseMessage struct {
191
	p         peer.ID
192 193
	responses []gsmsg.GraphSyncResponse
	blks      []blocks.Block
194 195 196 197
}

// ProcessResponses ingests the given responses from the network and
// and updates the in progress requests based on those responses.
198
func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyncResponse,
199
	blks []blocks.Block) {
200
	select {
201
	case rm.messages <- &processResponseMessage{p, responses, blks}:
202 203 204 205
	case <-rm.ctx.Done():
	}
}

206 207 208 209 210
type registerHookMessage struct {
	hook               graphsync.OnResponseReceivedHook
	unregisterHookChan chan graphsync.UnregisterHookFunc
}

211 212
// RegisterHook registers an extension to processincoming responses
func (rm *RequestManager) RegisterHook(
213 214 215 216 217 218 219
	hook graphsync.OnResponseReceivedHook) graphsync.UnregisterHookFunc {
	response := make(chan graphsync.UnregisterHookFunc)
	select {
	case rm.messages <- &registerHookMessage{hook, response}:
	case <-rm.ctx.Done():
		return nil
	}
220
	select {
221 222
	case unregister := <-response:
		return unregister
223
	case <-rm.ctx.Done():
224
		return nil
225 226 227
	}
}

228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
// Startup starts processing for the WantManager.
func (rm *RequestManager) Startup() {
	go rm.run()
}

// Shutdown ends processing for the want manager.
func (rm *RequestManager) Shutdown() {
	rm.cancel()
}

func (rm *RequestManager) run() {
	// NOTE: Do not open any streams or connections from anywhere in this
	// event loop. Really, just don't do anything likely to block.
	defer rm.cleanupInProcessRequests()

	for {
		select {
		case message := <-rm.messages:
			message.handle(rm)
		case <-rm.ctx.Done():
			return
		}
	}
}

func (rm *RequestManager) cleanupInProcessRequests() {
	for _, requestStatus := range rm.inProgressRequestStatuses {
255
		requestStatus.cancelFn()
256 257 258
	}
}

259
type terminateRequestMessage struct {
260
	requestID graphsync.RequestID
261
}
262

263
func (nrm *newRequestMessage) handle(rm *RequestManager) {
264 265 266
	requestID := rm.nextRequestID
	rm.nextRequestID++

267
	inProgressChan, inProgressErr := rm.setupRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
268 269 270

	select {
	case nrm.inProgressRequestChan <- inProgressRequest{
271 272 273
		requestID:     requestID,
		incoming:      inProgressChan,
		incomingError: inProgressErr,
274 275 276 277 278
	}:
	case <-rm.ctx.Done():
	}
}

279 280 281 282 283
func (trm *terminateRequestMessage) handle(rm *RequestManager) {
	delete(rm.inProgressRequestStatuses, trm.requestID)
	rm.asyncLoader.CleanupRequest(trm.requestID)
}

284 285 286 287 288 289
func (crm *cancelRequestMessage) handle(rm *RequestManager) {
	inProgressRequestStatus, ok := rm.inProgressRequestStatuses[crm.requestID]
	if !ok {
		return
	}

290
	rm.peerHandler.SendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(crm.requestID))
291
	delete(rm.inProgressRequestStatuses, crm.requestID)
292
	inProgressRequestStatus.cancelFn()
293 294 295
}

func (prm *processResponseMessage) handle(rm *RequestManager) {
296
	filteredResponses := rm.filterResponsesForPeer(prm.responses, prm.p)
297
	filteredResponses = rm.processExtensions(filteredResponses, prm.p)
298
	responseMetadata := metadataForResponses(filteredResponses)
299 300 301 302
	rm.asyncLoader.ProcessResponse(responseMetadata, prm.blks)
	rm.processTerminations(filteredResponses)
}

303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
func (rhm *registerHookMessage) handle(rm *RequestManager) {
	rh := responseHook{rm.responseHookNextKey, rhm.hook}
	rm.responseHookNextKey++
	rm.responseHooks = append(rm.responseHooks, rh)
	select {
	case rhm.unregisterHookChan <- func() {
		for i, matchHook := range rm.responseHooks {
			if rh.key == matchHook.key {
				rm.responseHooks = append(rm.responseHooks[:i], rm.responseHooks[i+1:]...)
				return
			}
		}
	}:
	case <-rm.ctx.Done():
	}
318 319
}

320 321 322 323 324 325
func (rm *RequestManager) filterResponsesForPeer(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
	responsesForPeer := make([]gsmsg.GraphSyncResponse, 0, len(responses))
	for _, response := range responses {
		requestStatus, ok := rm.inProgressRequestStatuses[response.RequestID()]
		if !ok || requestStatus.p != p {
			continue
326
		}
327
		responsesForPeer = append(responsesForPeer, response)
328
	}
329 330
	return responsesForPeer
}
331

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
func (rm *RequestManager) processExtensions(responses []gsmsg.GraphSyncResponse, p peer.ID) []gsmsg.GraphSyncResponse {
	remainingResponses := make([]gsmsg.GraphSyncResponse, 0, len(responses))
	for _, response := range responses {
		success := rm.processExtensionsForResponse(p, response)
		if success {
			remainingResponses = append(remainingResponses, response)
		}
	}
	return remainingResponses
}

func (rm *RequestManager) processExtensionsForResponse(p peer.ID, response gsmsg.GraphSyncResponse) bool {
	for _, responseHook := range rm.responseHooks {
		err := responseHook.hook(p, response)
		if err != nil {
			requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
			responseError := rm.generateResponseErrorFromStatus(graphsync.RequestFailedUnknown)
			select {
			case requestStatus.networkError <- responseError:
			case <-requestStatus.ctx.Done():
			}
			requestStatus.cancelFn()
			return false
		}
	}
	return true
}

360 361
func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncResponse) {
	for _, response := range responses {
362
		if gsmsg.IsTerminalResponseCode(response.Status()) {
363 364 365 366 367 368
			if gsmsg.IsTerminalFailureCode(response.Status()) {
				requestStatus := rm.inProgressRequestStatuses[response.RequestID()]
				responseError := rm.generateResponseErrorFromStatus(response.Status())
				select {
				case requestStatus.networkError <- responseError:
				case <-requestStatus.ctx.Done():
369
				}
370
				requestStatus.cancelFn()
371
			}
372 373
			rm.asyncLoader.CompleteResponsesFor(response.RequestID())
			delete(rm.inProgressRequestStatuses, response.RequestID())
374 375 376
		}
	}
}
377

378
func (rm *RequestManager) generateResponseErrorFromStatus(status graphsync.ResponseStatusCode) error {
379
	switch status {
380
	case graphsync.RequestFailedBusy:
381
		return fmt.Errorf("Request Failed - Peer Is Busy")
382
	case graphsync.RequestFailedContentNotFound:
383
		return fmt.Errorf("Request Failed - Content Not Found")
384
	case graphsync.RequestFailedLegal:
385
		return fmt.Errorf("Request Failed - For Legal Reasons")
386
	case graphsync.RequestFailedUnknown:
387
		return fmt.Errorf("Request Failed - Unknown Reason")
388
	default:
389
		return fmt.Errorf("Unknown")
390 391
	}
}
392

393
func (rm *RequestManager) setupRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node, extensions []graphsync.ExtensionData) (chan graphsync.ResponseProgress, chan error) {
394
	_, err := ipldutil.EncodeNode(selectorSpec)
395 396 397
	if err != nil {
		return rm.singleErrorResponse(err)
	}
398
	selector, err := ipldutil.ParseSelector(selectorSpec)
399 400 401
	if err != nil {
		return rm.singleErrorResponse(err)
	}
402 403 404 405
	asCidLink, ok := root.(cidlink.Link)
	if !ok {
		return rm.singleErrorResponse(fmt.Errorf("request failed: link has no cid"))
	}
406 407 408 409 410 411
	networkErrorChan := make(chan error, 1)
	ctx, cancel := context.WithCancel(rm.ctx)
	rm.inProgressRequestStatuses[requestID] = &inProgressRequestStatus{
		ctx, cancel, p, networkErrorChan,
	}
	rm.asyncLoader.StartRequest(requestID)
412
	rm.peerHandler.SendRequest(p, gsmsg.NewRequest(requestID, asCidLink.Cid, selectorSpec, maxPriority, extensions...))
413 414 415 416 417
	return rm.executeTraversal(ctx, requestID, root, selector, networkErrorChan)
}

func (rm *RequestManager) executeTraversal(
	ctx context.Context,
418
	requestID graphsync.RequestID,
419
	root ipld.Link,
420
	selector selector.Selector,
421
	networkErrorChan chan error,
422 423
) (chan graphsync.ResponseProgress, chan error) {
	inProgressChan := make(chan graphsync.ResponseProgress)
424 425 426 427
	inProgressErr := make(chan error)
	loaderFn := loader.WrapAsyncLoader(ctx, rm.asyncLoader.AsyncLoad, requestID, inProgressErr)
	visitor := visitToChannel(ctx, inProgressChan)
	go func() {
428
		_ = ipldutil.Traverse(ctx, loaderFn, nil, root, selector, visitor)
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
		select {
		case networkError := <-networkErrorChan:
			select {
			case <-rm.ctx.Done():
			case inProgressErr <- networkError:
			}
		default:
		}
		select {
		case <-ctx.Done():
		case rm.messages <- &terminateRequestMessage{requestID}:
		}
		close(inProgressChan)
		close(inProgressErr)
	}()
	return inProgressChan, inProgressErr
}