responsecollector.go 2.16 KB
Newer Older
1 2
package requestmanager

3 4
import (
	"context"
5 6

	"github.com/ipfs/go-graphsync/requestmanager/types"
7
)
8 9 10 11 12 13 14 15 16 17 18

type responseCollector struct {
	ctx context.Context
}

func newResponseCollector(ctx context.Context) *responseCollector {
	return &responseCollector{ctx}
}

func (rc *responseCollector) collectResponses(
	requestCtx context.Context,
19
	incomingResponses <-chan types.ResponseProgress,
20
	incomingErrors <-chan error,
21
	cancelRequest func()) (<-chan types.ResponseProgress, <-chan error) {
22

23
	returnedResponses := make(chan types.ResponseProgress)
24
	returnedErrors := make(chan error)
25 26

	go func() {
27
		var receivedResponses []types.ResponseProgress
28
		defer close(returnedResponses)
29
		outgoingResponses := func() chan<- types.ResponseProgress {
30 31 32 33 34
			if len(receivedResponses) == 0 {
				return nil
			}
			return returnedResponses
		}
35
		nextResponse := func() types.ResponseProgress {
36
			if len(receivedResponses) == 0 {
37
				return types.ResponseProgress{}
38 39 40
			}
			return receivedResponses[0]
		}
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
		for len(receivedResponses) > 0 || incomingResponses != nil {
			select {
			case <-rc.ctx.Done():
				return
			case <-requestCtx.Done():
				if incomingResponses != nil {
					cancelRequest()
				}
				return
			case response, ok := <-incomingResponses:
				if !ok {
					incomingResponses = nil
				} else {
					receivedResponses = append(receivedResponses, response)
				}
			case outgoingResponses() <- nextResponse():
				receivedResponses = receivedResponses[1:]
			}
		}
	}()
	go func() {
		var receivedErrors []error
		defer close(returnedErrors)

65
		outgoingErrors := func() chan<- error {
66 67 68 69 70
			if len(receivedErrors) == 0 {
				return nil
			}
			return returnedErrors
		}
71
		nextError := func() error {
72
			if len(receivedErrors) == 0 {
73
				return nil
74 75 76 77
			}
			return receivedErrors[0]
		}

78
		for len(receivedErrors) > 0 || incomingErrors != nil {
79 80 81 82 83
			select {
			case <-rc.ctx.Done():
				return
			case <-requestCtx.Done():
				return
hannahhoward's avatar
hannahhoward committed
84
			case err, ok := <-incomingErrors:
85 86 87
				if !ok {
					incomingErrors = nil
				} else {
hannahhoward's avatar
hannahhoward committed
88
					receivedErrors = append(receivedErrors, err)
89 90 91
				}
			case outgoingErrors() <- nextError():
				receivedErrors = receivedErrors[1:]
92 93 94
			}
		}
	}()
95
	return returnedResponses, returnedErrors
96
}