Commit 77a15c1b authored by hannahhoward's avatar hannahhoward

refactor(asyncloader): return single channel

Async loader returns a single channel result type
parent 7a583a29
...@@ -10,10 +10,9 @@ import ( ...@@ -10,10 +10,9 @@ import (
) )
type loadRequest struct { type loadRequest struct {
requestID gsmsg.GraphSyncRequestID requestID gsmsg.GraphSyncRequestID
link ipld.Link link ipld.Link
responseChan chan []byte resultChan chan AsyncLoadResult
errChan chan error
} }
var loadRequestPool = sync.Pool{ var loadRequestPool = sync.Pool{
...@@ -24,13 +23,11 @@ var loadRequestPool = sync.Pool{ ...@@ -24,13 +23,11 @@ var loadRequestPool = sync.Pool{
func newLoadRequest(requestID gsmsg.GraphSyncRequestID, func newLoadRequest(requestID gsmsg.GraphSyncRequestID,
link ipld.Link, link ipld.Link,
responseChan chan []byte, resultChan chan AsyncLoadResult) *loadRequest {
errChan chan error) *loadRequest {
lr := loadRequestPool.Get().(*loadRequest) lr := loadRequestPool.Get().(*loadRequest)
lr.requestID = requestID lr.requestID = requestID
lr.link = link lr.link = link
lr.responseChan = responseChan lr.resultChan = resultChan
lr.errChan = errChan
return lr return lr
} }
...@@ -60,6 +57,12 @@ type finishRequestMessage struct { ...@@ -60,6 +57,12 @@ type finishRequestMessage struct {
// bytes nil, error nil = did not load, but try again later // bytes nil, error nil = did not load, but try again later
type LoadAttempter func(gsmsg.GraphSyncRequestID, ipld.Link) ([]byte, error) type LoadAttempter func(gsmsg.GraphSyncRequestID, ipld.Link) ([]byte, error)
// AsyncLoadResult is sent once over the channel returned by an async load.
type AsyncLoadResult struct {
Data []byte
Err error
}
// AsyncLoader is used to make multiple attempts to load a blocks over the // AsyncLoader is used to make multiple attempts to load a blocks over the
// course of a request - as long as a request is in progress, it will make multiple // course of a request - as long as a request is in progress, it will make multiple
// attempts to load a block until it gets a definitive result of whether the block // attempts to load a block until it gets a definitive result of whether the block
...@@ -89,16 +92,15 @@ func New(ctx context.Context, loadAttempter LoadAttempter) *AsyncLoader { ...@@ -89,16 +92,15 @@ func New(ctx context.Context, loadAttempter LoadAttempter) *AsyncLoader {
// AsyncLoad asynchronously loads the given link for the given request ID. It returns a channel for data and a channel // AsyncLoad asynchronously loads the given link for the given request ID. It returns a channel for data and a channel
// for errors -- only one message will be sent over either. // for errors -- only one message will be sent over either.
func (abl *AsyncLoader) AsyncLoad(requestID gsmsg.GraphSyncRequestID, link ipld.Link) (<-chan []byte, <-chan error) { func (abl *AsyncLoader) AsyncLoad(requestID gsmsg.GraphSyncRequestID, link ipld.Link) <-chan AsyncLoadResult {
responseChan := make(chan []byte, 1) resultChan := make(chan AsyncLoadResult, 1)
errChan := make(chan error, 1) lr := newLoadRequest(requestID, link, resultChan)
lr := newLoadRequest(requestID, link, responseChan, errChan)
select { select {
case <-abl.ctx.Done(): case <-abl.ctx.Done():
abl.terminateWithError("Context Closed", responseChan, errChan) abl.terminateWithError("Context Closed", resultChan)
case abl.incomingMessages <- lr: case abl.incomingMessages <- lr:
} }
return responseChan, errChan return resultChan
} }
// NewResponsesAvailable indicates that the async loader should make another attempt to load // NewResponsesAvailable indicates that the async loader should make another attempt to load
...@@ -179,22 +181,20 @@ func (abl *AsyncLoader) messageQueueWorker() { ...@@ -179,22 +181,20 @@ func (abl *AsyncLoader) messageQueueWorker() {
func (lr *loadRequest) handle(abl *AsyncLoader) { func (lr *loadRequest) handle(abl *AsyncLoader) {
_, ok := abl.activeRequests[lr.requestID] _, ok := abl.activeRequests[lr.requestID]
if !ok { if !ok {
abl.terminateWithError("No active request", lr.responseChan, lr.errChan) abl.terminateWithError("No active request", lr.resultChan)
returnLoadRequest(lr) returnLoadRequest(lr)
return return
} }
response, err := abl.loadAttempter(lr.requestID, lr.link) response, err := abl.loadAttempter(lr.requestID, lr.link)
if err != nil { if err != nil {
lr.errChan <- err lr.resultChan <- AsyncLoadResult{nil, err}
close(lr.errChan) close(lr.resultChan)
close(lr.responseChan)
returnLoadRequest(lr) returnLoadRequest(lr)
return return
} }
if response != nil { if response != nil {
lr.responseChan <- response lr.resultChan <- AsyncLoadResult{response, nil}
close(lr.errChan) close(lr.resultChan)
close(lr.responseChan)
returnLoadRequest(lr) returnLoadRequest(lr)
return return
} }
...@@ -211,7 +211,7 @@ func (frm *finishRequestMessage) handle(abl *AsyncLoader) { ...@@ -211,7 +211,7 @@ func (frm *finishRequestMessage) handle(abl *AsyncLoader) {
abl.pausedRequests = nil abl.pausedRequests = nil
for _, lr := range pausedRequests { for _, lr := range pausedRequests {
if lr.requestID == frm.requestID { if lr.requestID == frm.requestID {
abl.terminateWithError("No active request", lr.responseChan, lr.errChan) abl.terminateWithError("No active request", lr.resultChan)
returnLoadRequest(lr) returnLoadRequest(lr)
} else { } else {
abl.pausedRequests = append(abl.pausedRequests, lr) abl.pausedRequests = append(abl.pausedRequests, lr)
...@@ -232,8 +232,7 @@ func (nram *newResponsesAvailableMessage) handle(abl *AsyncLoader) { ...@@ -232,8 +232,7 @@ func (nram *newResponsesAvailableMessage) handle(abl *AsyncLoader) {
} }
} }
func (abl *AsyncLoader) terminateWithError(errMsg string, responseChan chan<- []byte, errChan chan<- error) { func (abl *AsyncLoader) terminateWithError(errMsg string, resultChan chan<- AsyncLoadResult) {
errChan <- errors.New(errMsg) resultChan <- AsyncLoadResult{nil, errors.New(errMsg)}
close(errChan) close(resultChan)
close(responseChan)
} }
...@@ -27,24 +27,19 @@ func TestAsyncLoadWhenRequestNotInProgress(t *testing.T) { ...@@ -27,24 +27,19 @@ func TestAsyncLoadWhenRequestNotInProgress(t *testing.T) {
link := testbridge.NewMockLink() link := testbridge.NewMockLink()
requestID := gsmsg.GraphSyncRequestID(rand.Int31()) requestID := gsmsg.GraphSyncRequestID(rand.Int31())
responseChan, errChan := asyncLoader.AsyncLoad(requestID, link) resultChan := asyncLoader.AsyncLoad(requestID, link)
select { select {
case _, ok := <-responseChan: case result := <-resultChan:
if ok { if result.Data != nil {
t.Fatal("should not have sent responses") t.Fatal("should not have sent responses")
} }
case <-ctx.Done(): if result.Err == nil {
t.Fatal("should have closed response channel")
}
select {
case _, ok := <-errChan:
if !ok {
t.Fatal("should have sent an error") t.Fatal("should have sent an error")
} }
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have closed error channel") t.Fatal("should have produced result")
} }
if callCount > 0 { if callCount > 0 {
...@@ -67,24 +62,18 @@ func TestAsyncLoadWhenInitialLoadSucceeds(t *testing.T) { ...@@ -67,24 +62,18 @@ func TestAsyncLoadWhenInitialLoadSucceeds(t *testing.T) {
link := testbridge.NewMockLink() link := testbridge.NewMockLink()
requestID := gsmsg.GraphSyncRequestID(rand.Int31()) requestID := gsmsg.GraphSyncRequestID(rand.Int31())
asyncLoader.StartRequest(requestID) asyncLoader.StartRequest(requestID)
responseChan, errChan := asyncLoader.AsyncLoad(requestID, link) resultChan := asyncLoader.AsyncLoad(requestID, link)
select { select {
case _, ok := <-responseChan: case result := <-resultChan:
if !ok { if result.Data == nil {
t.Fatal("should have sent a response") t.Fatal("should have sent a response")
} }
case <-ctx.Done(): if result.Err != nil {
t.Fatal("should have closed response channel")
}
select {
case _, ok := <-errChan:
if ok {
t.Fatal("should not have sent an error") t.Fatal("should not have sent an error")
} }
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have closed error channel") t.Fatal("should have closed response channel")
} }
if callCount == 0 { if callCount == 0 {
...@@ -107,24 +96,18 @@ func TestAsyncLoadInitialLoadFails(t *testing.T) { ...@@ -107,24 +96,18 @@ func TestAsyncLoadInitialLoadFails(t *testing.T) {
link := testbridge.NewMockLink() link := testbridge.NewMockLink()
requestID := gsmsg.GraphSyncRequestID(rand.Int31()) requestID := gsmsg.GraphSyncRequestID(rand.Int31())
asyncLoader.StartRequest(requestID) asyncLoader.StartRequest(requestID)
responseChan, errChan := asyncLoader.AsyncLoad(requestID, link) resultChan := asyncLoader.AsyncLoad(requestID, link)
select { select {
case _, ok := <-responseChan: case result := <-resultChan:
if ok { if result.Data != nil {
t.Fatal("should not have sent responses") t.Fatal("should not have sent responses")
} }
case <-ctx.Done(): if result.Err == nil {
t.Fatal("should have closed response channel")
}
select {
case _, ok := <-errChan:
if !ok {
t.Fatal("should have sent an error") t.Fatal("should have sent an error")
} }
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have closed error channel") t.Fatal("should have closed response channel")
} }
if callCount == 0 { if callCount == 0 {
...@@ -153,34 +136,26 @@ func TestAsyncLoadInitialLoadIndeterminateThenSucceeds(t *testing.T) { ...@@ -153,34 +136,26 @@ func TestAsyncLoadInitialLoadIndeterminateThenSucceeds(t *testing.T) {
link := testbridge.NewMockLink() link := testbridge.NewMockLink()
requestID := gsmsg.GraphSyncRequestID(rand.Int31()) requestID := gsmsg.GraphSyncRequestID(rand.Int31())
asyncLoader.StartRequest(requestID) asyncLoader.StartRequest(requestID)
responseChan, errChan := asyncLoader.AsyncLoad(requestID, link) resultChan := asyncLoader.AsyncLoad(requestID, link)
select { select {
case <-called: case <-called:
case <-responseChan: case <-resultChan:
t.Fatal("Should not have sent message on response chan") t.Fatal("Should not have sent message on response chan")
case <-errChan:
t.Fatal("Should not have sent messages on error chan")
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have attempted load once") t.Fatal("should have attempted load once")
} }
asyncLoader.NewResponsesAvailable() asyncLoader.NewResponsesAvailable()
select { select {
case _, ok := <-responseChan: case result := <-resultChan:
if !ok { if result.Data == nil {
t.Fatal("should have sent a response") t.Fatal("should have sent a response")
} }
case <-ctx.Done(): if result.Err != nil {
t.Fatal("should have closed response channel")
}
select {
case _, ok := <-errChan:
if ok {
t.Fatal("should not have sent an error") t.Fatal("should not have sent an error")
} }
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have closed error channel") t.Fatal("should have closed response channel")
} }
if callCount < 2 { if callCount < 2 {
...@@ -209,13 +184,11 @@ func TestAsyncLoadInitialLoadIndeterminateThenRequestFinishes(t *testing.T) { ...@@ -209,13 +184,11 @@ func TestAsyncLoadInitialLoadIndeterminateThenRequestFinishes(t *testing.T) {
link := testbridge.NewMockLink() link := testbridge.NewMockLink()
requestID := gsmsg.GraphSyncRequestID(rand.Int31()) requestID := gsmsg.GraphSyncRequestID(rand.Int31())
asyncLoader.StartRequest(requestID) asyncLoader.StartRequest(requestID)
responseChan, errChan := asyncLoader.AsyncLoad(requestID, link) resultChan := asyncLoader.AsyncLoad(requestID, link)
select { select {
case <-called: case <-called:
case <-responseChan: case <-resultChan:
t.Fatal("Should not have sent message on response chan") t.Fatal("Should not have sent message on response chan")
case <-errChan:
t.Fatal("Should not have sent messages on error chan")
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have attempted load once") t.Fatal("should have attempted load once")
} }
...@@ -223,22 +196,17 @@ func TestAsyncLoadInitialLoadIndeterminateThenRequestFinishes(t *testing.T) { ...@@ -223,22 +196,17 @@ func TestAsyncLoadInitialLoadIndeterminateThenRequestFinishes(t *testing.T) {
asyncLoader.NewResponsesAvailable() asyncLoader.NewResponsesAvailable()
select { select {
case _, ok := <-responseChan: case result := <-resultChan:
if ok { if result.Data != nil {
t.Fatal("should not have sent responses") t.Fatal("should not have sent responses")
} }
case <-ctx.Done(): if result.Err == nil {
t.Fatal("should have closed response channel")
}
select {
case _, ok := <-errChan:
if !ok {
t.Fatal("should have sent an error") t.Fatal("should have sent an error")
} }
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("should have closed error channel") t.Fatal("should have closed response channel")
} }
if callCount > 1 { if callCount > 1 {
t.Fatal("should only have attempted one call but attempted multiple") t.Fatal("should only have attempted one call but attempted multiple")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment