diff --git a/message/message.go b/message/message.go index 7ede57f87248959ba613aa6317b38385f6b51aed..9a166c942e582af51702ccc1775f171e91998d26 100644 --- a/message/message.go +++ b/message/message.go @@ -50,7 +50,7 @@ type Exportable interface { type impl struct { full bool - wantlist map[string]Entry + wantlist map[string]*Entry blocks map[string]blocks.Block } @@ -61,7 +61,7 @@ func New(full bool) BitSwapMessage { func newMsg(full bool) *impl { return &impl{ blocks: make(map[string]blocks.Block), - wantlist: make(map[string]Entry), + wantlist: make(map[string]*Entry), full: full, } } @@ -122,7 +122,7 @@ func (m *impl) Empty() bool { func (m *impl) Wantlist() []Entry { out := make([]Entry, 0, len(m.wantlist)) for _, e := range m.wantlist { - out = append(out, e) + out = append(out, *e) } return out } @@ -151,7 +151,7 @@ func (m *impl) addEntry(c *cid.Cid, priority int, cancel bool) { e.Priority = priority e.Cancel = cancel } else { - m.wantlist[k] = Entry{ + m.wantlist[k] = &Entry{ Entry: &wantlist.Entry{ Cid: c, Priority: priority, diff --git a/notifications/notifications.go b/notifications/notifications.go index ba5b379ec8bb773519da686b5575e9f4455e89d8..9a6f10b525610a1f65f80d10029243810faa4a2d 100644 --- a/notifications/notifications.go +++ b/notifications/notifications.go @@ -2,6 +2,7 @@ package notifications import ( "context" + "sync" blocks "gx/ipfs/Qmej7nf81hi2x2tvjRBF3mcp74sQyuDH4VMYDGd1YtXjb2/go-block-format" @@ -18,18 +19,43 @@ type PubSub interface { } func New() PubSub { - return &impl{*pubsub.New(bufferSize)} + return &impl{ + wrapped: *pubsub.New(bufferSize), + cancel: make(chan struct{}), + } } type impl struct { wrapped pubsub.PubSub + + // These two fields make up a shutdown "lock". + // We need them as calling, e.g., `Unsubscribe` after calling `Shutdown` + // blocks forever and fixing this in pubsub would be rather invasive. + cancel chan struct{} + wg sync.WaitGroup } func (ps *impl) Publish(block blocks.Block) { + ps.wg.Add(1) + defer ps.wg.Done() + + select { + case <-ps.cancel: + // Already shutdown, bail. + return + default: + } + ps.wrapped.Pub(block, block.Cid().KeyString()) } +// Not safe to call more than once. func (ps *impl) Shutdown() { + // Interrupt in-progress subscriptions. + close(ps.cancel) + // Wait for them to finish. + ps.wg.Wait() + // shutdown the pubsub. ps.wrapped.Shutdown() } @@ -44,12 +70,34 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B close(blocksCh) return blocksCh } + + // prevent shutdown + ps.wg.Add(1) + + // check if shutdown *after* preventing shutdowns. + select { + case <-ps.cancel: + // abort, allow shutdown to continue. + ps.wg.Done() + close(blocksCh) + return blocksCh + default: + } + ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...) go func() { - defer close(blocksCh) - defer ps.wrapped.Unsub(valuesCh) // with a len(keys) buffer, this is an optimization + defer func() { + ps.wrapped.Unsub(valuesCh) + close(blocksCh) + + // Unblock shutdown. + ps.wg.Done() + }() + for { select { + case <-ps.cancel: + return case <-ctx.Done(): return case val, ok := <-valuesCh: @@ -61,6 +109,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B return } select { + case <-ps.cancel: + return case <-ctx.Done(): return case blocksCh <- block: // continue diff --git a/notifications/notifications_test.go b/notifications/notifications_test.go index 0377c307d6a4cf42056e4b61ac1b0c8cbf612c2d..a70a0755a0038d3ccdb2fb134630f8aee7b33f15 100644 --- a/notifications/notifications_test.go +++ b/notifications/notifications_test.go @@ -100,6 +100,25 @@ func TestDuplicateSubscribe(t *testing.T) { assertBlocksEqual(t, e1, r2) } +func TestShutdownBeforeUnsubscribe(t *testing.T) { + e1 := blocks.NewBlock([]byte("1")) + + n := New() + ctx, cancel := context.WithCancel(context.Background()) + ch := n.Subscribe(ctx, e1.Cid()) // no keys provided + n.Shutdown() + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Fatal("channel should have been closed") + } + default: + t.Fatal("channel should have been closed") + } +} + func TestSubscribeIsANoopWhenCalledWithNoKeys(t *testing.T) { n := New() defer n.Shutdown() diff --git a/session.go b/session.go index 07444ad36f4b45e99510662cf1dfc2b6040c9cf5..937376723e2e7219311568df875d8d00a17ffa82 100644 --- a/session.go +++ b/session.go @@ -83,6 +83,15 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session { } func (bs *Bitswap) removeSession(s *Session) { + s.notif.Shutdown() + + live := make([]*cid.Cid, 0, len(s.liveWants)) + for c := range s.liveWants { + cs, _ := cid.Cast([]byte(c)) + live = append(live, cs) + } + bs.CancelWants(live, s.id) + bs.sessLk.Lock() defer bs.sessLk.Unlock() for i := 0; i < len(bs.sessions); i++ { @@ -270,8 +279,9 @@ func (s *Session) receiveBlock(ctx context.Context, blk blocks.Block) { } func (s *Session) wantBlocks(ctx context.Context, ks []*cid.Cid) { + now := time.Now() for _, c := range ks { - s.liveWants[c.KeyString()] = time.Now() + s.liveWants[c.KeyString()] = now } s.bs.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id) } diff --git a/session_test.go b/session_test.go index 6458904548ab7a524a3696c7e68d304b297b0ff8..2fe4672b06c67d126adb138fc6c631a5b693376b 100644 --- a/session_test.go +++ b/session_test.go @@ -285,3 +285,36 @@ func TestMultipleSessions(t *testing.T) { } _ = blkch } + +func TestWantlistClearsOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vnet := getVirtualNetwork() + sesgen := NewTestSessionGenerator(vnet) + defer sesgen.Close() + bgen := blocksutil.NewBlockGenerator() + + blks := bgen.Blocks(10) + var cids []*cid.Cid + for _, blk := range blks { + cids = append(cids, blk.Cid()) + } + + inst := sesgen.Instances(1) + + a := inst[0] + + ctx1, cancel1 := context.WithCancel(ctx) + ses := a.Exchange.NewSession(ctx1) + + _, err := ses.GetBlocks(ctx, cids) + if err != nil { + t.Fatal(err) + } + cancel1() + + if len(a.Exchange.GetWantlist()) > 0 { + t.Fatal("expected empty wantlist") + } +} diff --git a/testnet/virtual.go b/testnet/virtual.go index bcb00d14e7f854005c2019d17ef7d2fb82af102b..7e7ee185ce9432881c2f1cf695bb417951da9e19 100644 --- a/testnet/virtual.go +++ b/testnet/virtual.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "time" bsmsg "github.com/ipfs/go-ipfs/exchange/bitswap/message" bsnet "github.com/ipfs/go-ipfs/exchange/bitswap/network" @@ -22,7 +23,7 @@ var log = logging.Logger("bstestnet") func VirtualNetwork(rs mockrouting.Server, d delay.D) Network { return &network{ - clients: make(map[peer.ID]bsnet.Receiver), + clients: make(map[peer.ID]*receiverQueue), delay: d, routingserver: rs, conns: make(map[string]struct{}), @@ -31,12 +32,28 @@ func VirtualNetwork(rs mockrouting.Server, d delay.D) Network { type network struct { mu sync.Mutex - clients map[peer.ID]bsnet.Receiver + clients map[peer.ID]*receiverQueue routingserver mockrouting.Server delay delay.D conns map[string]struct{} } +type message struct { + from peer.ID + msg bsmsg.BitSwapMessage + shouldSend time.Time +} + +// receiverQueue queues up a set of messages to be sent, and sends them *in +// order* with their delays respected as much as sending them in order allows +// for +type receiverQueue struct { + receiver bsnet.Receiver + queue []*message + active bool + lk sync.Mutex +} + func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork { n.mu.Lock() defer n.mu.Unlock() @@ -46,7 +63,7 @@ func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork { network: n, routing: n.routingserver.Client(p), } - n.clients[p.ID()] = client + n.clients[p.ID()] = &receiverQueue{receiver: client} return client } @@ -64,7 +81,7 @@ func (n *network) SendMessage( ctx context.Context, from peer.ID, to peer.ID, - message bsmsg.BitSwapMessage) error { + mes bsmsg.BitSwapMessage) error { n.mu.Lock() defer n.mu.Unlock() @@ -77,7 +94,12 @@ func (n *network) SendMessage( // nb: terminate the context since the context wouldn't actually be passed // over the network in a real scenario - go n.deliver(receiver, from, message) + msg := &message{ + from: from, + msg: mes, + shouldSend: time.Now().Add(n.delay.Get()), + } + receiver.enqueue(msg) return nil } @@ -191,11 +213,38 @@ func (nc *networkClient) ConnectTo(_ context.Context, p peer.ID) error { // TODO: add handling for disconnects - otherClient.PeerConnected(nc.local) + otherClient.receiver.PeerConnected(nc.local) nc.Receiver.PeerConnected(p) return nil } +func (rq *receiverQueue) enqueue(m *message) { + rq.lk.Lock() + defer rq.lk.Unlock() + rq.queue = append(rq.queue, m) + if !rq.active { + rq.active = true + go rq.process() + } +} + +func (rq *receiverQueue) process() { + for { + rq.lk.Lock() + if len(rq.queue) == 0 { + rq.active = false + rq.lk.Unlock() + return + } + m := rq.queue[0] + rq.queue = rq.queue[1:] + rq.lk.Unlock() + + time.Sleep(time.Until(m.shouldSend)) + rq.receiver.ReceiveMessage(context.TODO(), m.from, m.msg) + } +} + func tagForPeers(a, b peer.ID) string { if a < b { return string(a + b)