Commit 9b5a0e67 authored by Adin Schmahmann's avatar Adin Schmahmann Committed by Steven Allen

feat: calling FindProvidersAsync with a count of zero now completes the query

parent ac9f0335
......@@ -539,14 +539,22 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
// FindProvidersAsync is the same thing as FindProviders, but returns a channel.
// Peers will be returned on the channel as soon as they are found, even before
// the search query completes.
// the search query completes. If count is zero then the query will run until it
// completes. Note: not reading from the returned channel may block the query
// from progressing.
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
peerOut := make(chan peer.AddrInfo, count)
if !dht.enableProviders {
peerOut := make(chan peer.AddrInfo)
close(peerOut)
return peerOut
}
chSize := count
if count == 0 {
chSize = 1
}
peerOut := make(chan peer.AddrInfo, chSize)
keyMH := key.Hash()
logger.Event(ctx, "findProviders", multihashLoggableKey(keyMH))
......@@ -558,7 +566,14 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash
defer logger.EventBegin(ctx, "findProvidersAsync", multihashLoggableKey(key)).Done()
defer close(peerOut)
ps := peer.NewLimitedSet(count)
findAll := count == 0
var ps *peer.Set
if findAll {
ps = peer.NewSet()
} else {
ps = peer.NewLimitedSet(count)
}
provs := dht.ProviderManager.GetProviders(ctx, key)
for _, p := range provs {
// NOTE: Assuming that this list of peers is unique
......@@ -573,7 +588,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash
// If we have enough peers locally, don't bother with remote RPC
// TODO: is this a DOS vector?
if ps.Size() >= count {
if !findAll && ps.Size() >= count {
return
}
}
......@@ -610,7 +625,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash
return nil, ctx.Err()
}
}
if ps.Size() >= count {
if !findAll && ps.Size() >= count {
logger.Debugf("got enough providers (%d/%d)", ps.Size(), count)
return nil, nil
}
......@@ -630,7 +645,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash
return peers, nil
},
func(peerset *kpeerset.SortedPeerset) bool {
return ps.Size() >= count
return !findAll && ps.Size() >= count
},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment