Commit e1a25234 authored by hannahhoward's avatar hannahhoward

test(sessionmanager): Add unit test

Add a unit test and do some additional decoupling
parent d7a532d0
...@@ -16,9 +16,10 @@ import ( ...@@ -16,9 +16,10 @@ import (
bsnet "github.com/ipfs/go-bitswap/network" bsnet "github.com/ipfs/go-bitswap/network"
notifications "github.com/ipfs/go-bitswap/notifications" notifications "github.com/ipfs/go-bitswap/notifications"
bspm "github.com/ipfs/go-bitswap/peermanager" bspm "github.com/ipfs/go-bitswap/peermanager"
bssession "github.com/ipfs/go-bitswap/session"
bssm "github.com/ipfs/go-bitswap/sessionmanager" bssm "github.com/ipfs/go-bitswap/sessionmanager"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
bswm "github.com/ipfs/go-bitswap/wantmanager" bswm "github.com/ipfs/go-bitswap/wantmanager"
blocks "github.com/ipfs/go-block-format" blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
blockstore "github.com/ipfs/go-ipfs-blockstore" blockstore "github.com/ipfs/go-ipfs-blockstore"
...@@ -102,6 +103,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -102,6 +103,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
} }
wm := bswm.New(ctx) wm := bswm.New(ctx)
sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager) bssm.Session {
return bssession.New(ctx, id, wm, pm)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager {
return bsspm.New(ctx, id, network)
}
bs := &Bitswap{ bs := &Bitswap{
blockstore: bstore, blockstore: bstore,
notifications: notif, notifications: notif,
...@@ -113,7 +121,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -113,7 +121,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
provideKeys: make(chan cid.Cid, provideKeysBufferSize), provideKeys: make(chan cid.Cid, provideKeysBufferSize),
wm: wm, wm: wm,
pm: bspm.New(ctx, peerQueueFactory), pm: bspm.New(ctx, peerQueueFactory),
sm: bssm.New(ctx, wm, network), sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory),
counters: new(counters), counters: new(counters),
dupMetric: dupHist, dupMetric: dupHist,
allMetric: allHist, allMetric: allHist,
......
...@@ -8,22 +8,34 @@ import ( ...@@ -8,22 +8,34 @@ import (
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
bssession "github.com/ipfs/go-bitswap/session" bssession "github.com/ipfs/go-bitswap/session"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
exchange "github.com/ipfs/go-ipfs-exchange-interface" exchange "github.com/ipfs/go-ipfs-exchange-interface"
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
) )
// Session is a session that is managed by the session manager
type Session interface {
exchange.Fetcher
InterestedIn(cid.Cid) bool
ReceiveBlockFrom(peer.ID, blocks.Block)
}
type sesTrk struct { type sesTrk struct {
session *bssession.Session session Session
pm *bsspm.SessionPeerManager pm bssession.PeerManager
} }
// SessionFactory generates a new session for the SessionManager to track.
type SessionFactory func(ctx context.Context, id uint64, pm bssession.PeerManager) Session
// PeerManagerFactory generates a new peer manager for a session.
type PeerManagerFactory func(ctx context.Context, id uint64) bssession.PeerManager
// SessionManager is responsible for creating, managing, and dispatching to // SessionManager is responsible for creating, managing, and dispatching to
// sessions. // sessions.
type SessionManager struct { type SessionManager struct {
wm bssession.WantManager ctx context.Context
network bsspm.PeerNetwork sessionFactory SessionFactory
ctx context.Context peerManagerFactory PeerManagerFactory
// Sessions // Sessions
sessLk sync.Mutex sessLk sync.Mutex
sessions []sesTrk sessions []sesTrk
...@@ -34,11 +46,11 @@ type SessionManager struct { ...@@ -34,11 +46,11 @@ type SessionManager struct {
} }
// New creates a new SessionManager. // New creates a new SessionManager.
func New(ctx context.Context, wm bssession.WantManager, network bsspm.PeerNetwork) *SessionManager { func New(ctx context.Context, sessionFactory SessionFactory, peerManagerFactory PeerManagerFactory) *SessionManager {
return &SessionManager{ return &SessionManager{
ctx: ctx, ctx: ctx,
wm: wm, sessionFactory: sessionFactory,
network: network, peerManagerFactory: peerManagerFactory,
} }
} }
...@@ -48,8 +60,8 @@ func (sm *SessionManager) NewSession(ctx context.Context) exchange.Fetcher { ...@@ -48,8 +60,8 @@ func (sm *SessionManager) NewSession(ctx context.Context) exchange.Fetcher {
id := sm.GetNextSessionID() id := sm.GetNextSessionID()
sessionctx, cancel := context.WithCancel(ctx) sessionctx, cancel := context.WithCancel(ctx)
pm := bsspm.New(sessionctx, id, sm.network) pm := sm.peerManagerFactory(sessionctx, id)
session := bssession.New(sessionctx, id, sm.wm, pm) session := sm.sessionFactory(sessionctx, id, pm)
tracked := sesTrk{session, pm} tracked := sesTrk{session, pm}
sm.sessLk.Lock() sm.sessLk.Lock()
sm.sessions = append(sm.sessions, tracked) sm.sessions = append(sm.sessions, tracked)
...@@ -94,11 +106,9 @@ func (sm *SessionManager) ReceiveBlockFrom(from peer.ID, blk blocks.Block) { ...@@ -94,11 +106,9 @@ func (sm *SessionManager) ReceiveBlockFrom(from peer.ID, blk blocks.Block) {
defer sm.sessLk.Unlock() defer sm.sessLk.Unlock()
k := blk.Cid() k := blk.Cid()
ks := []cid.Cid{k}
for _, s := range sm.sessions { for _, s := range sm.sessions {
if s.session.InterestedIn(k) { if s.session.InterestedIn(k) {
s.session.ReceiveBlockFrom(from, blk) s.session.ReceiveBlockFrom(from, blk)
sm.wm.CancelWants(sm.ctx, ks, nil, s.session.ID())
} }
} }
} }
package sessionmanager
import (
"context"
"testing"
"time"
bssession "github.com/ipfs/go-bitswap/session"
blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-peer"
)
type fakeSession struct {
interested bool
receivedBlock bool
id uint64
pm *fakePeerManager
}
func (*fakeSession) GetBlock(context.Context, cid.Cid) (blocks.Block, error) {
return nil, nil
}
func (*fakeSession) GetBlocks(context.Context, []cid.Cid) (<-chan blocks.Block, error) {
return nil, nil
}
func (fs *fakeSession) InterestedIn(cid.Cid) bool { return fs.interested }
func (fs *fakeSession) ReceiveBlockFrom(peer.ID, blocks.Block) { fs.receivedBlock = true }
type fakePeerManager struct {
id uint64
}
func (*fakePeerManager) FindMorePeers(context.Context, cid.Cid) {}
func (*fakePeerManager) GetOptimizedPeers() []peer.ID { return nil }
func (*fakePeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {}
func (*fakePeerManager) RecordPeerResponse(peer.ID, cid.Cid) {}
var nextInterestedIn bool
func sessionFactory(ctx context.Context, id uint64, pm bssession.PeerManager) Session {
return &fakeSession{
interested: nextInterestedIn,
receivedBlock: false,
id: id,
pm: pm.(*fakePeerManager),
}
}
func peerManagerFactory(ctx context.Context, id uint64) bssession.PeerManager {
return &fakePeerManager{id}
}
func TestAddingSessions(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
currentID := sm.GetNextSessionID()
firstSession := sm.NewSession(ctx).(*fakeSession)
if firstSession.id != firstSession.pm.id ||
firstSession.id != currentID+1 {
t.Fatal("session does not have correct id set")
}
secondSession := sm.NewSession(ctx).(*fakeSession)
if secondSession.id != secondSession.pm.id ||
secondSession.id != firstSession.id+1 {
t.Fatal("session does not have correct id set")
}
sm.GetNextSessionID()
thirdSession := sm.NewSession(ctx).(*fakeSession)
if thirdSession.id != thirdSession.pm.id ||
thirdSession.id != secondSession.id+2 {
t.Fatal("session does not have correct id set")
}
sm.ReceiveBlockFrom(p, block)
if !firstSession.receivedBlock ||
!secondSession.receivedBlock ||
!thirdSession.receivedBlock {
t.Fatal("should have received blocks but didn't")
}
}
func TestReceivingBlocksWhenNotInterested(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = false
firstSession := sm.NewSession(ctx).(*fakeSession)
nextInterestedIn = true
secondSession := sm.NewSession(ctx).(*fakeSession)
nextInterestedIn = false
thirdSession := sm.NewSession(ctx).(*fakeSession)
sm.ReceiveBlockFrom(p, block)
if firstSession.receivedBlock ||
!secondSession.receivedBlock ||
thirdSession.receivedBlock {
t.Fatal("did not receive blocks only for interested sessions")
}
}
func TestRemovingPeersWhenManagerContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
firstSession := sm.NewSession(ctx).(*fakeSession)
secondSession := sm.NewSession(ctx).(*fakeSession)
thirdSession := sm.NewSession(ctx).(*fakeSession)
cancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlockFrom(p, block)
if firstSession.receivedBlock ||
secondSession.receivedBlock ||
thirdSession.receivedBlock {
t.Fatal("received blocks for sessions after manager is shutdown")
}
}
func TestRemovingPeersWhenSessionContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
firstSession := sm.NewSession(ctx).(*fakeSession)
sessionCtx, sessionCancel := context.WithCancel(ctx)
secondSession := sm.NewSession(sessionCtx).(*fakeSession)
thirdSession := sm.NewSession(ctx).(*fakeSession)
sessionCancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlockFrom(p, block)
if !firstSession.receivedBlock ||
secondSession.receivedBlock ||
!thirdSession.receivedBlock {
t.Fatal("received blocks for sessions that are canceled")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment