Unverified Commit 6e60b859 authored by dirkmc's avatar dirkmc Committed by GitHub

feat: fire network error when network disconnects during request (#164)

parent 611735cb
......@@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"github.com/hannahhoward/go-pubsub"
"golang.org/x/xerrors"
"sync/atomic"
blocks "github.com/ipfs/go-block-format"
......@@ -70,6 +72,7 @@ type RequestManager struct {
peerHandler PeerHandler
rc *responseCollector
asyncLoader AsyncLoader
disconnectNotif *pubsub.PubSub
// dont touch out side of run loop
nextRequestID graphsync.RequestID
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
......@@ -111,6 +114,7 @@ func New(ctx context.Context,
ctx: ctx,
cancel: cancel,
asyncLoader: asyncLoader,
disconnectNotif: pubsub.New(disconnectDispatcher),
rc: newResponseCollector(ctx),
messages: make(chan requestManagerMessage, 16),
inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus),
......@@ -128,6 +132,7 @@ func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) {
type inProgressRequest struct {
requestID graphsync.RequestID
request gsmsg.GraphSyncRequest
incoming chan graphsync.ResponseProgress
incomingError chan error
}
......@@ -166,6 +171,11 @@ func (rm *RequestManager) SendRequest(ctx context.Context,
case receivedInProgressRequest = <-inProgressRequestChan:
}
// If the connection to the peer is disconnected, fire an error
unsub := rm.listenForDisconnect(p, func(neterr error) {
rm.networkErrorListeners.NotifyNetworkErrorListeners(p, receivedInProgressRequest.request, neterr)
})
return rm.rc.collectResponses(ctx,
receivedInProgressRequest.incoming,
receivedInProgressRequest.incomingError,
......@@ -173,7 +183,34 @@ func (rm *RequestManager) SendRequest(ctx context.Context,
rm.cancelRequest(receivedInProgressRequest.requestID,
receivedInProgressRequest.incoming,
receivedInProgressRequest.incomingError)
})
},
// Once the request has completed, stop listening for disconnect events
unsub,
)
}
// Dispatch the Disconnect event to subscribers
func disconnectDispatcher(p pubsub.Event, subscriberFn pubsub.SubscriberFn) error {
listener := subscriberFn.(func(peer.ID))
listener(p.(peer.ID))
return nil
}
// Listen for the Disconnect event for the given peer
func (rm *RequestManager) listenForDisconnect(p peer.ID, onDisconnect func(neterr error)) func() {
// Subscribe to Disconnect notifications
return rm.disconnectNotif.Subscribe(func(evtPeer peer.ID) {
// If the peer is the one we're interested in, call the listener
if evtPeer == p {
onDisconnect(xerrors.Errorf("disconnected from peer %s", p))
}
})
}
// Disconnected is called when a peer disconnects
func (rm *RequestManager) Disconnected(p peer.ID) {
// Notify any listeners that a peer has disconnected
rm.disconnectNotif.Publish(p)
}
func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) {
......@@ -311,17 +348,19 @@ type terminateRequestMessage struct {
requestID graphsync.RequestID
}
func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (chan graphsync.ResponseProgress, chan error) {
func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) {
request, hooksResult, err := rm.validateRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
if err != nil {
return rm.singleErrorResponse(err)
rp, err := rm.singleErrorResponse(err)
return request, rp, err
}
doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs)
var doNotSendCids *cid.Set
if has {
doNotSendCids, err = cidset.DecodeCidSet(doNotSendCidsData)
if err != nil {
return rm.singleErrorResponse(err)
rp, err := rm.singleErrorResponse(err)
return request, rp, err
}
} else {
doNotSendCids = cid.NewSet()
......@@ -355,14 +394,14 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
ResumeMessages: resumeMessages,
PauseMessages: pauseMessages,
})
return incoming, incomingError
return request, incoming, incomingError
}
func (nrm *newRequestMessage) handle(rm *RequestManager) {
var ipr inProgressRequest
ipr.requestID = rm.nextRequestID
rm.nextRequestID++
ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)
ipr.request, ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)
select {
case nrm.inProgressRequestChan <- ipr:
......
......@@ -352,6 +352,42 @@ func TestRequestReturnsMissingBlocks(t *testing.T) {
require.NotEqual(t, len(errs), 0, "did not send errors")
}
func TestDisconnectNotification(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
requestCtx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
peers := testutil.GeneratePeers(2)
// Listen for network errors
networkErrors := make(chan peer.ID, 1)
td.networkErrorListeners.Register(func(p peer.ID, request graphsync.RequestData, err error) {
networkErrors <- p
})
// Send a request to the target peer
targetPeer := peers[0]
td.requestManager.SendRequest(requestCtx, targetPeer, td.blockChain.TipLink, td.blockChain.Selector())
// Disconnect a random peer, should not fire any events
randomPeer := peers[1]
td.requestManager.Disconnected(randomPeer)
select {
case <-networkErrors:
t.Fatal("should not fire network error when unrelated peer disconnects")
default:
}
// Disconnect the target peer, should fire a network error
td.requestManager.Disconnected(targetPeer)
select {
case p:= <-networkErrors:
require.Equal(t, p, targetPeer)
default:
t.Fatal("should fire network error when peer disconnects")
}
}
func TestEncodingExtensions(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
......
......@@ -18,7 +18,9 @@ func (rc *responseCollector) collectResponses(
requestCtx context.Context,
incomingResponses <-chan graphsync.ResponseProgress,
incomingErrors <-chan error,
cancelRequest func()) (<-chan graphsync.ResponseProgress, <-chan error) {
cancelRequest func(),
onComplete func(),
) (<-chan graphsync.ResponseProgress, <-chan error) {
returnedResponses := make(chan graphsync.ResponseProgress)
returnedErrors := make(chan error)
......@@ -26,6 +28,7 @@ func (rc *responseCollector) collectResponses(
go func() {
var receivedResponses []graphsync.ResponseProgress
defer close(returnedResponses)
defer onComplete()
outgoingResponses := func() chan<- graphsync.ResponseProgress {
if len(receivedResponses) == 0 {
return nil
......
......@@ -26,7 +26,7 @@ func TestBufferingResponseProgress(t *testing.T) {
cancelRequest := func() {}
outgoingResponses, outgoingErrors := rc.collectResponses(
requestCtx, incomingResponses, incomingErrors, cancelRequest)
requestCtx, incomingResponses, incomingErrors, cancelRequest, func(){})
blockStore := make(map[ipld.Link][]byte)
loader, storer := testutil.NewTestStore(blockStore)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment