Commit 92717dbb authored by hannahhoward's avatar hannahhoward

refactor(providerquerymanager): don't use session ids

removed session id user completely from providerquerymanager
parent 56d9e3fc
......@@ -21,7 +21,7 @@ const (
type inProgressRequestStatus struct {
providersSoFar []peer.ID
listeners map[uint64]chan peer.ID
listeners map[chan peer.ID]struct{}
}
// ProviderQueryNetwork is an interface for finding providers and connecting to
......@@ -46,14 +46,13 @@ type finishedProviderQueryMessage struct {
}
type newProvideQueryMessage struct {
ses uint64
k cid.Cid
inProgressRequestChan chan<- inProgressRequest
}
type cancelRequestMessage struct {
ses uint64
k cid.Cid
incomingProviders chan peer.ID
k cid.Cid
}
// ProviderQueryManager manages requests to find more providers for blocks
......@@ -98,7 +97,7 @@ func (pqm *ProviderQueryManager) Startup() {
type inProgressRequest struct {
providersSoFar []peer.ID
incoming <-chan peer.ID
incoming chan peer.ID
}
// SetFindProviderTimeout changes the timeout for finding providers
......@@ -109,12 +108,11 @@ func (pqm *ProviderQueryManager) SetFindProviderTimeout(findProviderTimeout time
}
// FindProvidersAsync finds providers for the given block.
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, ses uint64) <-chan peer.ID {
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.ID {
inProgressRequestChan := make(chan inProgressRequest)
select {
case pqm.providerQueryMessages <- &newProvideQueryMessage{
ses: ses,
k: k,
inProgressRequestChan: inProgressRequestChan,
}:
......@@ -131,10 +129,10 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
case receivedInProgressRequest = <-inProgressRequestChan:
}
return pqm.receiveProviders(sessionCtx, k, ses, receivedInProgressRequest)
return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest)
}
func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, ses uint64, receivedInProgressRequest inProgressRequest) <-chan peer.ID {
func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest) <-chan peer.ID {
// maintains an unbuffered queue for incoming providers for given request for a given session
// essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all
// sessions that queried that CID, without worrying about whether the client code is actually
......@@ -162,8 +160,8 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
select {
case <-sessionCtx.Done():
pqm.providerQueryMessages <- &cancelRequestMessage{
ses: ses,
k: k,
incomingProviders: incomingProviders,
k: k,
}
// clear out any remaining providers
for range incomingProviders {
......@@ -269,7 +267,7 @@ func (pqm *ProviderQueryManager) providerRequestBufferWorker() {
func (pqm *ProviderQueryManager) cleanupInProcessRequests() {
for _, requestStatus := range pqm.inProgressRequestStatuses {
for _, listener := range requestStatus.listeners {
for listener := range requestStatus.listeners {
close(listener)
}
}
......@@ -305,7 +303,7 @@ func (rpm *receivedProviderMessage) handle(pqm *ProviderQueryManager) {
return
}
requestStatus.providersSoFar = append(requestStatus.providersSoFar, rpm.p)
for _, listener := range requestStatus.listeners {
for listener := range requestStatus.listeners {
select {
case listener <- rpm.p:
case <-pqm.ctx.Done():
......@@ -324,21 +322,21 @@ func (fpqm *finishedProviderQueryMessage) handle(pqm *ProviderQueryManager) {
log.Errorf("Ended request for cid (%s) not in progress", fpqm.k.String())
return
}
for _, listener := range requestStatus.listeners {
for listener := range requestStatus.listeners {
close(listener)
}
delete(pqm.inProgressRequestStatuses, fpqm.k)
}
func (npqm *newProvideQueryMessage) debugMessage() string {
return fmt.Sprintf("New Provider Query on cid: %s from session: %d", npqm.k.String(), npqm.ses)
return fmt.Sprintf("New Provider Query on cid: %s", npqm.k.String())
}
func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[npqm.k]
if !ok {
requestStatus = &inProgressRequestStatus{
listeners: make(map[uint64]chan peer.ID),
listeners: make(map[chan peer.ID]struct{}),
}
pqm.inProgressRequestStatuses[npqm.k] = requestStatus
select {
......@@ -347,31 +345,32 @@ func (npqm *newProvideQueryMessage) handle(pqm *ProviderQueryManager) {
return
}
}
requestStatus.listeners[npqm.ses] = make(chan peer.ID)
inProgressChan := make(chan peer.ID)
requestStatus.listeners[inProgressChan] = struct{}{}
select {
case npqm.inProgressRequestChan <- inProgressRequest{
providersSoFar: requestStatus.providersSoFar,
incoming: requestStatus.listeners[npqm.ses],
incoming: inProgressChan,
}:
case <-pqm.ctx.Done():
}
}
func (crm *cancelRequestMessage) debugMessage() string {
return fmt.Sprintf("Cancel provider query on cid: %s from session: %d", crm.k.String(), crm.ses)
return fmt.Sprintf("Cancel provider query on cid: %s", crm.k.String())
}
func (crm *cancelRequestMessage) handle(pqm *ProviderQueryManager) {
requestStatus, ok := pqm.inProgressRequestStatuses[crm.k]
if !ok {
log.Errorf("Attempt to cancel request for session (%d) for cid (%s) not in progress", crm.ses, crm.k.String())
log.Errorf("Attempt to cancel request for cid (%s) not in progress", crm.k.String())
return
}
listener, ok := requestStatus.listeners[crm.ses]
listener := crm.incomingProviders
if !ok {
log.Errorf("Attempt to cancel request for session (%d) for cid (%s) this is not a listener", crm.ses, crm.k.String())
log.Errorf("Attempt to cancel request for for cid (%s) this is not a listener", crm.k.String())
return
}
close(listener)
delete(requestStatus.listeners, crm.ses)
delete(requestStatus.listeners, listener)
}
......@@ -62,13 +62,11 @@ func TestNormalSimultaneousFetch(t *testing.T) {
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
keys := testutil.GenerateCids(2)
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], sessionID2)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1])
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
......@@ -102,13 +100,11 @@ func TestDedupingProviderRequests(t *testing.T) {
providerQueryManager := New(ctx, fpn)
providerQueryManager.Startup()
key := testutil.GenerateCids(1)[0]
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
......@@ -145,16 +141,14 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
providerQueryManager.Startup()
key := testutil.GenerateCids(1)[0]
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()
// first session will cancel before done
firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond)
defer firstCancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, sessionID1)
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key)
secondSessionCtx, secondCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer secondCancel()
secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, sessionID2)
secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key)
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
......@@ -193,13 +187,11 @@ func TestCancelManagerExitsGracefully(t *testing.T) {
providerQueryManager.Startup()
key := testutil.GenerateCids(1)[0]
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
......@@ -229,13 +221,11 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {
providerQueryManager.Startup()
key := testutil.GenerateCids(1)[0]
sessionID1 := testutil.GenerateSessionID()
sessionID2 := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID1)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, sessionID2)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
......@@ -266,12 +256,11 @@ func TestRateLimitingRequests(t *testing.T) {
providerQueryManager.Startup()
keys := testutil.GenerateCids(maxInProcessRequests + 1)
sessionID := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
var requestChannels []<-chan peer.ID
for i := 0; i < maxInProcessRequests+1; i++ {
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], sessionID))
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i]))
}
time.Sleep(9 * time.Millisecond)
fpn.queriesMadeMutex.Lock()
......@@ -303,11 +292,10 @@ func TestFindProviderTimeout(t *testing.T) {
providerQueryManager.Startup()
providerQueryManager.SetFindProviderTimeout(2 * time.Millisecond)
keys := testutil.GenerateCids(1)
sessionID1 := testutil.GenerateSessionID()
sessionCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], sessionID1)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
var firstPeersReceived []peer.ID
for p := range firstRequestChan {
firstPeersReceived = append(firstPeersReceived, p)
......
......@@ -26,7 +26,7 @@ type PeerTagger interface {
// PeerProviderFinder is an interface for finding providers
type PeerProviderFinder interface {
FindProvidersAsync(context.Context, cid.Cid, uint64) <-chan peer.ID
FindProvidersAsync(context.Context, cid.Cid) <-chan peer.ID
}
type peerMessage interface {
......@@ -108,8 +108,8 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID {
// providers for the given Cid
func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) {
for p := range spm.providerFinder.FindProvidersAsync(ctx, k, spm.id) {
for p := range spm.providerFinder.FindProvidersAsync(ctx, k) {
select {
case spm.peerMessages <- &peerFoundMessage{p}:
case <-ctx.Done():
......
......@@ -18,7 +18,7 @@ type fakePeerProviderFinder struct {
completed chan struct{}
}
func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid, ses uint64) <-chan peer.ID {
func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid) <-chan peer.ID {
peerCh := make(chan peer.ID)
go func() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment