Commit e1a25234 authored by hannahhoward's avatar hannahhoward

test(sessionmanager): Add unit test

Add a unit test and do some additional decoupling
parent d7a532d0
......@@ -16,9 +16,10 @@ import (
bsnet "github.com/ipfs/go-bitswap/network"
notifications "github.com/ipfs/go-bitswap/notifications"
bspm "github.com/ipfs/go-bitswap/peermanager"
bssession "github.com/ipfs/go-bitswap/session"
bssm "github.com/ipfs/go-bitswap/sessionmanager"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
bswm "github.com/ipfs/go-bitswap/wantmanager"
blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid"
blockstore "github.com/ipfs/go-ipfs-blockstore"
......@@ -102,6 +103,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
}
wm := bswm.New(ctx)
sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager) bssm.Session {
return bssession.New(ctx, id, wm, pm)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager {
return bsspm.New(ctx, id, network)
}
bs := &Bitswap{
blockstore: bstore,
notifications: notif,
......@@ -113,7 +121,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
provideKeys: make(chan cid.Cid, provideKeysBufferSize),
wm: wm,
pm: bspm.New(ctx, peerQueueFactory),
sm: bssm.New(ctx, wm, network),
sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory),
counters: new(counters),
dupMetric: dupHist,
allMetric: allHist,
......
......@@ -8,22 +8,34 @@ import (
cid "github.com/ipfs/go-cid"
bssession "github.com/ipfs/go-bitswap/session"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
exchange "github.com/ipfs/go-ipfs-exchange-interface"
peer "github.com/libp2p/go-libp2p-peer"
)
// Session is a session that is managed by the session manager
type Session interface {
exchange.Fetcher
InterestedIn(cid.Cid) bool
ReceiveBlockFrom(peer.ID, blocks.Block)
}
type sesTrk struct {
session *bssession.Session
pm *bsspm.SessionPeerManager
session Session
pm bssession.PeerManager
}
// SessionFactory generates a new session for the SessionManager to track.
type SessionFactory func(ctx context.Context, id uint64, pm bssession.PeerManager) Session
// PeerManagerFactory generates a new peer manager for a session.
type PeerManagerFactory func(ctx context.Context, id uint64) bssession.PeerManager
// SessionManager is responsible for creating, managing, and dispatching to
// sessions.
type SessionManager struct {
wm bssession.WantManager
network bsspm.PeerNetwork
ctx context.Context
ctx context.Context
sessionFactory SessionFactory
peerManagerFactory PeerManagerFactory
// Sessions
sessLk sync.Mutex
sessions []sesTrk
......@@ -34,11 +46,11 @@ type SessionManager struct {
}
// New creates a new SessionManager.
func New(ctx context.Context, wm bssession.WantManager, network bsspm.PeerNetwork) *SessionManager {
func New(ctx context.Context, sessionFactory SessionFactory, peerManagerFactory PeerManagerFactory) *SessionManager {
return &SessionManager{
ctx: ctx,
wm: wm,
network: network,
ctx: ctx,
sessionFactory: sessionFactory,
peerManagerFactory: peerManagerFactory,
}
}
......@@ -48,8 +60,8 @@ func (sm *SessionManager) NewSession(ctx context.Context) exchange.Fetcher {
id := sm.GetNextSessionID()
sessionctx, cancel := context.WithCancel(ctx)
pm := bsspm.New(sessionctx, id, sm.network)
session := bssession.New(sessionctx, id, sm.wm, pm)
pm := sm.peerManagerFactory(sessionctx, id)
session := sm.sessionFactory(sessionctx, id, pm)
tracked := sesTrk{session, pm}
sm.sessLk.Lock()
sm.sessions = append(sm.sessions, tracked)
......@@ -94,11 +106,9 @@ func (sm *SessionManager) ReceiveBlockFrom(from peer.ID, blk blocks.Block) {
defer sm.sessLk.Unlock()
k := blk.Cid()
ks := []cid.Cid{k}
for _, s := range sm.sessions {
if s.session.InterestedIn(k) {
s.session.ReceiveBlockFrom(from, blk)
sm.wm.CancelWants(sm.ctx, ks, nil, s.session.ID())
}
}
}
package sessionmanager
import (
"context"
"testing"
"time"
bssession "github.com/ipfs/go-bitswap/session"
blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-peer"
)
type fakeSession struct {
interested bool
receivedBlock bool
id uint64
pm *fakePeerManager
}
func (*fakeSession) GetBlock(context.Context, cid.Cid) (blocks.Block, error) {
return nil, nil
}
func (*fakeSession) GetBlocks(context.Context, []cid.Cid) (<-chan blocks.Block, error) {
return nil, nil
}
func (fs *fakeSession) InterestedIn(cid.Cid) bool { return fs.interested }
func (fs *fakeSession) ReceiveBlockFrom(peer.ID, blocks.Block) { fs.receivedBlock = true }
type fakePeerManager struct {
id uint64
}
func (*fakePeerManager) FindMorePeers(context.Context, cid.Cid) {}
func (*fakePeerManager) GetOptimizedPeers() []peer.ID { return nil }
func (*fakePeerManager) RecordPeerRequests([]peer.ID, []cid.Cid) {}
func (*fakePeerManager) RecordPeerResponse(peer.ID, cid.Cid) {}
var nextInterestedIn bool
func sessionFactory(ctx context.Context, id uint64, pm bssession.PeerManager) Session {
return &fakeSession{
interested: nextInterestedIn,
receivedBlock: false,
id: id,
pm: pm.(*fakePeerManager),
}
}
func peerManagerFactory(ctx context.Context, id uint64) bssession.PeerManager {
return &fakePeerManager{id}
}
func TestAddingSessions(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
currentID := sm.GetNextSessionID()
firstSession := sm.NewSession(ctx).(*fakeSession)
if firstSession.id != firstSession.pm.id ||
firstSession.id != currentID+1 {
t.Fatal("session does not have correct id set")
}
secondSession := sm.NewSession(ctx).(*fakeSession)
if secondSession.id != secondSession.pm.id ||
secondSession.id != firstSession.id+1 {
t.Fatal("session does not have correct id set")
}
sm.GetNextSessionID()
thirdSession := sm.NewSession(ctx).(*fakeSession)
if thirdSession.id != thirdSession.pm.id ||
thirdSession.id != secondSession.id+2 {
t.Fatal("session does not have correct id set")
}
sm.ReceiveBlockFrom(p, block)
if !firstSession.receivedBlock ||
!secondSession.receivedBlock ||
!thirdSession.receivedBlock {
t.Fatal("should have received blocks but didn't")
}
}
func TestReceivingBlocksWhenNotInterested(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = false
firstSession := sm.NewSession(ctx).(*fakeSession)
nextInterestedIn = true
secondSession := sm.NewSession(ctx).(*fakeSession)
nextInterestedIn = false
thirdSession := sm.NewSession(ctx).(*fakeSession)
sm.ReceiveBlockFrom(p, block)
if firstSession.receivedBlock ||
!secondSession.receivedBlock ||
thirdSession.receivedBlock {
t.Fatal("did not receive blocks only for interested sessions")
}
}
func TestRemovingPeersWhenManagerContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
firstSession := sm.NewSession(ctx).(*fakeSession)
secondSession := sm.NewSession(ctx).(*fakeSession)
thirdSession := sm.NewSession(ctx).(*fakeSession)
cancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlockFrom(p, block)
if firstSession.receivedBlock ||
secondSession.receivedBlock ||
thirdSession.receivedBlock {
t.Fatal("received blocks for sessions after manager is shutdown")
}
}
func TestRemovingPeersWhenSessionContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
// we'll be interested in all blocks for this test
nextInterestedIn = true
firstSession := sm.NewSession(ctx).(*fakeSession)
sessionCtx, sessionCancel := context.WithCancel(ctx)
secondSession := sm.NewSession(sessionCtx).(*fakeSession)
thirdSession := sm.NewSession(ctx).(*fakeSession)
sessionCancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlockFrom(p, block)
if !firstSession.receivedBlock ||
secondSession.receivedBlock ||
!thirdSession.receivedBlock {
t.Fatal("received blocks for sessions that are canceled")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment