Commit 30f40ece authored by hannahhoward's avatar hannahhoward

fix(providerquerymanager): minor channel cleanup

Keep channels unblocked in cancelling request -- refactored to function. Also cancel find provider
context as soon as it can be.
parent b48b3c33
...@@ -170,22 +170,8 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k ...@@ -170,22 +170,8 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
case <-pqm.ctx.Done(): case <-pqm.ctx.Done():
return return
case <-sessionCtx.Done(): case <-sessionCtx.Done():
pqm.providerQueryMessages <- &cancelRequestMessage{ pqm.cancelProviderRequest(k, incomingProviders)
incomingProviders: incomingProviders,
k: k,
}
// clear out any remaining providers, in case and "incoming provider"
// messages get processed before our cancel message
for {
select {
case _, ok := <-incomingProviders:
if !ok {
return
}
case <-pqm.ctx.Done():
return return
}
}
case provider, ok := <-incomingProviders: case provider, ok := <-incomingProviders:
if !ok { if !ok {
incomingProviders = nil incomingProviders = nil
...@@ -200,6 +186,27 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k ...@@ -200,6 +186,27 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
return returnedProviders return returnedProviders
} }
func (pqm *ProviderQueryManager) cancelProviderRequest(k cid.Cid, incomingProviders chan peer.ID) {
cancelMessageChannel := pqm.providerQueryMessages
for {
select {
case cancelMessageChannel <- &cancelRequestMessage{
incomingProviders: incomingProviders,
k: k,
}:
cancelMessageChannel = nil
// clear out any remaining providers, in case and "incoming provider"
// messages get processed before our cancel message
case _, ok := <-incomingProviders:
if !ok {
return
}
case <-pqm.ctx.Done():
return
}
}
}
func (pqm *ProviderQueryManager) findProviderWorker() { func (pqm *ProviderQueryManager) findProviderWorker() {
// findProviderWorker just cycles through incoming provider queries one // findProviderWorker just cycles through incoming provider queries one
// at a time. We have six of these workers running at once // at a time. We have six of these workers running at once
...@@ -215,7 +222,6 @@ func (pqm *ProviderQueryManager) findProviderWorker() { ...@@ -215,7 +222,6 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
pqm.timeoutMutex.RLock() pqm.timeoutMutex.RLock()
findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout) findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout)
pqm.timeoutMutex.RUnlock() pqm.timeoutMutex.RUnlock()
defer cancel()
providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders) providers := pqm.network.FindProvidersAsync(findProviderCtx, k, maxProviders)
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for p := range providers { for p := range providers {
...@@ -237,6 +243,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() { ...@@ -237,6 +243,7 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
} }
}(p) }(p)
} }
cancel()
wg.Wait() wg.Wait()
select { select {
case pqm.providerQueryMessages <- &finishedProviderQueryMessage{ case pqm.providerQueryMessages <- &finishedProviderQueryMessage{
...@@ -389,19 +396,19 @@ func (crm *cancelRequestMessage) debugMessage() string { ...@@ -389,19 +396,19 @@ func (crm *cancelRequestMessage) debugMessage() string {
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) { func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[crm.k] requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
if ok { if !ok {
_, ok := requestStatus.listeners[crm.incomingProviders] log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String())
if ok { return
}
_, ok = requestStatus.listeners[crm.incomingProviders]
if !ok {
log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String())
return
}
delete(requestStatus.listeners, crm.incomingProviders) delete(requestStatus.listeners, crm.incomingProviders)
close(crm.incomingProviders)
if len(requestStatus.listeners) == 0 { if len(requestStatus.listeners) == 0 {
delete(pqm.inProgressRequestStatuses, crm.k) delete(pqm.inProgressRequestStatuses, crm.k)
requestStatus.cancelFn() requestStatus.cancelFn()
} }
} else {
log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String())
}
} else {
log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String())
}
close(crm.incomingProviders)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment