Unverified Commit 52d6095e authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #87 from ipfs/fix/leak

fix(prq): fix a bunch of goroutine leaks and deadlocks
parents ee93aa83 04e47665
...@@ -124,17 +124,24 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, ...@@ -124,17 +124,24 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
inProgressRequestChan: inProgressRequestChan, inProgressRequestChan: inProgressRequestChan,
}: }:
case <-pqm.ctx.Done(): case <-pqm.ctx.Done():
return nil ch := make(chan peer.ID)
close(ch)
return ch
case <-sessionCtx.Done(): case <-sessionCtx.Done():
return nil ch := make(chan peer.ID)
close(ch)
return ch
} }
// DO NOT select on sessionCtx. We only want to abort here if we're
// shutting down because we can't actually _cancel_ the request till we
// get to receiveProviders.
var receivedInProgressRequest inProgressRequest var receivedInProgressRequest inProgressRequest
select { select {
case <-pqm.ctx.Done(): case <-pqm.ctx.Done():
return nil ch := make(chan peer.ID)
case <-sessionCtx.Done(): close(ch)
return nil return ch
case receivedInProgressRequest = <-inProgressRequestChan: case receivedInProgressRequest = <-inProgressRequestChan:
} }
...@@ -170,7 +177,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k ...@@ -170,7 +177,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
case <-pqm.ctx.Done(): case <-pqm.ctx.Done():
return return
case <-sessionCtx.Done(): case <-sessionCtx.Done():
pqm.cancelProviderRequest(k, incomingProviders) if incomingProviders != nil {
pqm.cancelProviderRequest(k, incomingProviders)
}
return return
case provider, ok := <-incomingProviders: case provider, ok := <-incomingProviders:
if !ok { if !ok {
...@@ -228,7 +237,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { ...@@ -228,7 +237,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
wg.Add(1) wg.Add(1)
go func(p peer.ID) { go func(p peer.ID) {
defer wg.Done() defer wg.Done()
err := pqm.network.ConnectTo(pqm.ctx, p) err := pqm.network.ConnectTo(findProviderCtx, p)
if err != nil { if err != nil {
log.Debugf("failed to connect to provider %s: %s", p, err) log.Debugf("failed to connect to provider %s: %s", p, err)
return return
...@@ -397,12 +406,12 @@ func (crm *cancelRequestMessage) debugMessage() string { ...@@ -397,12 +406,12 @@ func (crm *cancelRequestMessage) debugMessage() string {
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
if !ok { if !ok {
log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String()) // Request finished while queued.
return return
} }
_, ok = requestStatus.listeners[crm.incomingProviders] _, ok = requestStatus.listeners[crm.incomingProviders]
if !ok { if !ok {
log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String()) // Request finished and _restarted_ while queued.
return return
} }
delete(requestStatus.listeners, crm.incomingProviders) delete(requestStatus.listeners, crm.incomingProviders)
......
...@@ -304,3 +304,60 @@ func TestFindProviderTimeout(t *testing.T) { ...@@ -304,3 +304,60 @@ func TestFindProviderTimeout(t *testing.T) {
t.Fatal("Find provider request should have timed out, did not") t.Fatal("Find provider request should have timed out, did not")
} }
} }
func TestFindProviderPreCanceled(t *testing.T) {
peers := testutil.GeneratePeers(10)
fpn := &fakeProviderNetwork{
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond)
keys := testutil.GenerateCids(1)
sessionCtx, cancel := context.WithCancel(ctx)
cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
if firstRequestChan == nil {
t.Fatal("expected non-nil channel")
}
select {
case <-firstRequestChan:
case <-time.After(10 * time.Millisecond):
t.Fatal("shouldn't have blocked waiting on a closed context")
}
}
func TestCancelFindProvidersAfterCompletion(t *testing.T) {
peers := testutil.GeneratePeers(2)
fpn := &fakeProviderNetwork{
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(100 * time.Millisecond)
keys := testutil.GenerateCids(1)
sessionCtx, cancel := context.WithCancel(ctx)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
<-firstRequestChan // wait for everything to start.
time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop.
cancel() // cancel the context.
timer := time.NewTimer(10 * time.Millisecond)
defer timer.Stop()
for {
select {
case _, ok := <-firstRequestChan:
if !ok {
return
}
case <-timer.C:
t.Fatal("should have finished receiving responses within timeout")
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment