Unverified Commit c980d7ed authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #177 from ipfs/refactor/use-global-pubsub-notifier

Refactor: use global pubsub notifier
parents a5fe0d4b 994279bd
......@@ -16,6 +16,7 @@ import (
bsmsg "github.com/ipfs/go-bitswap/message"
bsmq "github.com/ipfs/go-bitswap/messagequeue"
bsnet "github.com/ipfs/go-bitswap/network"
notifications "github.com/ipfs/go-bitswap/notifications"
bspm "github.com/ipfs/go-bitswap/peermanager"
bspqm "github.com/ipfs/go-bitswap/providerquerymanager"
bssession "github.com/ipfs/go-bitswap/session"
......@@ -116,9 +117,10 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
pqm := bspqm.New(ctx, network)
sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter,
notif notifications.PubSub,
provSearchDelay time.Duration,
rebroadcastDelay delay.D) bssm.Session {
return bssession.New(ctx, id, wm, pm, srs, provSearchDelay, rebroadcastDelay)
return bssession.New(ctx, id, wm, pm, srs, notif, provSearchDelay, rebroadcastDelay)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager {
return bsspm.New(ctx, id, network.ConnectionManager(), pqm)
......@@ -126,6 +128,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
sessionRequestSplitterFactory := func(ctx context.Context) bssession.RequestSplitter {
return bssrs.New(ctx)
}
notif := notifications.New()
bs := &Bitswap{
blockstore: bstore,
......@@ -136,7 +139,8 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
provideKeys: make(chan cid.Cid, provideKeysBufferSize),
wm: wm,
pqm: pqm,
sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory),
sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory, notif),
notif: notif,
counters: new(counters),
dupMetric: dupHist,
allMetric: allHist,
......@@ -163,6 +167,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
go func() {
<-px.Closing() // process closes first
cancelFunc()
notif.Shutdown()
}()
procctx.CloseAfterContext(px, ctx) // parent cancelled first
......@@ -187,6 +192,9 @@ type Bitswap struct {
// NB: ensure threadsafety
blockstore blockstore.Blockstore
// manages channels of outgoing blocks for sessions
notif notifications.PubSub
// newBlocks is a channel for newly added blocks to be provided to the
// network. blocks pushed down this channel get buffered and fed to the
// provideKeys channel later on to avoid too much network activity
......@@ -307,18 +315,38 @@ func (bs *Bitswap) receiveBlocksFrom(from peer.ID, blks []blocks.Block) error {
// to the same node. We should address this soon, but i'm not going to do
// it now as it requires more thought and isnt causing immediate problems.
// Send all blocks (including duplicates) to any sessions that want them.
allKs := make([]cid.Cid, 0, len(blks))
for _, b := range blks {
allKs = append(allKs, b.Cid())
}
wantedKs := allKs
if len(blks) != len(wanted) {
wantedKs = make([]cid.Cid, 0, len(wanted))
for _, b := range wanted {
wantedKs = append(wantedKs, b.Cid())
}
}
// Send all block keys (including duplicates) to any sessions that want them.
// (The duplicates are needed by sessions for accounting purposes)
bs.sm.ReceiveBlocksFrom(from, blks)
bs.sm.ReceiveFrom(from, allKs)
// Send wanted blocks to decision engine
bs.engine.AddBlocks(wanted)
// Send wanted block keys to decision engine
bs.engine.AddBlocks(wantedKs)
// Publish the block to any Bitswap clients that had requested blocks.
// (the sessions use this pubsub mechanism to inform clients of received
// blocks)
for _, b := range wanted {
bs.notif.Publish(b)
}
// If the reprovider is enabled, send wanted blocks to reprovider
if bs.provideEnabled {
for _, b := range wanted {
for _, k := range wantedKs {
select {
case bs.newBlocks <- b.Cid():
case bs.newBlocks <- k:
// send block off to be reprovided
case <-bs.process.Closing():
return bs.process.Close()
......
......@@ -10,7 +10,6 @@ import (
"github.com/google/uuid"
bsmsg "github.com/ipfs/go-bitswap/message"
wl "github.com/ipfs/go-bitswap/wantlist"
blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid"
bstore "github.com/ipfs/go-ipfs-blockstore"
logging "github.com/ipfs/go-log"
......@@ -312,13 +311,13 @@ func (e *Engine) MessageReceived(p peer.ID, m bsmsg.BitSwapMessage) {
}
}
func (e *Engine) addBlocks(blocks []blocks.Block) {
func (e *Engine) addBlocks(ks []cid.Cid) {
work := false
for _, l := range e.ledgerMap {
l.lk.Lock()
for _, block := range blocks {
if entry, ok := l.WantListContains(block.Cid()); ok {
for _, k := range ks {
if entry, ok := l.WantListContains(k); ok {
e.peerRequestQueue.PushBlock(l.Partner, peertask.Task{
Identifier: entry.Cid,
Priority: entry.Priority,
......@@ -337,11 +336,11 @@ func (e *Engine) addBlocks(blocks []blocks.Block) {
// AddBlocks is called when new blocks are received and added to a block store,
// meaning there may be peers who want those blocks, so we should send the blocks
// to them.
func (e *Engine) AddBlocks(blocks []blocks.Block) {
func (e *Engine) AddBlocks(ks []cid.Cid) {
e.lock.Lock()
defer e.lock.Unlock()
e.addBlocks(blocks)
e.addBlocks(ks)
}
// TODO add contents of m.WantList() to my local wantlist? NB: could introduce
......
......@@ -61,15 +61,19 @@ func SyncGetBlock(p context.Context, k cid.Cid, gb GetBlocksFunc) (blocks.Block,
type WantFunc func(context.Context, []cid.Cid)
// AsyncGetBlocks take a set of block cids, a pubsub channel for incoming
// blocks, a want function, and a close function,
// and returns a channel of incoming blocks.
func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.PubSub, want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
// blocks, a want function, and a close function, and returns a channel of
// incoming blocks.
func AsyncGetBlocks(ctx context.Context, sessctx context.Context, keys []cid.Cid, notif notifications.PubSub,
want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) {
// If there are no keys supplied, just return a closed channel
if len(keys) == 0 {
out := make(chan blocks.Block)
close(out)
return out, nil
}
// Use a PubSub notifier to listen for incoming blocks for each key
remaining := cid.NewSet()
promise := notif.Subscribe(ctx, keys...)
for _, k := range keys {
......@@ -77,24 +81,36 @@ func AsyncGetBlocks(ctx context.Context, keys []cid.Cid, notif notifications.Pub
remaining.Add(k)
}
// Send the want request for the keys to the network
want(ctx, keys)
out := make(chan blocks.Block)
go handleIncoming(ctx, remaining, promise, out, cwants)
go handleIncoming(ctx, sessctx, remaining, promise, out, cwants)
return out, nil
}
func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
// Listens for incoming blocks, passing them to the out channel.
// If the context is cancelled or the incoming channel closes, calls cfun with
// any keys corresponding to blocks that were never received.
func handleIncoming(ctx context.Context, sessctx context.Context, remaining *cid.Set,
in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) {
ctx, cancel := context.WithCancel(ctx)
// Clean up before exiting this function, and call the cancel function on
// any remaining keys
defer func() {
cancel()
close(out)
// can't just defer this call on its own, arguments are resolved *when* the defer is created
cfun(remaining.Keys())
}()
for {
select {
case blk, ok := <-in:
// If the channel is closed, we're done (note that PubSub closes
// the channel once all the keys have been received)
if !ok {
return
}
......@@ -104,9 +120,13 @@ func handleIncoming(ctx context.Context, remaining *cid.Set, in <-chan blocks.Bl
case out <- blk:
case <-ctx.Done():
return
case <-sessctx.Done():
return
}
case <-ctx.Done():
return
case <-sessctx.Done():
return
}
}
}
......@@ -60,8 +60,8 @@ func (ps *impl) Shutdown() {
}
// Subscribe returns a channel of blocks for the given |keys|. |blockChannel|
// is closed if the |ctx| times out or is cancelled, or after sending len(keys)
// blocks.
// is closed if the |ctx| times out or is cancelled, or after receiving the blocks
// corresponding to |keys|.
func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Block {
blocksCh := make(chan blocks.Block, len(keys))
......@@ -82,6 +82,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
default:
}
// AddSubOnceEach listens for each key in the list, and closes the channel
// once all keys have been received
ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
go func() {
defer func() {
......
......@@ -52,9 +52,9 @@ type interestReq struct {
resp chan bool
}
type blksRecv struct {
type rcvFrom struct {
from peer.ID
blks []blocks.Block
ks []cid.Cid
}
// Session holds state for an individual bitswap transfer operation.
......@@ -68,7 +68,7 @@ type Session struct {
srs RequestSplitter
// channels
incoming chan blksRecv
incoming chan rcvFrom
newReqs chan []cid.Cid
cancelKeys chan []cid.Cid
interestReqs chan interestReq
......@@ -101,6 +101,7 @@ func New(ctx context.Context,
wm WantManager,
pm PeerManager,
srs RequestSplitter,
notif notifications.PubSub,
initialSearchDelay time.Duration,
periodicSearchDelay delay.D) *Session {
s := &Session{
......@@ -116,8 +117,8 @@ func New(ctx context.Context,
wm: wm,
pm: pm,
srs: srs,
incoming: make(chan blksRecv),
notif: notifications.New(),
incoming: make(chan rcvFrom),
notif: notif,
uuid: loggables.Uuid("GetBlockRequest"),
baseTickDelay: time.Millisecond * 500,
id: id,
......@@ -133,10 +134,10 @@ func New(ctx context.Context,
return s
}
// ReceiveBlocksFrom receives incoming blocks from the given peer.
func (s *Session) ReceiveBlocksFrom(from peer.ID, blocks []blocks.Block) {
// ReceiveFrom receives incoming blocks from the given peer.
func (s *Session) ReceiveFrom(from peer.ID, ks []cid.Cid) {
select {
case s.incoming <- blksRecv{from: from, blks: blocks}:
case s.incoming <- rcvFrom{from: from, ks: ks}:
case <-s.ctx.Done():
}
}
......@@ -181,7 +182,8 @@ func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, err
// guaranteed on the returned blocks.
func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) {
ctx = logging.ContextWithLoggable(ctx, s.uuid)
return bsgetter.AsyncGetBlocks(ctx, keys, s.notif,
return bsgetter.AsyncGetBlocks(ctx, s.ctx, keys, s.notif,
func(ctx context.Context, keys []cid.Cid) {
select {
case s.newReqs <- keys:
......@@ -231,13 +233,13 @@ func (s *Session) run(ctx context.Context) {
for {
select {
case rcv := <-s.incoming:
s.cancelIncomingBlocks(ctx, rcv)
s.cancelIncoming(ctx, rcv)
// Record statistics only if the blocks came from the network
// (blocks can also be received from the local node)
if rcv.from != "" {
s.updateReceiveCounters(ctx, rcv)
}
s.handleIncomingBlocks(ctx, rcv)
s.handleIncoming(ctx, rcv)
case keys := <-s.newReqs:
s.handleNewRequest(ctx, keys)
case keys := <-s.cancelKeys:
......@@ -259,23 +261,23 @@ func (s *Session) run(ctx context.Context) {
}
}
func (s *Session) cancelIncomingBlocks(ctx context.Context, rcv blksRecv) {
func (s *Session) cancelIncoming(ctx context.Context, rcv rcvFrom) {
// We've received the blocks so we can cancel any outstanding wants for them
ks := make([]cid.Cid, 0, len(rcv.blks))
for _, b := range rcv.blks {
if s.cidIsWanted(b.Cid()) {
ks = append(ks, b.Cid())
wanted := make([]cid.Cid, 0, len(rcv.ks))
for _, k := range rcv.ks {
if s.cidIsWanted(k) {
wanted = append(wanted, k)
}
}
s.pm.RecordCancels(ks)
s.wm.CancelWants(s.ctx, ks, nil, s.id)
s.pm.RecordCancels(wanted)
s.wm.CancelWants(s.ctx, wanted, nil, s.id)
}
func (s *Session) handleIncomingBlocks(ctx context.Context, rcv blksRecv) {
func (s *Session) handleIncoming(ctx context.Context, rcv rcvFrom) {
s.idleTick.Stop()
// Process the received blocks
s.receiveBlocks(ctx, rcv.blks)
s.processIncoming(ctx, rcv.ks)
s.resetIdleTick()
}
......@@ -359,7 +361,6 @@ func (s *Session) randomLiveWant() cid.Cid {
}
func (s *Session) handleShutdown() {
s.idleTick.Stop()
s.notif.Shutdown()
live := make([]cid.Cid, 0, len(s.liveWants))
for c := range s.liveWants {
......@@ -376,9 +377,8 @@ func (s *Session) cidIsWanted(c cid.Cid) bool {
return ok
}
func (s *Session) receiveBlocks(ctx context.Context, blocks []blocks.Block) {
for _, blk := range blocks {
c := blk.Cid()
func (s *Session) processIncoming(ctx context.Context, ks []cid.Cid) {
for _, c := range ks {
if s.cidIsWanted(c) {
// If the block CID was in the live wants queue, remove it
tval, ok := s.liveWants[c]
......@@ -395,8 +395,6 @@ func (s *Session) receiveBlocks(ctx context.Context, blocks []blocks.Block) {
// that have occurred since the last new block
s.consecutiveTicks = 0
s.notif.Publish(blk)
// Keep track of CIDs we've successfully fetched
s.pastWants.Push(c)
}
......@@ -417,23 +415,19 @@ func (s *Session) receiveBlocks(ctx context.Context, blocks []blocks.Block) {
}
}
func (s *Session) updateReceiveCounters(ctx context.Context, rcv blksRecv) {
ks := make([]cid.Cid, len(rcv.blks))
for _, blk := range rcv.blks {
func (s *Session) updateReceiveCounters(ctx context.Context, rcv rcvFrom) {
for _, k := range rcv.ks {
// Inform the request splitter of unique / duplicate blocks
if s.cidIsWanted(blk.Cid()) {
if s.cidIsWanted(k) {
s.srs.RecordUniqueBlock()
} else if s.pastWants.Has(blk.Cid()) {
} else if s.pastWants.Has(k) {
s.srs.RecordDuplicateBlock()
}
ks = append(ks, blk.Cid())
}
// Record response (to be able to time latency)
if len(ks) > 0 {
s.pm.RecordPeerResponse(rcv.from, ks)
if len(rcv.ks) > 0 {
s.pm.RecordPeerResponse(rcv.from, rcv.ks)
}
}
......
......@@ -6,6 +6,7 @@ import (
"testing"
"time"
notifications "github.com/ipfs/go-bitswap/notifications"
bssd "github.com/ipfs/go-bitswap/sessiondata"
"github.com/ipfs/go-bitswap/testutil"
blocks "github.com/ipfs/go-block-format"
......@@ -92,8 +93,10 @@ func TestSessionGetBlocks(t *testing.T) {
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{}
frs := &fakeRequestSplitter{}
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, frs, time.Second, delay.Fixed(time.Minute))
session := New(ctx, id, fwm, fpm, frs, notif, time.Second, delay.Fixed(time.Minute))
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
var cids []cid.Cid
......@@ -122,7 +125,13 @@ func TestSessionGetBlocks(t *testing.T) {
var newBlockReqs []wantReq
var receivedBlocks []blocks.Block
for i, p := range peers {
session.ReceiveBlocksFrom(p, []blocks.Block{blks[testutil.IndexOf(blks, receivedWantReq.cids[i])]})
// simulate what bitswap does on receiving a message:
// - calls ReceiveFrom() on session
// - publishes block to pubsub channel
blk := blks[testutil.IndexOf(blks, receivedWantReq.cids[i])]
session.ReceiveFrom(p, []cid.Cid{blk.Cid()})
notif.Publish(blk)
select {
case cancelBlock := <-cancelReqs:
newCancelReqs = append(newCancelReqs, cancelBlock)
......@@ -178,7 +187,13 @@ func TestSessionGetBlocks(t *testing.T) {
// receive remaining blocks
for i, p := range peers {
session.ReceiveBlocksFrom(p, []blocks.Block{blks[testutil.IndexOf(blks, newCidsRequested[i])]})
// simulate what bitswap does on receiving a message:
// - calls ReceiveFrom() on session
// - publishes block to pubsub channel
blk := blks[testutil.IndexOf(blks, newCidsRequested[i])]
session.ReceiveFrom(p, []cid.Cid{blk.Cid()})
notif.Publish(blk)
receivedBlock := <-getBlocksCh
receivedBlocks = append(receivedBlocks, receivedBlock)
cancelBlock := <-cancelReqs
......@@ -207,8 +222,10 @@ func TestSessionFindMorePeers(t *testing.T) {
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{findMorePeersRequested: make(chan cid.Cid, 1)}
frs := &fakeRequestSplitter{}
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, frs, time.Second, delay.Fixed(time.Minute))
session := New(ctx, id, fwm, fpm, frs, notif, time.Second, delay.Fixed(time.Minute))
session.SetBaseTickDelay(200 * time.Microsecond)
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2)
......@@ -233,7 +250,13 @@ func TestSessionFindMorePeers(t *testing.T) {
// or there will be no tick set -- time precision on Windows in go is in the
// millisecond range
p := testutil.GeneratePeers(1)[0]
session.ReceiveBlocksFrom(p, []blocks.Block{blks[0]})
// simulate what bitswap does on receiving a message:
// - calls ReceiveFrom() on session
// - publishes block to pubsub channel
blk := blks[0]
session.ReceiveFrom(p, []cid.Cid{blk.Cid()})
notif.Publish(blk)
select {
case <-cancelReqs:
case <-ctx.Done():
......@@ -279,9 +302,11 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{findMorePeersRequested: make(chan cid.Cid, 1)}
frs := &fakeRequestSplitter{}
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
session := New(ctx, id, fwm, fpm, frs, 10*time.Millisecond, delay.Fixed(100*time.Millisecond))
session := New(ctx, id, fwm, fpm, frs, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond))
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(4)
var cids []cid.Cid
......@@ -391,3 +416,45 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) {
t.Fatal("Did not rebroadcast to find more peers")
}
}
func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) {
wantReqs := make(chan wantReq, 1)
cancelReqs := make(chan wantReq, 1)
fwm := &fakeWantManager{wantReqs, cancelReqs}
fpm := &fakePeerManager{}
frs := &fakeRequestSplitter{}
notif := notifications.New()
defer notif.Shutdown()
id := testutil.GenerateSessionID()
// Create a new session with its own context
sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
session := New(sessctx, id, fwm, fpm, frs, notif, time.Second, delay.Fixed(time.Minute))
timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer timerCancel()
// Request a block with a new context
blockGenerator := blocksutil.NewBlockGenerator()
blks := blockGenerator.Blocks(1)
getctx, getcancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer getcancel()
getBlocksCh, err := session.GetBlocks(getctx, []cid.Cid{blks[0].Cid()})
if err != nil {
t.Fatal("error getting blocks")
}
// Cancel the session context
sesscancel()
// Expect the GetBlocks() channel to be closed
select {
case _, ok := <-getBlocksCh:
if ok {
t.Fatal("expected channel to be closed but was not closed")
}
case <-timerCtx.Done():
t.Fatal("expected channel to be closed before timeout")
}
}
......@@ -5,10 +5,10 @@ import (
"sync"
"time"
blocks "github.com/ipfs/go-block-format"
cid "github.com/ipfs/go-cid"
delay "github.com/ipfs/go-ipfs-delay"
notifications "github.com/ipfs/go-bitswap/notifications"
bssession "github.com/ipfs/go-bitswap/session"
exchange "github.com/ipfs/go-ipfs-exchange-interface"
peer "github.com/libp2p/go-libp2p-core/peer"
......@@ -18,7 +18,7 @@ import (
type Session interface {
exchange.Fetcher
InterestedIn(cid.Cid) bool
ReceiveBlocksFrom(peer.ID, []blocks.Block)
ReceiveFrom(peer.ID, []cid.Cid)
}
type sesTrk struct {
......@@ -28,7 +28,7 @@ type sesTrk struct {
}
// SessionFactory generates a new session for the SessionManager to track.
type SessionFactory func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter, provSearchDelay time.Duration, rebroadcastDelay delay.D) Session
type SessionFactory func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter, notif notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D) Session
// RequestSplitterFactory generates a new request splitter for a session.
type RequestSplitterFactory func(ctx context.Context) bssession.RequestSplitter
......@@ -43,6 +43,7 @@ type SessionManager struct {
sessionFactory SessionFactory
peerManagerFactory PeerManagerFactory
requestSplitterFactory RequestSplitterFactory
notif notifications.PubSub
// Sessions
sessLk sync.Mutex
......@@ -54,12 +55,14 @@ type SessionManager struct {
}
// New creates a new SessionManager.
func New(ctx context.Context, sessionFactory SessionFactory, peerManagerFactory PeerManagerFactory, requestSplitterFactory RequestSplitterFactory) *SessionManager {
func New(ctx context.Context, sessionFactory SessionFactory, peerManagerFactory PeerManagerFactory,
requestSplitterFactory RequestSplitterFactory, notif notifications.PubSub) *SessionManager {
return &SessionManager{
ctx: ctx,
sessionFactory: sessionFactory,
peerManagerFactory: peerManagerFactory,
requestSplitterFactory: requestSplitterFactory,
notif: notif,
}
}
......@@ -73,7 +76,7 @@ func (sm *SessionManager) NewSession(ctx context.Context,
pm := sm.peerManagerFactory(sessionctx, id)
srs := sm.requestSplitterFactory(sessionctx)
session := sm.sessionFactory(sessionctx, id, pm, srs, provSearchDelay, rebroadcastDelay)
session := sm.sessionFactory(sessionctx, id, pm, srs, sm.notif, provSearchDelay, rebroadcastDelay)
tracked := sesTrk{session, pm, srs}
sm.sessLk.Lock()
sm.sessions = append(sm.sessions, tracked)
......@@ -111,20 +114,20 @@ func (sm *SessionManager) GetNextSessionID() uint64 {
return sm.sessID
}
// ReceiveBlocksFrom receives blocks from a peer and dispatches to interested
// ReceiveFrom receives blocks from a peer and dispatches to interested
// sessions.
func (sm *SessionManager) ReceiveBlocksFrom(from peer.ID, blks []blocks.Block) {
func (sm *SessionManager) ReceiveFrom(from peer.ID, ks []cid.Cid) {
sm.sessLk.Lock()
defer sm.sessLk.Unlock()
// Only give each session the blocks / dups that it is interested in
for _, s := range sm.sessions {
sessBlks := make([]blocks.Block, 0, len(blks))
for _, b := range blks {
if s.session.InterestedIn(b.Cid()) {
sessBlks = append(sessBlks, b)
sessKs := make([]cid.Cid, 0, len(ks))
for _, k := range ks {
if s.session.InterestedIn(k) {
sessKs = append(sessKs, k)
}
}
s.session.ReceiveBlocksFrom(from, sessBlks)
s.session.ReceiveFrom(from, sessKs)
}
}
......@@ -7,6 +7,7 @@ import (
delay "github.com/ipfs/go-ipfs-delay"
notifications "github.com/ipfs/go-bitswap/notifications"
bssession "github.com/ipfs/go-bitswap/session"
bssd "github.com/ipfs/go-bitswap/sessiondata"
"github.com/ipfs/go-bitswap/testutil"
......@@ -18,10 +19,11 @@ import (
type fakeSession struct {
interested []cid.Cid
blks []blocks.Block
ks []cid.Cid
id uint64
pm *fakePeerManager
srs *fakeRequestSplitter
notif notifications.PubSub
}
func (*fakeSession) GetBlock(context.Context, cid.Cid) (blocks.Block, error) {
......@@ -38,8 +40,8 @@ func (fs *fakeSession) InterestedIn(c cid.Cid) bool {
}
return false
}
func (fs *fakeSession) ReceiveBlocksFrom(p peer.ID, blks []blocks.Block) {
fs.blks = append(fs.blks, blks...)
func (fs *fakeSession) ReceiveFrom(p peer.ID, ks []cid.Cid) {
fs.ks = append(fs.ks, ks...)
}
type fakePeerManager struct {
......@@ -67,6 +69,7 @@ func sessionFactory(ctx context.Context,
id uint64,
pm bssession.PeerManager,
srs bssession.RequestSplitter,
notif notifications.PubSub,
provSearchDelay time.Duration,
rebroadcastDelay delay.D) Session {
return &fakeSession{
......@@ -74,6 +77,7 @@ func sessionFactory(ctx context.Context,
id: id,
pm: pm.(*fakePeerManager),
srs: srs.(*fakeRequestSplitter),
notif: notif,
}
}
......@@ -86,17 +90,13 @@ func requestSplitterFactory(ctx context.Context) bssession.RequestSplitter {
}
func cmpSessionCids(s *fakeSession, cids []cid.Cid) bool {
return cmpBlockCids(s.blks, cids)
}
func cmpBlockCids(blks []blocks.Block, cids []cid.Cid) bool {
if len(blks) != len(cids) {
if len(s.ks) != len(cids) {
return false
}
for _, b := range blks {
for _, bk := range s.ks {
has := false
for _, c := range cids {
if c == b.Cid() {
if c == bk {
has = true
}
}
......@@ -111,7 +111,9 @@ func TestAddingSessions(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory)
notif := notifications.New()
defer notif.Shutdown()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory, notif)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
......@@ -135,10 +137,10 @@ func TestAddingSessions(t *testing.T) {
thirdSession.id != secondSession.id+2 {
t.Fatal("session does not have correct id set")
}
sm.ReceiveBlocksFrom(p, []blocks.Block{block})
if len(firstSession.blks) == 0 ||
len(secondSession.blks) == 0 ||
len(thirdSession.blks) == 0 {
sm.ReceiveFrom(p, []cid.Cid{block.Cid()})
if len(firstSession.ks) == 0 ||
len(secondSession.ks) == 0 ||
len(thirdSession.ks) == 0 {
t.Fatal("should have received blocks but didn't")
}
}
......@@ -147,7 +149,9 @@ func TestReceivingBlocksWhenNotInterested(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory)
notif := notifications.New()
defer notif.Shutdown()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory, notif)
p := peer.ID(123)
blks := testutil.GenerateBlocksOfSize(3, 1024)
......@@ -163,7 +167,7 @@ func TestReceivingBlocksWhenNotInterested(t *testing.T) {
nextInterestedIn = []cid.Cid{}
thirdSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession)
sm.ReceiveBlocksFrom(p, []blocks.Block{blks[0], blks[1]})
sm.ReceiveFrom(p, []cid.Cid{blks[0].Cid(), blks[1].Cid()})
if !cmpSessionCids(firstSession, []cid.Cid{cids[0], cids[1]}) ||
!cmpSessionCids(secondSession, []cid.Cid{cids[0]}) ||
......@@ -175,7 +179,9 @@ func TestReceivingBlocksWhenNotInterested(t *testing.T) {
func TestRemovingPeersWhenManagerContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory)
notif := notifications.New()
defer notif.Shutdown()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory, notif)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
......@@ -188,10 +194,10 @@ func TestRemovingPeersWhenManagerContextCancelled(t *testing.T) {
cancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlocksFrom(p, []blocks.Block{block})
if len(firstSession.blks) > 0 ||
len(secondSession.blks) > 0 ||
len(thirdSession.blks) > 0 {
sm.ReceiveFrom(p, []cid.Cid{block.Cid()})
if len(firstSession.ks) > 0 ||
len(secondSession.ks) > 0 ||
len(thirdSession.ks) > 0 {
t.Fatal("received blocks for sessions after manager is shutdown")
}
}
......@@ -200,7 +206,9 @@ func TestRemovingPeersWhenSessionContextCancelled(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory)
notif := notifications.New()
defer notif.Shutdown()
sm := New(ctx, sessionFactory, peerManagerFactory, requestSplitterFactory, notif)
p := peer.ID(123)
block := blocks.NewBlock([]byte("block"))
......@@ -214,10 +222,10 @@ func TestRemovingPeersWhenSessionContextCancelled(t *testing.T) {
sessionCancel()
// wait for sessions to get removed
time.Sleep(10 * time.Millisecond)
sm.ReceiveBlocksFrom(p, []blocks.Block{block})
if len(firstSession.blks) == 0 ||
len(secondSession.blks) > 0 ||
len(thirdSession.blks) == 0 {
sm.ReceiveFrom(p, []cid.Cid{block.Cid()})
if len(firstSession.ks) == 0 ||
len(secondSession.ks) > 0 ||
len(thirdSession.ks) == 0 {
t.Fatal("received blocks for sessions that are canceled")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment