responsecollector.go 2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
package requestmanager

import "context"

type responseCollector struct {
	ctx context.Context
}

func newResponseCollector(ctx context.Context) *responseCollector {
	return &responseCollector{ctx}
}

func (rc *responseCollector) collectResponses(
	requestCtx context.Context,
15 16 17
	incomingResponses <-chan ResponseProgress,
	incomingErrors <-chan ResponseError,
	cancelRequest func()) (<-chan ResponseProgress, <-chan ResponseError) {
18 19

	returnedResponses := make(chan ResponseProgress)
20
	returnedErrors := make(chan ResponseError)
21 22

	go func() {
23 24
		var receivedResponses []ResponseProgress
		var receivedErrors []ResponseError
25
		defer close(returnedResponses)
26
		defer close(returnedErrors)
27 28 29 30 31 32 33 34 35 36 37 38
		outgoingResponses := func() chan<- ResponseProgress {
			if len(receivedResponses) == 0 {
				return nil
			}
			return returnedResponses
		}
		nextResponse := func() ResponseProgress {
			if len(receivedResponses) == 0 {
				return nil
			}
			return receivedResponses[0]
		}
39 40 41 42 43 44 45 46 47 48 49 50 51 52
		outgoingErrors := func() chan<- ResponseError {
			if len(receivedErrors) == 0 {
				return nil
			}
			return returnedErrors
		}
		nextError := func() ResponseError {
			if len(receivedErrors) == 0 {
				return ResponseError{}
			}
			return receivedErrors[0]
		}

		for len(receivedResponses) > 0 || len(receivedErrors) > 0 || incomingResponses != nil || incomingErrors != nil {
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
			select {
			case <-rc.ctx.Done():
				return
			case <-requestCtx.Done():
				if incomingResponses != nil {
					cancelRequest()
				}
				return
			case response, ok := <-incomingResponses:
				if !ok {
					incomingResponses = nil
				} else {
					receivedResponses = append(receivedResponses, response)
				}
			case outgoingResponses() <- nextResponse():
				receivedResponses = receivedResponses[1:]
69 70 71 72 73 74 75 76
			case error, ok := <-incomingErrors:
				if !ok {
					incomingErrors = nil
				} else {
					receivedErrors = append(receivedErrors, error)
				}
			case outgoingErrors() <- nextError():
				receivedErrors = receivedErrors[1:]
77 78 79
			}
		}
	}()
80
	return returnedResponses, returnedErrors
81
}