Commit 960f6971 authored by Dirk McCormick's avatar Dirk McCormick

refactor: simplify session peer management

parent 0ba089b4
...@@ -148,10 +148,10 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -148,10 +148,10 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
provSearchDelay time.Duration, provSearchDelay time.Duration,
rebroadcastDelay delay.D, rebroadcastDelay delay.D,
self peer.ID) bssm.Session { self peer.ID) bssm.Session {
return bssession.New(ctx, id, wm, spm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self) return bssession.New(ctx, id, wm, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self)
} }
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.SessionPeerManager { sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.SessionPeerManager {
return bsspm.New(ctx, id, network.ConnectionManager(), pqm) return bsspm.New(id, network.ConnectionManager())
} }
notif := notifications.New() notif := notifications.New()
sm := bssm.New(ctx, sessionFactory, sim, sessionPeerManagerFactory, bpm, pm, notif, network.Self()) sm := bssm.New(ctx, sessionFactory, sim, sessionPeerManagerFactory, bpm, pm, notif, network.Self())
......
package session
import (
peer "github.com/libp2p/go-libp2p-core/peer"
)
// peerAvailabilityManager keeps track of which peers have available space
// to receive want requests
type peerAvailabilityManager struct {
peerAvailable map[peer.ID]bool
}
func newPeerAvailabilityManager() *peerAvailabilityManager {
return &peerAvailabilityManager{
peerAvailable: make(map[peer.ID]bool),
}
}
func (pam *peerAvailabilityManager) addPeer(p peer.ID) {
pam.peerAvailable[p] = false
}
func (pam *peerAvailabilityManager) isAvailable(p peer.ID) (bool, bool) {
is, ok := pam.peerAvailable[p]
return is, ok
}
func (pam *peerAvailabilityManager) setPeerAvailability(p peer.ID, isAvailable bool) {
pam.peerAvailable[p] = isAvailable
}
func (pam *peerAvailabilityManager) haveAvailablePeers() bool {
for _, isAvailable := range pam.peerAvailable {
if isAvailable {
return true
}
}
return false
}
func (pam *peerAvailabilityManager) availablePeers() []peer.ID {
var available []peer.ID
for p, isAvailable := range pam.peerAvailable {
if isAvailable {
available = append(available, p)
}
}
return available
}
func (pam *peerAvailabilityManager) allPeers() []peer.ID {
var available []peer.ID
for p := range pam.peerAvailable {
available = append(available, p)
}
return available
}
package session
import (
"testing"
"github.com/ipfs/go-bitswap/internal/testutil"
)
func TestPeerAvailabilityManager(t *testing.T) {
peers := testutil.GeneratePeers(2)
pam := newPeerAvailabilityManager()
isAvailable, ok := pam.isAvailable(peers[0])
if isAvailable || ok {
t.Fatal("expected not to have any availability yet")
}
if pam.haveAvailablePeers() {
t.Fatal("expected not to have any availability yet")
}
pam.addPeer(peers[0])
isAvailable, ok = pam.isAvailable(peers[0])
if !ok {
t.Fatal("expected to have a peer")
}
if isAvailable {
t.Fatal("expected not to have any availability yet")
}
if pam.haveAvailablePeers() {
t.Fatal("expected not to have any availability yet")
}
if len(pam.availablePeers()) != 0 {
t.Fatal("expected not to have any availability yet")
}
if len(pam.allPeers()) != 1 {
t.Fatal("expected one peer")
}
pam.setPeerAvailability(peers[0], true)
isAvailable, ok = pam.isAvailable(peers[0])
if !ok {
t.Fatal("expected to have a peer")
}
if !isAvailable {
t.Fatal("expected peer to be available")
}
if !pam.haveAvailablePeers() {
t.Fatal("expected peer to be available")
}
if len(pam.availablePeers()) != 1 {
t.Fatal("expected peer to be available")
}
if len(pam.allPeers()) != 1 {
t.Fatal("expected one peer")
}
pam.addPeer(peers[1])
if len(pam.availablePeers()) != 1 {
t.Fatal("expected one peer to be available")
}
if len(pam.allPeers()) != 2 {
t.Fatal("expected two peers")
}
pam.setPeerAvailability(peers[0], false)
isAvailable, ok = pam.isAvailable(peers[0])
if !ok {
t.Fatal("expected to have a peer")
}
if isAvailable {
t.Fatal("expected peer to not be available")
}
}
...@@ -2,7 +2,6 @@ package session ...@@ -2,7 +2,6 @@ package session
import ( import (
"context" "context"
"sync"
"time" "time"
// lu "github.com/ipfs/go-bitswap/internal/logutil" // lu "github.com/ipfs/go-bitswap/internal/logutil"
...@@ -49,23 +48,26 @@ type PeerManager interface { ...@@ -49,23 +48,26 @@ type PeerManager interface {
SendWants(ctx context.Context, peerId peer.ID, wantBlocks []cid.Cid, wantHaves []cid.Cid) SendWants(ctx context.Context, peerId peer.ID, wantBlocks []cid.Cid, wantHaves []cid.Cid)
} }
// PeerManager provides an interface for tracking and optimize peers, and // SessionPeerManager keeps track of peers in the session
// requesting more when neccesary.
type SessionPeerManager interface { type SessionPeerManager interface {
// ReceiveFrom is called when blocks and HAVEs are received from a peer. // PeersDiscovered indicates if any peers have been discovered yet
// It returns a boolean indicating if the peer is new to the session. PeersDiscovered() bool
ReceiveFrom(peerId peer.ID, blks []cid.Cid, haves []cid.Cid) bool // Shutdown the SessionPeerManager
// Peers returns the set of peers in the session. Shutdown()
Peers() *peer.Set // Adds a peer to the session, returning true if the peer is new
// FindMorePeers queries Content Routing to discover providers of the given cid AddPeer(peer.ID) bool
FindMorePeers(context.Context, cid.Cid) // Removes a peer from the session, returning true if the peer existed
// RecordPeerRequests records the time that a cid was requested from a peer RemovePeer(peer.ID) bool
RecordPeerRequests([]peer.ID, []cid.Cid) // All peers in the session
// RecordPeerResponse records the time that a response for a cid arrived Peers() []peer.ID
// from a peer // Whether there are any peers in the session
RecordPeerResponse(peer.ID, []cid.Cid) HasPeers() bool
// RecordCancels records that cancels were sent for the given cids }
RecordCancels([]cid.Cid)
// ProviderFinder is used to find providers for a given key
type ProviderFinder interface {
// FindProvidersAsync searches for peers that provide the given CID
FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.ID
} }
// opType is the kind of operation that is being processed by the event loop // opType is the kind of operation that is being processed by the event loop
...@@ -80,6 +82,8 @@ const ( ...@@ -80,6 +82,8 @@ const (
opCancel opCancel
// Broadcast want-haves // Broadcast want-haves
opBroadcast opBroadcast
// Wants sent to peers
opWantsSent
) )
type op struct { type op struct {
...@@ -92,10 +96,11 @@ type op struct { ...@@ -92,10 +96,11 @@ type op struct {
// info to, and who to request blocks from. // info to, and who to request blocks from.
type Session struct { type Session struct {
// dependencies // dependencies
ctx context.Context ctx context.Context
wm WantManager wm WantManager
sprm SessionPeerManager sprm SessionPeerManager
sim *bssim.SessionInterestManager providerFinder ProviderFinder
sim *bssim.SessionInterestManager
sw sessionWants sw sessionWants
sws sessionWantSender sws sessionWantSender
...@@ -127,6 +132,7 @@ func New(ctx context.Context, ...@@ -127,6 +132,7 @@ func New(ctx context.Context,
id uint64, id uint64,
wm WantManager, wm WantManager,
sprm SessionPeerManager, sprm SessionPeerManager,
providerFinder ProviderFinder,
sim *bssim.SessionInterestManager, sim *bssim.SessionInterestManager,
pm PeerManager, pm PeerManager,
bpm *bsbpm.BlockPresenceManager, bpm *bsbpm.BlockPresenceManager,
...@@ -140,6 +146,7 @@ func New(ctx context.Context, ...@@ -140,6 +146,7 @@ func New(ctx context.Context,
ctx: ctx, ctx: ctx,
wm: wm, wm: wm,
sprm: sprm, sprm: sprm,
providerFinder: providerFinder,
sim: sim, sim: sim,
incoming: make(chan op, 128), incoming: make(chan op, 128),
latencyTrkr: latencyTracker{}, latencyTrkr: latencyTracker{},
...@@ -151,7 +158,7 @@ func New(ctx context.Context, ...@@ -151,7 +158,7 @@ func New(ctx context.Context,
periodicSearchDelay: periodicSearchDelay, periodicSearchDelay: periodicSearchDelay,
self: self, self: self,
} }
s.sws = newSessionWantSender(ctx, id, pm, bpm, s.onWantsSent, s.onPeersExhausted) s.sws = newSessionWantSender(ctx, id, pm, sprm, bpm, s.onWantsSent, s.onPeersExhausted)
go s.run(ctx) go s.run(ctx)
...@@ -164,44 +171,25 @@ func (s *Session) ID() uint64 { ...@@ -164,44 +171,25 @@ func (s *Session) ID() uint64 {
// ReceiveFrom receives incoming blocks from the given peer. // ReceiveFrom receives incoming blocks from the given peer.
func (s *Session) ReceiveFrom(from peer.ID, ks []cid.Cid, haves []cid.Cid, dontHaves []cid.Cid) { func (s *Session) ReceiveFrom(from peer.ID, ks []cid.Cid, haves []cid.Cid, dontHaves []cid.Cid) {
// The SessionManager tells each Session about all keys that it may be
// interested in. Here the Session filters the keys to the ones that this
// particular Session is interested in.
interestedRes := s.sim.FilterSessionInterested(s.id, ks, haves, dontHaves) interestedRes := s.sim.FilterSessionInterested(s.id, ks, haves, dontHaves)
ks = interestedRes[0] ks = interestedRes[0]
haves = interestedRes[1] haves = interestedRes[1]
dontHaves = interestedRes[2] dontHaves = interestedRes[2]
// s.logReceiveFrom(from, ks, haves, dontHaves) // s.logReceiveFrom(from, ks, haves, dontHaves)
// Add any newly discovered peers that have blocks we're interested in to // Inform the session want sender that a message has been received
// the peer set s.sws.Update(from, ks, haves, dontHaves)
isNewPeer := s.sprm.ReceiveFrom(from, ks, haves)
// Record response timing only if the blocks came from the network
// (blocks can also be received from the local node)
if len(ks) > 0 && from != "" {
s.sprm.RecordPeerResponse(from, ks)
}
// Update want potential
s.sws.Update(from, ks, haves, dontHaves, isNewPeer)
if len(ks) == 0 { if len(ks) == 0 {
return return
} }
// Record which blocks have been received and figure out the total latency // Inform the session that blocks have been received
// for fetching the blocks
wanted, totalLatency := s.sw.BlocksReceived(ks)
s.latencyTrkr.receiveUpdate(len(wanted), totalLatency)
if len(wanted) == 0 {
return
}
// Inform the SessionInterestManager that this session is no longer
// expecting to receive the wanted keys
s.sim.RemoveSessionWants(s.id, wanted)
select { select {
case s.incoming <- op{op: opReceive, keys: wanted}: case s.incoming <- op{op: opReceive, keys: ks}:
case <-s.ctx.Done(): case <-s.ctx.Done():
} }
} }
...@@ -220,28 +208,6 @@ func (s *Session) ReceiveFrom(from peer.ID, ks []cid.Cid, haves []cid.Cid, dontH ...@@ -220,28 +208,6 @@ func (s *Session) ReceiveFrom(from peer.ID, ks []cid.Cid, haves []cid.Cid, dontH
// } // }
// } // }
func (s *Session) onWantsSent(p peer.ID, wantBlocks []cid.Cid, wantHaves []cid.Cid) {
allBlks := append(wantBlocks[:len(wantBlocks):len(wantBlocks)], wantHaves...)
s.sw.WantsSent(allBlks)
s.sprm.RecordPeerRequests([]peer.ID{p}, allBlks)
}
func (s *Session) onPeersExhausted(ks []cid.Cid) {
// We don't want to block the sessionWantSender if the incoming channel
// is full. So if we can't immediately send on the incoming channel spin
// it off into a go-routine.
select {
case s.incoming <- op{op: opBroadcast, keys: ks}:
default:
go func() {
select {
case s.incoming <- op{op: opBroadcast, keys: ks}:
case <-s.ctx.Done():
}
}()
}
}
// GetBlock fetches a single block. // GetBlock fetches a single block.
func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, error) { func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, error) {
return bsgetter.SyncGetBlock(parent, k, s.GetBlocks) return bsgetter.SyncGetBlock(parent, k, s.GetBlocks)
...@@ -278,6 +244,34 @@ func (s *Session) SetBaseTickDelay(baseTickDelay time.Duration) { ...@@ -278,6 +244,34 @@ func (s *Session) SetBaseTickDelay(baseTickDelay time.Duration) {
} }
} }
// onWantsSent is called when wants are sent to a peer by the session wants sender
func (s *Session) onWantsSent(p peer.ID, wantBlocks []cid.Cid, wantHaves []cid.Cid) {
allBlks := append(wantBlocks[:len(wantBlocks):len(wantBlocks)], wantHaves...)
s.nonBlockingEnqueue(op{op: opWantsSent, keys: allBlks})
}
// onPeersExhausted is called when all available peers have sent DONT_HAVE for
// a set of cids (or all peers become unavailable)
func (s *Session) onPeersExhausted(ks []cid.Cid) {
s.nonBlockingEnqueue(op{op: opBroadcast, keys: ks})
}
// We don't want to block the sessionWantSender if the incoming channel
// is full. So if we can't immediately send on the incoming channel spin
// it off into a go-routine.
func (s *Session) nonBlockingEnqueue(o op) {
select {
case s.incoming <- o:
default:
go func() {
select {
case s.incoming <- o:
case <-s.ctx.Done():
}
}()
}
}
// Session run loop -- everything in this function should not be called // Session run loop -- everything in this function should not be called
// outside of this loop // outside of this loop
func (s *Session) run(ctx context.Context) { func (s *Session) run(ctx context.Context) {
...@@ -290,23 +284,34 @@ func (s *Session) run(ctx context.Context) { ...@@ -290,23 +284,34 @@ func (s *Session) run(ctx context.Context) {
case oper := <-s.incoming: case oper := <-s.incoming:
switch oper.op { switch oper.op {
case opReceive: case opReceive:
// Received blocks
s.handleReceive(oper.keys) s.handleReceive(oper.keys)
case opWant: case opWant:
// Client wants blocks
s.wantBlocks(ctx, oper.keys) s.wantBlocks(ctx, oper.keys)
case opCancel: case opCancel:
// Wants were cancelled
s.sw.CancelPending(oper.keys) s.sw.CancelPending(oper.keys)
case opWantsSent:
// Wants were sent to a peer
s.sw.WantsSent(oper.keys)
case opBroadcast: case opBroadcast:
// Broadcast want-haves to all peers
s.broadcastWantHaves(ctx, oper.keys) s.broadcastWantHaves(ctx, oper.keys)
default: default:
panic("unhandled operation") panic("unhandled operation")
} }
case <-s.idleTick.C: case <-s.idleTick.C:
// The session hasn't received blocks for a while, broadcast
s.broadcastWantHaves(ctx, nil) s.broadcastWantHaves(ctx, nil)
case <-s.periodicSearchTimer.C: case <-s.periodicSearchTimer.C:
// Periodically search for a random live want
s.handlePeriodicSearch(ctx) s.handlePeriodicSearch(ctx)
case baseTickDelay := <-s.tickDelayReqs: case baseTickDelay := <-s.tickDelayReqs:
// Set the base tick delay
s.baseTickDelay = baseTickDelay s.baseTickDelay = baseTickDelay
case <-ctx.Done(): case <-ctx.Done():
// Shutdown
s.handleShutdown() s.handleShutdown()
return return
} }
...@@ -327,7 +332,6 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) { ...@@ -327,7 +332,6 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) {
// log.Infof("Ses%d: broadcast %d keys\n", s.id, len(live)) // log.Infof("Ses%d: broadcast %d keys\n", s.id, len(live))
// Broadcast a want-have for the live wants to everyone we're connected to // Broadcast a want-have for the live wants to everyone we're connected to
s.sprm.RecordPeerRequests(nil, wants)
s.wm.BroadcastWantHaves(ctx, s.id, wants) s.wm.BroadcastWantHaves(ctx, s.id, wants)
// do not find providers on consecutive ticks // do not find providers on consecutive ticks
...@@ -337,7 +341,7 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) { ...@@ -337,7 +341,7 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) {
// Typically if the provider has the first block they will have // Typically if the provider has the first block they will have
// the rest of the blocks also. // the rest of the blocks also.
log.Warnf("Ses%d: FindMorePeers with want 0 of %d wants", s.id, len(wants)) log.Warnf("Ses%d: FindMorePeers with want 0 of %d wants", s.id, len(wants))
s.sprm.FindMorePeers(ctx, wants[0]) s.findMorePeers(ctx, wants[0])
} }
s.resetIdleTick() s.resetIdleTick()
...@@ -347,6 +351,8 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) { ...@@ -347,6 +351,8 @@ func (s *Session) broadcastWantHaves(ctx context.Context, wants []cid.Cid) {
} }
} }
// handlePeriodicSearch is called periodically to search for providers of a
// randomly chosen CID in the sesssion.
func (s *Session) handlePeriodicSearch(ctx context.Context) { func (s *Session) handlePeriodicSearch(ctx context.Context) {
randomWant := s.sw.RandomLiveWant() randomWant := s.sw.RandomLiveWant()
if !randomWant.Defined() { if !randomWant.Defined() {
...@@ -355,40 +361,74 @@ func (s *Session) handlePeriodicSearch(ctx context.Context) { ...@@ -355,40 +361,74 @@ func (s *Session) handlePeriodicSearch(ctx context.Context) {
// TODO: come up with a better strategy for determining when to search // TODO: come up with a better strategy for determining when to search
// for new providers for blocks. // for new providers for blocks.
s.sprm.FindMorePeers(ctx, randomWant) s.findMorePeers(ctx, randomWant)
s.wm.BroadcastWantHaves(ctx, s.id, []cid.Cid{randomWant}) s.wm.BroadcastWantHaves(ctx, s.id, []cid.Cid{randomWant})
s.periodicSearchTimer.Reset(s.periodicSearchDelay.NextWaitTime()) s.periodicSearchTimer.Reset(s.periodicSearchDelay.NextWaitTime())
} }
// findMorePeers attempts to find more peers for a session by searching for
// providers for the given Cid
func (s *Session) findMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) {
for p := range s.providerFinder.FindProvidersAsync(ctx, k) {
// When a provider indicates that it has a cid, it's equivalent to
// the providing peer sending a HAVE
s.sws.Update(p, nil, []cid.Cid{c}, nil)
}
}(c)
}
// handleShutdown is called when the session shuts down
func (s *Session) handleShutdown() { func (s *Session) handleShutdown() {
// Stop the idle timer
s.idleTick.Stop() s.idleTick.Stop()
// Shut down the session peer manager
s.sprm.Shutdown()
// Remove the session from the want manager
s.wm.RemoveSession(s.ctx, s.id) s.wm.RemoveSession(s.ctx, s.id)
} }
// handleReceive is called when the session receives blocks from a peer
func (s *Session) handleReceive(ks []cid.Cid) { func (s *Session) handleReceive(ks []cid.Cid) {
// Record which blocks have been received and figure out the total latency
// for fetching the blocks
wanted, totalLatency := s.sw.BlocksReceived(ks)
if len(wanted) == 0 {
return
}
// Record latency
s.latencyTrkr.receiveUpdate(len(wanted), totalLatency)
// Inform the SessionInterestManager that this session is no longer
// expecting to receive the wanted keys
s.sim.RemoveSessionWants(s.id, wanted)
s.idleTick.Stop() s.idleTick.Stop()
// We've received new wanted blocks, so reset the number of ticks // We've received new wanted blocks, so reset the number of ticks
// that have occurred since the last new block // that have occurred since the last new block
s.consecutiveTicks = 0 s.consecutiveTicks = 0
s.sprm.RecordCancels(ks)
s.resetIdleTick() s.resetIdleTick()
} }
// wantBlocks is called when blocks are requested by the client
func (s *Session) wantBlocks(ctx context.Context, newks []cid.Cid) { func (s *Session) wantBlocks(ctx context.Context, newks []cid.Cid) {
if len(newks) > 0 { if len(newks) > 0 {
// Inform the SessionInterestManager that this session is interested in the keys
s.sim.RecordSessionInterest(s.id, newks) s.sim.RecordSessionInterest(s.id, newks)
// Tell the sessionWants tracker that that the wants have been requested
s.sw.BlocksRequested(newks) s.sw.BlocksRequested(newks)
// Tell the sessionWantSender that the blocks have been requested
s.sws.Add(newks) s.sws.Add(newks)
} }
// If we have discovered peers already, the SessionPotentialManager will // If we have discovered peers already, the sessionWantSender will
// send wants to them // send wants to them
if s.sprm.Peers().Size() > 0 { if s.sprm.PeersDiscovered() {
return return
} }
...@@ -396,7 +436,6 @@ func (s *Session) wantBlocks(ctx context.Context, newks []cid.Cid) { ...@@ -396,7 +436,6 @@ func (s *Session) wantBlocks(ctx context.Context, newks []cid.Cid) {
ks := s.sw.GetNextWants(broadcastLiveWantsLimit) ks := s.sw.GetNextWants(broadcastLiveWantsLimit)
if len(ks) > 0 { if len(ks) > 0 {
log.Infof("Ses%d: No peers - broadcasting %d want HAVE requests\n", s.id, len(ks)) log.Infof("Ses%d: No peers - broadcasting %d want HAVE requests\n", s.id, len(ks))
s.sprm.RecordPeerRequests(nil, ks)
s.wm.BroadcastWantHaves(ctx, s.id, ks) s.wm.BroadcastWantHaves(ctx, s.id, ks)
} }
} }
...@@ -415,29 +454,19 @@ func (s *Session) resetIdleTick() { ...@@ -415,29 +454,19 @@ func (s *Session) resetIdleTick() {
} }
type latencyTracker struct { type latencyTracker struct {
sync.RWMutex
totalLatency time.Duration totalLatency time.Duration
count int count int
} }
func (lt *latencyTracker) hasLatency() bool { func (lt *latencyTracker) hasLatency() bool {
lt.RLock()
defer lt.RUnlock()
return lt.totalLatency > 0 && lt.count > 0 return lt.totalLatency > 0 && lt.count > 0
} }
func (lt *latencyTracker) averageLatency() time.Duration { func (lt *latencyTracker) averageLatency() time.Duration {
lt.RLock()
defer lt.RUnlock()
return lt.totalLatency / time.Duration(lt.count) return lt.totalLatency / time.Duration(lt.count)
} }
func (lt *latencyTracker) receiveUpdate(count int, totalLatency time.Duration) { func (lt *latencyTracker) receiveUpdate(count int, totalLatency time.Duration) {
lt.Lock()
defer lt.Unlock()
lt.totalLatency += totalLatency lt.totalLatency += totalLatency
lt.count += count lt.count += count
} }
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
notifications "github.com/ipfs/go-bitswap/internal/notifications" notifications "github.com/ipfs/go-bitswap/internal/notifications"
bspm "github.com/ipfs/go-bitswap/internal/peermanager" bspm "github.com/ipfs/go-bitswap/internal/peermanager"
bssim "github.com/ipfs/go-bitswap/internal/sessioninterestmanager" bssim "github.com/ipfs/go-bitswap/internal/sessioninterestmanager"
bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager"
"github.com/ipfs/go-bitswap/internal/testutil" "github.com/ipfs/go-bitswap/internal/testutil"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
blocksutil "github.com/ipfs/go-ipfs-blocksutil" blocksutil "github.com/ipfs/go-ipfs-blocksutil"
...@@ -38,40 +39,41 @@ func (fwm *fakeWantManager) BroadcastWantHaves(ctx context.Context, sesid uint64 ...@@ -38,40 +39,41 @@ func (fwm *fakeWantManager) BroadcastWantHaves(ctx context.Context, sesid uint64
} }
func (fwm *fakeWantManager) RemoveSession(context.Context, uint64) {} func (fwm *fakeWantManager) RemoveSession(context.Context, uint64) {}
type fakeSessionPeerManager struct { func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
peers *peer.Set return bsspm.New(1, newFakePeerTagger())
findMorePeersRequested chan cid.Cid
} }
func newFakeSessionPeerManager() *fakeSessionPeerManager { type fakePeerTagger struct {
return &fakeSessionPeerManager{
peers: peer.NewSet(),
findMorePeersRequested: make(chan cid.Cid, 1),
}
} }
func (fpm *fakeSessionPeerManager) FindMorePeers(ctx context.Context, k cid.Cid) { func newFakePeerTagger() *fakePeerTagger {
select { return &fakePeerTagger{}
case fpm.findMorePeersRequested <- k:
case <-ctx.Done():
}
} }
func (fpm *fakeSessionPeerManager) Peers() *peer.Set { func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {
return fpm.peers }
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
} }
func (fpm *fakeSessionPeerManager) ReceiveFrom(p peer.ID, ks []cid.Cid, haves []cid.Cid) bool { type fakeProviderFinder struct {
if !fpm.peers.Contains(p) { findMorePeersRequested chan cid.Cid
fpm.peers.Add(p) }
return true
func newFakeProviderFinder() *fakeProviderFinder {
return &fakeProviderFinder{
findMorePeersRequested: make(chan cid.Cid, 1),
} }
return false
} }
func (fpm *fakeSessionPeerManager) RecordCancels(c []cid.Cid) {}
func (fpm *fakeSessionPeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {} func (fpf *fakeProviderFinder) FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.ID {
func (fpm *fakeSessionPeerManager) RecordPeerResponse(p peer.ID, c []cid.Cid) { go func() {
fpm.peers.Add(p) select {
case fpf.findMorePeersRequested <- k:
case <-ctx.Done():
}
}()
return make(chan peer.ID)
} }
type fakePeerManager struct { type fakePeerManager struct {
...@@ -88,22 +90,24 @@ func (pm *fakePeerManager) UnregisterSession(uint64) ...@@ -88,22 +90,24 @@ func (pm *fakePeerManager) UnregisterSession(uint64)
func (pm *fakePeerManager) SendWants(context.Context, peer.ID, []cid.Cid, []cid.Cid) {} func (pm *fakePeerManager) SendWants(context.Context, peer.ID, []cid.Cid, []cid.Cid) {}
func TestSessionGetBlocks(t *testing.T) { func TestSessionGetBlocks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
defer notif.Shutdown() defer notif.Shutdown()
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2) blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
var cids []cid.Cid var cids []cid.Cid
for _, block := range blks { for _, block := range blks {
cids = append(cids, block.Cid()) cids = append(cids, block.Cid())
} }
_, err := session.GetBlocks(ctx, cids) _, err := session.GetBlocks(ctx, cids)
if err != nil { if err != nil {
...@@ -125,14 +129,16 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -125,14 +129,16 @@ func TestSessionGetBlocks(t *testing.T) {
} }
// Simulate receiving HAVEs from several peers // Simulate receiving HAVEs from several peers
peers := testutil.GeneratePeers(broadcastLiveWantsLimit) peers := testutil.GeneratePeers(5)
for i, p := range peers { for i, p := range peers {
blk := blks[testutil.IndexOf(blks, receivedWantReq.cids[i])] blk := blks[testutil.IndexOf(blks, receivedWantReq.cids[i])]
session.ReceiveFrom(p, []cid.Cid{}, []cid.Cid{blk.Cid()}, []cid.Cid{}) session.ReceiveFrom(p, []cid.Cid{}, []cid.Cid{blk.Cid()}, []cid.Cid{})
} }
time.Sleep(10 * time.Millisecond)
// Verify new peers were recorded // Verify new peers were recorded
if !testutil.MatchPeersIgnoreOrder(fpm.Peers().Peers(), peers) { if !testutil.MatchPeersIgnoreOrder(fpm.Peers(), peers) {
t.Fatal("peers not recorded by the peer manager") t.Fatal("peers not recorded by the peer manager")
} }
...@@ -145,6 +151,8 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -145,6 +151,8 @@ func TestSessionGetBlocks(t *testing.T) {
// Simulate receiving DONT_HAVE for a CID // Simulate receiving DONT_HAVE for a CID
session.ReceiveFrom(peers[0], []cid.Cid{}, []cid.Cid{}, []cid.Cid{blks[0].Cid()}) session.ReceiveFrom(peers[0], []cid.Cid{}, []cid.Cid{}, []cid.Cid{blks[0].Cid()})
time.Sleep(10 * time.Millisecond)
// Verify session still wants received blocks // Verify session still wants received blocks
_, unwanted = sim.SplitWantedUnwanted(blks) _, unwanted = sim.SplitWantedUnwanted(blks)
if len(unwanted) > 0 { if len(unwanted) > 0 {
...@@ -154,6 +162,8 @@ func TestSessionGetBlocks(t *testing.T) { ...@@ -154,6 +162,8 @@ func TestSessionGetBlocks(t *testing.T) {
// Simulate receiving block for a CID // Simulate receiving block for a CID
session.ReceiveFrom(peers[1], []cid.Cid{blks[0].Cid()}, []cid.Cid{}, []cid.Cid{}) session.ReceiveFrom(peers[1], []cid.Cid{blks[0].Cid()}, []cid.Cid{}, []cid.Cid{})
time.Sleep(100 * time.Millisecond)
// Verify session no longer wants received block // Verify session no longer wants received block
wanted, unwanted := sim.SplitWantedUnwanted(blks) wanted, unwanted := sim.SplitWantedUnwanted(blks)
if len(unwanted) != 1 || !unwanted[0].Cid().Equals(blks[0].Cid()) { if len(unwanted) != 1 || !unwanted[0].Cid().Equals(blks[0].Cid()) {
...@@ -169,12 +179,13 @@ func TestSessionFindMorePeers(t *testing.T) { ...@@ -169,12 +179,13 @@ func TestSessionFindMorePeers(t *testing.T) {
defer cancel() defer cancel()
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
defer notif.Shutdown() defer notif.Shutdown()
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "")
session.SetBaseTickDelay(200 * time.Microsecond) session.SetBaseTickDelay(200 * time.Microsecond)
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2) blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
...@@ -223,7 +234,7 @@ func TestSessionFindMorePeers(t *testing.T) { ...@@ -223,7 +234,7 @@ func TestSessionFindMorePeers(t *testing.T) {
// The session should eventually try to find more peers // The session should eventually try to find more peers
select { select {
case <-fpm.findMorePeersRequested: case <-fpf.findMorePeersRequested:
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("Did not find more peers") t.Fatal("Did not find more peers")
} }
...@@ -234,12 +245,14 @@ func TestSessionOnPeersExhausted(t *testing.T) { ...@@ -234,12 +245,14 @@ func TestSessionOnPeersExhausted(t *testing.T) {
defer cancel() defer cancel()
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
defer notif.Shutdown() defer notif.Shutdown()
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit + 5) blks := blockGenerator.Blocks(broadcastLiveWantsLimit + 5)
var cids []cid.Cid var cids []cid.Cid
...@@ -277,12 +290,13 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { ...@@ -277,12 +290,13 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
defer cancel() defer cancel()
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
defer notif.Shutdown() defer notif.Shutdown()
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "") session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "")
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(4) blks := blockGenerator.Blocks(4)
var cids []cid.Cid var cids []cid.Cid
...@@ -314,7 +328,7 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { ...@@ -314,7 +328,7 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
// Wait for a request to find more peers to occur // Wait for a request to find more peers to occur
select { select {
case k := <-fpm.findMorePeersRequested: case k := <-fpf.findMorePeersRequested:
if testutil.IndexOf(blks, k) == -1 { if testutil.IndexOf(blks, k) == -1 {
t.Fatal("did not rebroadcast an active want") t.Fatal("did not rebroadcast an active want")
} }
...@@ -369,14 +383,14 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { ...@@ -369,14 +383,14 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
// Should not have tried to find peers on consecutive ticks // Should not have tried to find peers on consecutive ticks
select { select {
case <-fpm.findMorePeersRequested: case <-fpf.findMorePeersRequested:
t.Fatal("Should not have tried to find peers on consecutive ticks") t.Fatal("Should not have tried to find peers on consecutive ticks")
default: default:
} }
// Wait for rebroadcast to occur // Wait for rebroadcast to occur
select { select {
case k := <-fpm.findMorePeersRequested: case k := <-fpf.findMorePeersRequested:
if testutil.IndexOf(blks, k) == -1 { if testutil.IndexOf(blks, k) == -1 {
t.Fatal("did not rebroadcast an active want") t.Fatal("did not rebroadcast an active want")
} }
...@@ -388,6 +402,7 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { ...@@ -388,6 +402,7 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) { func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
...@@ -396,7 +411,7 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) { ...@@ -396,7 +411,7 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
// Create a new session with its own context // Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond) sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
session := New(sessctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session := New(sessctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "")
timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond) timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer timerCancel() defer timerCancel()
...@@ -430,12 +445,14 @@ func TestSessionReceiveMessageAfterShutdown(t *testing.T) { ...@@ -430,12 +445,14 @@ func TestSessionReceiveMessageAfterShutdown(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, cancelCtx := context.WithTimeout(context.Background(), 10*time.Millisecond)
fwm := newFakeWantManager() fwm := newFakeWantManager()
fpm := newFakeSessionPeerManager() fpm := newFakeSessionPeerManager()
fpf := newFakeProviderFinder()
sim := bssim.New() sim := bssim.New()
bpm := bsbpm.New() bpm := bsbpm.New()
notif := notifications.New() notif := notifications.New()
defer notif.Shutdown() defer notif.Shutdown()
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "")
blockGenerator := blocksutil.NewBlockGenerator() blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(2) blks := blockGenerator.Blocks(2)
cids := []cid.Cid{blks[0].Cid(), blks[1].Cid()} cids := []cid.Cid{blks[0].Cid(), blks[1].Cid()}
......
...@@ -3,7 +3,6 @@ package session ...@@ -3,7 +3,6 @@ package session
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"sync"
"time" "time"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
...@@ -12,7 +11,6 @@ import ( ...@@ -12,7 +11,6 @@ import (
// sessionWants keeps track of which cids are waiting to be sent out, and which // sessionWants keeps track of which cids are waiting to be sent out, and which
// peers are "live" - ie, we've sent a request but haven't received a block yet // peers are "live" - ie, we've sent a request but haven't received a block yet
type sessionWants struct { type sessionWants struct {
sync.RWMutex
toFetch *cidQueue toFetch *cidQueue
liveWants map[cid.Cid]time.Time liveWants map[cid.Cid]time.Time
} }
...@@ -30,9 +28,6 @@ func (sw *sessionWants) String() string { ...@@ -30,9 +28,6 @@ func (sw *sessionWants) String() string {
// BlocksRequested is called when the client makes a request for blocks // BlocksRequested is called when the client makes a request for blocks
func (sw *sessionWants) BlocksRequested(newWants []cid.Cid) { func (sw *sessionWants) BlocksRequested(newWants []cid.Cid) {
sw.Lock()
defer sw.Unlock()
for _, k := range newWants { for _, k := range newWants {
sw.toFetch.Push(k) sw.toFetch.Push(k)
} }
...@@ -43,9 +38,6 @@ func (sw *sessionWants) BlocksRequested(newWants []cid.Cid) { ...@@ -43,9 +38,6 @@ func (sw *sessionWants) BlocksRequested(newWants []cid.Cid) {
func (sw *sessionWants) GetNextWants(limit int) []cid.Cid { func (sw *sessionWants) GetNextWants(limit int) []cid.Cid {
now := time.Now() now := time.Now()
sw.Lock()
defer sw.Unlock()
// Move CIDs from fetch queue to the live wants queue (up to the limit) // Move CIDs from fetch queue to the live wants queue (up to the limit)
currentLiveCount := len(sw.liveWants) currentLiveCount := len(sw.liveWants)
toAdd := limit - currentLiveCount toAdd := limit - currentLiveCount
...@@ -63,10 +55,6 @@ func (sw *sessionWants) GetNextWants(limit int) []cid.Cid { ...@@ -63,10 +55,6 @@ func (sw *sessionWants) GetNextWants(limit int) []cid.Cid {
// WantsSent is called when wants are sent to a peer // WantsSent is called when wants are sent to a peer
func (sw *sessionWants) WantsSent(ks []cid.Cid) { func (sw *sessionWants) WantsSent(ks []cid.Cid) {
now := time.Now() now := time.Now()
sw.Lock()
defer sw.Unlock()
for _, c := range ks { for _, c := range ks {
if _, ok := sw.liveWants[c]; !ok { if _, ok := sw.liveWants[c]; !ok {
sw.toFetch.Remove(c) sw.toFetch.Remove(c)
...@@ -86,12 +74,8 @@ func (sw *sessionWants) BlocksReceived(ks []cid.Cid) ([]cid.Cid, time.Duration) ...@@ -86,12 +74,8 @@ func (sw *sessionWants) BlocksReceived(ks []cid.Cid) ([]cid.Cid, time.Duration)
} }
now := time.Now() now := time.Now()
sw.Lock()
defer sw.Unlock()
for _, c := range ks { for _, c := range ks {
if sw.unlockedIsWanted(c) { if sw.isWanted(c) {
wanted = append(wanted, c) wanted = append(wanted, c)
sentAt, ok := sw.liveWants[c] sentAt, ok := sw.liveWants[c]
...@@ -113,10 +97,6 @@ func (sw *sessionWants) BlocksReceived(ks []cid.Cid) ([]cid.Cid, time.Duration) ...@@ -113,10 +97,6 @@ func (sw *sessionWants) BlocksReceived(ks []cid.Cid) ([]cid.Cid, time.Duration)
// live want CIDs. // live want CIDs.
func (sw *sessionWants) PrepareBroadcast() []cid.Cid { func (sw *sessionWants) PrepareBroadcast() []cid.Cid {
now := time.Now() now := time.Now()
sw.Lock()
defer sw.Unlock()
live := make([]cid.Cid, 0, len(sw.liveWants)) live := make([]cid.Cid, 0, len(sw.liveWants))
for c := range sw.liveWants { for c := range sw.liveWants {
live = append(live, c) live = append(live, c)
...@@ -127,9 +107,6 @@ func (sw *sessionWants) PrepareBroadcast() []cid.Cid { ...@@ -127,9 +107,6 @@ func (sw *sessionWants) PrepareBroadcast() []cid.Cid {
// CancelPending removes the given CIDs from the fetch queue. // CancelPending removes the given CIDs from the fetch queue.
func (sw *sessionWants) CancelPending(keys []cid.Cid) { func (sw *sessionWants) CancelPending(keys []cid.Cid) {
sw.Lock()
defer sw.Unlock()
for _, k := range keys { for _, k := range keys {
sw.toFetch.Remove(k) sw.toFetch.Remove(k)
} }
...@@ -137,9 +114,6 @@ func (sw *sessionWants) CancelPending(keys []cid.Cid) { ...@@ -137,9 +114,6 @@ func (sw *sessionWants) CancelPending(keys []cid.Cid) {
// LiveWants returns a list of live wants // LiveWants returns a list of live wants
func (sw *sessionWants) LiveWants() []cid.Cid { func (sw *sessionWants) LiveWants() []cid.Cid {
sw.RLock()
defer sw.RUnlock()
live := make([]cid.Cid, 0, len(sw.liveWants)) live := make([]cid.Cid, 0, len(sw.liveWants))
for c := range sw.liveWants { for c := range sw.liveWants {
live = append(live, c) live = append(live, c)
...@@ -148,16 +122,12 @@ func (sw *sessionWants) LiveWants() []cid.Cid { ...@@ -148,16 +122,12 @@ func (sw *sessionWants) LiveWants() []cid.Cid {
} }
func (sw *sessionWants) RandomLiveWant() cid.Cid { func (sw *sessionWants) RandomLiveWant() cid.Cid {
i := rand.Uint64()
sw.RLock()
defer sw.RUnlock()
if len(sw.liveWants) == 0 { if len(sw.liveWants) == 0 {
return cid.Cid{} return cid.Cid{}
} }
i %= uint64(len(sw.liveWants))
// picking a random live want // picking a random live want
i := rand.Intn(len(sw.liveWants))
for k := range sw.liveWants { for k := range sw.liveWants {
if i == 0 { if i == 0 {
return k return k
...@@ -169,13 +139,11 @@ func (sw *sessionWants) RandomLiveWant() cid.Cid { ...@@ -169,13 +139,11 @@ func (sw *sessionWants) RandomLiveWant() cid.Cid {
// Has live wants indicates if there are any live wants // Has live wants indicates if there are any live wants
func (sw *sessionWants) HasLiveWants() bool { func (sw *sessionWants) HasLiveWants() bool {
sw.RLock()
defer sw.RUnlock()
return len(sw.liveWants) > 0 return len(sw.liveWants) > 0
} }
func (sw *sessionWants) unlockedIsWanted(c cid.Cid) bool { // Indicates whether the want is in either of the fetch or live queues
func (sw *sessionWants) isWanted(c cid.Cid) bool {
_, ok := sw.liveWants[c] _, ok := sw.liveWants[c]
if !ok { if !ok {
ok = sw.toFetch.Has(c) ok = sw.toFetch.Has(c)
......
This diff is collapsed.
...@@ -45,12 +45,12 @@ func (fs *fakeSession) ReceiveFrom(p peer.ID, ks []cid.Cid, wantBlocks []cid.Cid ...@@ -45,12 +45,12 @@ func (fs *fakeSession) ReceiveFrom(p peer.ID, ks []cid.Cid, wantBlocks []cid.Cid
type fakeSesPeerManager struct { type fakeSesPeerManager struct {
} }
func (*fakeSesPeerManager) ReceiveFrom(peer.ID, []cid.Cid, []cid.Cid) bool { return true } func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) Peers() *peer.Set { return nil } func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) FindMorePeers(context.Context, cid.Cid) {} func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {} func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RecordPeerResponse(peer.ID, []cid.Cid) {} func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RecordCancels(c []cid.Cid) {} func (*fakeSesPeerManager) HasPeers() bool { return false }
type fakePeerManager struct { type fakePeerManager struct {
} }
......
package sessionpeermanager
import (
"time"
"github.com/ipfs/go-cid"
)
type requestData struct {
startedAt time.Time
wasCancelled bool
timeoutFunc *time.Timer
}
type latencyTracker struct {
requests map[cid.Cid]*requestData
}
func newLatencyTracker() *latencyTracker {
return &latencyTracker{requests: make(map[cid.Cid]*requestData)}
}
type afterTimeoutFunc func(cid.Cid)
func (lt *latencyTracker) SetupRequests(keys []cid.Cid, timeoutDuration time.Duration, afterTimeout afterTimeoutFunc) {
startedAt := time.Now()
for _, k := range keys {
if _, ok := lt.requests[k]; !ok {
lt.requests[k] = &requestData{
startedAt,
false,
time.AfterFunc(timeoutDuration, makeAfterTimeout(afterTimeout, k)),
}
}
}
}
func makeAfterTimeout(afterTimeout afterTimeoutFunc, k cid.Cid) func() {
return func() { afterTimeout(k) }
}
func (lt *latencyTracker) CheckDuration(key cid.Cid) (time.Duration, bool) {
request, ok := lt.requests[key]
var latency time.Duration
if ok {
latency = time.Since(request.startedAt)
}
return latency, ok
}
func (lt *latencyTracker) RemoveRequest(key cid.Cid) {
request, ok := lt.requests[key]
if ok {
request.timeoutFunc.Stop()
delete(lt.requests, key)
}
}
func (lt *latencyTracker) RecordCancel(keys []cid.Cid) {
for _, key := range keys {
request, ok := lt.requests[key]
if ok {
request.wasCancelled = true
}
}
}
func (lt *latencyTracker) WasCancelled(key cid.Cid) bool {
request, ok := lt.requests[key]
return ok && request.wasCancelled
}
func (lt *latencyTracker) Shutdown() {
for _, request := range lt.requests {
request.timeoutFunc.Stop()
}
}
package sessionpeermanager
import (
"time"
"github.com/ipfs/go-cid"
)
const (
newLatencyWeight = 0.5
)
type peerData struct {
hasLatency bool
latency time.Duration
lt *latencyTracker
}
func newPeerData() *peerData {
return &peerData{
hasLatency: false,
lt: newLatencyTracker(),
latency: 0,
}
}
func (pd *peerData) AdjustLatency(k cid.Cid, hasFallbackLatency bool, fallbackLatency time.Duration) {
latency, hasLatency := pd.lt.CheckDuration(k)
pd.lt.RemoveRequest(k)
if !hasLatency {
latency, hasLatency = fallbackLatency, hasFallbackLatency
}
if hasLatency {
if pd.hasLatency {
pd.latency = time.Duration(float64(pd.latency)*(1.0-newLatencyWeight) + float64(latency)*newLatencyWeight)
} else {
pd.latency = latency
pd.hasLatency = true
}
}
}
package sessionpeermanager package sessionpeermanager
import ( import (
"context"
"fmt" "fmt"
"math/rand" "sync"
"sort"
"time"
bssd "github.com/ipfs/go-bitswap/internal/sessiondata"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-core/peer" peer "github.com/libp2p/go-libp2p-core/peer"
) )
var log = logging.Logger("bs:sprmgr") var log = logging.Logger("bs:sprmgr")
const ( const (
defaultTimeoutDuration = 5 * time.Second // Connection Manager tag value for session peers. Indicates to connection
maxOptimizedPeers = 32 // manager that it should keep the connection to the peer.
unoptimizedTagValue = 5 // tag value for "unoptimized" session peers. sessionPeerTagValue = 5
optimizedTagValue = 10 // tag value for "optimized" session peers.
) )
// PeerTagger is an interface for tagging peers with metadata // PeerTagger is an interface for tagging peers with metadata
...@@ -29,362 +23,100 @@ type PeerTagger interface { ...@@ -29,362 +23,100 @@ type PeerTagger interface {
UntagPeer(p peer.ID, tag string) UntagPeer(p peer.ID, tag string)
} }
// PeerProviderFinder is an interface for finding providers // SessionPeerManager keeps track of peers for a session, and takes care of
type PeerProviderFinder interface { // ConnectionManager tagging.
FindProvidersAsync(context.Context, cid.Cid) <-chan peer.ID
}
type peerMessage interface {
handle(spm *SessionPeerManager)
}
// SessionPeerManager tracks and manages peers for a session, and provides
// the best ones to the session
type SessionPeerManager struct { type SessionPeerManager struct {
ctx context.Context tagger PeerTagger
tagger PeerTagger tag string
providerFinder PeerProviderFinder
peers *peer.Set
tag string
id uint64
peerMessages chan peerMessage
// do not touch outside of run loop plk sync.RWMutex
activePeers map[peer.ID]*peerData peers map[peer.ID]struct{}
unoptimizedPeersArr []peer.ID peersDiscovered bool
optimizedPeersArr []peer.ID
broadcastLatency *latencyTracker
timeoutDuration time.Duration
} }
// New creates a new SessionPeerManager // New creates a new SessionPeerManager
func New(ctx context.Context, id uint64, tagger PeerTagger, providerFinder PeerProviderFinder) *SessionPeerManager { func New(id uint64, tagger PeerTagger) *SessionPeerManager {
spm := &SessionPeerManager{ return &SessionPeerManager{
ctx: ctx, tag: fmt.Sprint("bs-ses-", id),
id: id, tagger: tagger,
tagger: tagger, peers: make(map[peer.ID]struct{}),
providerFinder: providerFinder,
peers: peer.NewSet(),
peerMessages: make(chan peerMessage, 128),
activePeers: make(map[peer.ID]*peerData),
broadcastLatency: newLatencyTracker(),
timeoutDuration: defaultTimeoutDuration,
}
spm.tag = fmt.Sprint("bs-ses-", id)
go spm.run(ctx)
return spm
}
func (spm *SessionPeerManager) ReceiveFrom(p peer.ID, ks []cid.Cid, haves []cid.Cid) bool {
if len(ks) > 0 || len(haves) > 0 && !spm.peers.Contains(p) {
log.Infof("Added peer %s to session: %d peers\n", p, spm.peers.Size())
spm.peers.Add(p)
return true
}
return false
}
func (spm *SessionPeerManager) Peers() *peer.Set {
return spm.peers
}
// RecordPeerResponse records that a peer received some blocks, and adds the
// peer to the list of peers if it wasn't already added
func (spm *SessionPeerManager) RecordPeerResponse(p peer.ID, ks []cid.Cid) {
select {
case spm.peerMessages <- &peerResponseMessage{p, ks}:
case <-spm.ctx.Done():
}
}
// RecordCancels records the fact that cancellations were sent to peers,
// so if blocks don't arrive, don't let it affect the peer's timeout
func (spm *SessionPeerManager) RecordCancels(ks []cid.Cid) {
select {
case spm.peerMessages <- &cancelMessage{ks}:
case <-spm.ctx.Done():
}
}
// RecordPeerRequests records that a given set of peers requested the given cids.
func (spm *SessionPeerManager) RecordPeerRequests(p []peer.ID, ks []cid.Cid) {
select {
case spm.peerMessages <- &peerRequestMessage{p, ks}:
case <-spm.ctx.Done():
}
}
// GetOptimizedPeers returns the best peers available for a session, along with
// a rating for how good they are, in comparison to the best peer.
func (spm *SessionPeerManager) GetOptimizedPeers() []bssd.OptimizedPeer {
// right now this just returns all peers, but soon we might return peers
// ordered by optimization, or only a subset
resp := make(chan []bssd.OptimizedPeer, 1)
select {
case spm.peerMessages <- &getPeersMessage{resp}:
case <-spm.ctx.Done():
return nil
}
select {
case peers := <-resp:
return peers
case <-spm.ctx.Done():
return nil
}
}
// FindMorePeers attempts to find more peers for a session by searching for
// providers for the given Cid
func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) {
for p := range spm.providerFinder.FindProvidersAsync(ctx, k) {
select {
case spm.peerMessages <- &peerFoundMessage{p}:
case <-ctx.Done():
case <-spm.ctx.Done():
}
}
}(c)
}
// SetTimeoutDuration changes the length of time used to timeout recording of
// requests
func (spm *SessionPeerManager) SetTimeoutDuration(timeoutDuration time.Duration) {
select {
case spm.peerMessages <- &setTimeoutMessage{timeoutDuration}:
case <-spm.ctx.Done():
}
}
func (spm *SessionPeerManager) run(ctx context.Context) {
for {
select {
case pm := <-spm.peerMessages:
pm.handle(spm)
case <-ctx.Done():
spm.handleShutdown()
return
}
}
}
func (spm *SessionPeerManager) tagPeer(p peer.ID, data *peerData) {
var value int
if data.hasLatency {
value = optimizedTagValue
} else {
value = unoptimizedTagValue
} }
spm.tagger.TagPeer(p, spm.tag, value)
} }
func (spm *SessionPeerManager) insertPeer(p peer.ID, data *peerData) { // AddPeer adds the peer to the SessionPeerManager.
if data.hasLatency { // Returns true if the peer is a new peer, false if it already existed.
insertPos := sort.Search(len(spm.optimizedPeersArr), func(i int) bool { func (spm *SessionPeerManager) AddPeer(p peer.ID) bool {
return spm.activePeers[spm.optimizedPeersArr[i]].latency > data.latency spm.plk.Lock()
}) defer spm.plk.Unlock()
spm.optimizedPeersArr = append(spm.optimizedPeersArr[:insertPos],
append([]peer.ID{p}, spm.optimizedPeersArr[insertPos:]...)...)
} else {
spm.unoptimizedPeersArr = append(spm.unoptimizedPeersArr, p)
}
if !spm.peers.Contains(p) { // Check if the peer is a new peer
log.Infof("Added peer %s to session: %d peers\n", p, spm.peers.Size()) if _, ok := spm.peers[p]; ok {
spm.peers.Add(p) return false
} }
}
func (spm *SessionPeerManager) removeOptimizedPeer(p peer.ID) { spm.peers[p] = struct{}{}
for i := 0; i < len(spm.optimizedPeersArr); i++ { spm.peersDiscovered = true
if spm.optimizedPeersArr[i] == p {
spm.optimizedPeersArr = append(spm.optimizedPeersArr[:i], spm.optimizedPeersArr[i+1:]...)
return
}
}
}
func (spm *SessionPeerManager) removeUnoptimizedPeer(p peer.ID) { // Tag the peer with the ConnectionManager so it doesn't discard the
for i := 0; i < len(spm.unoptimizedPeersArr); i++ { // connection
if spm.unoptimizedPeersArr[i] == p { spm.tagger.TagPeer(p, spm.tag, sessionPeerTagValue)
spm.unoptimizedPeersArr[i] = spm.unoptimizedPeersArr[len(spm.unoptimizedPeersArr)-1]
spm.unoptimizedPeersArr = spm.unoptimizedPeersArr[:len(spm.unoptimizedPeersArr)-1]
return
}
}
}
func (spm *SessionPeerManager) recordResponse(p peer.ID, ks []cid.Cid) { log.Infof("Added peer %s to session: %d peers\n", p, len(spm.peers))
data, ok := spm.activePeers[p] return true
wasOptimized := ok && data.hasLatency
if wasOptimized {
spm.removeOptimizedPeer(p)
} else {
if ok {
spm.removeUnoptimizedPeer(p)
} else {
data = newPeerData()
spm.activePeers[p] = data
}
}
for _, k := range ks {
fallbackLatency, hasFallbackLatency := spm.broadcastLatency.CheckDuration(k)
data.AdjustLatency(k, hasFallbackLatency, fallbackLatency)
}
if !ok || wasOptimized != data.hasLatency {
spm.tagPeer(p, data)
}
spm.insertPeer(p, data)
} }
type peerFoundMessage struct { // RemovePeer removes the peer from the SessionPeerManager.
p peer.ID // Returns true if the peer was removed, false if it did not exist.
} func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
spm.plk.Lock()
defer spm.plk.Unlock()
func (pfm *peerFoundMessage) handle(spm *SessionPeerManager) { if _, ok := spm.peers[p]; !ok {
p := pfm.p return false
if _, ok := spm.activePeers[p]; !ok {
spm.activePeers[p] = newPeerData()
spm.insertPeer(p, spm.activePeers[p])
spm.tagPeer(p, spm.activePeers[p])
} }
}
type peerResponseMessage struct { delete(spm.peers, p)
p peer.ID spm.tagger.UntagPeer(p, spm.tag)
ks []cid.Cid return true
} }
func (prm *peerResponseMessage) handle(spm *SessionPeerManager) { // PeersDiscovered indicates whether peers have been discovered yet.
spm.recordResponse(prm.p, prm.ks) // Returns true once a peer has been discovered by the session (even if all
} // peers are later removed from the session).
func (spm *SessionPeerManager) PeersDiscovered() bool {
spm.plk.RLock()
defer spm.plk.RUnlock()
type peerRequestMessage struct { return spm.peersDiscovered
peers []peer.ID
keys []cid.Cid
} }
func (spm *SessionPeerManager) makeTimeout(p peer.ID) afterTimeoutFunc { func (spm *SessionPeerManager) Peers() []peer.ID {
return func(k cid.Cid) { spm.plk.RLock()
select { defer spm.plk.RUnlock()
case spm.peerMessages <- &peerTimeoutMessage{p, k}:
case <-spm.ctx.Done():
}
}
}
func (prm *peerRequestMessage) handle(spm *SessionPeerManager) { peers := make([]peer.ID, 0, len(spm.peers))
if prm.peers == nil { for p := range spm.peers {
spm.broadcastLatency.SetupRequests(prm.keys, spm.timeoutDuration, func(k cid.Cid) { peers = append(peers, p)
select {
case spm.peerMessages <- &broadcastTimeoutMessage{k}:
case <-spm.ctx.Done():
}
})
} else {
for _, p := range prm.peers {
if data, ok := spm.activePeers[p]; ok {
data.lt.SetupRequests(prm.keys, spm.timeoutDuration, spm.makeTimeout(p))
}
}
} }
}
type getPeersMessage struct { return peers
resp chan<- []bssd.OptimizedPeer
} }
// Get all optimized peers in order followed by randomly ordered unoptimized func (spm *SessionPeerManager) HasPeers() bool {
// peers, with a limit of maxOptimizedPeers spm.plk.RLock()
func (prm *getPeersMessage) handle(spm *SessionPeerManager) { defer spm.plk.RUnlock()
randomOrder := rand.Perm(len(spm.unoptimizedPeersArr))
// Number of peers to get in total: unoptimized + optimized
// limited by maxOptimizedPeers
maxPeers := len(spm.unoptimizedPeersArr) + len(spm.optimizedPeersArr)
if maxPeers > maxOptimizedPeers {
maxPeers = maxOptimizedPeers
}
// The best peer latency is the first optimized peer's latency.
// If we haven't recorded any peer's latency, use 0.
var bestPeerLatency float64
if len(spm.optimizedPeersArr) > 0 {
bestPeerLatency = float64(spm.activePeers[spm.optimizedPeersArr[0]].latency)
} else {
bestPeerLatency = 0
}
optimizedPeers := make([]bssd.OptimizedPeer, 0, maxPeers) return len(spm.peers) > 0
for i := 0; i < maxPeers; i++ {
// First add optimized peers in order
if i < len(spm.optimizedPeersArr) {
p := spm.optimizedPeersArr[i]
optimizedPeers = append(optimizedPeers, bssd.OptimizedPeer{
Peer: p,
OptimizationRating: bestPeerLatency / float64(spm.activePeers[p].latency),
})
} else {
// Then add unoptimized peers in random order
p := spm.unoptimizedPeersArr[randomOrder[i-len(spm.optimizedPeersArr)]]
optimizedPeers = append(optimizedPeers, bssd.OptimizedPeer{Peer: p, OptimizationRating: 0.0})
}
}
prm.resp <- optimizedPeers
} }
type cancelMessage struct { // Shutdown untags all the peers
ks []cid.Cid func (spm *SessionPeerManager) Shutdown() {
} spm.plk.Lock()
defer spm.plk.Unlock()
func (cm *cancelMessage) handle(spm *SessionPeerManager) {
for _, data := range spm.activePeers {
data.lt.RecordCancel(cm.ks)
}
}
func (spm *SessionPeerManager) handleShutdown() { // Untag the peers with the ConnectionManager so that it can release
for p, data := range spm.activePeers { // connections to those peers
for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
data.lt.Shutdown()
} }
} }
type peerTimeoutMessage struct {
p peer.ID
k cid.Cid
}
func (ptm *peerTimeoutMessage) handle(spm *SessionPeerManager) {
data, ok := spm.activePeers[ptm.p]
// If the request was cancelled, make sure we clean up the request tracker
if ok && data.lt.WasCancelled(ptm.k) {
data.lt.RemoveRequest(ptm.k)
} else {
// If the request was not cancelled, record the latency. Note that we
// do this even if we didn't previously know about this peer.
spm.recordResponse(ptm.p, []cid.Cid{ptm.k})
}
}
type broadcastTimeoutMessage struct {
k cid.Cid
}
func (btm *broadcastTimeoutMessage) handle(spm *SessionPeerManager) {
spm.broadcastLatency.RemoveRequest(btm.k)
}
type setTimeoutMessage struct {
timeoutDuration time.Duration
}
func (stm *setTimeoutMessage) handle(spm *SessionPeerManager) {
spm.timeoutDuration = stm.timeoutDuration
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment