Commit 66bcda37 authored by Whyrusleeping's avatar Whyrusleeping Committed by GitHub

Merge pull request #4658 from ipfs/fix/session-cleanup

shutdown notifications engine when closing a bitswap session
parents c9341aeb 0df75e41
...@@ -50,7 +50,7 @@ type Exportable interface { ...@@ -50,7 +50,7 @@ type Exportable interface {
type impl struct { type impl struct {
full bool full bool
wantlist map[string]Entry wantlist map[string]*Entry
blocks map[string]blocks.Block blocks map[string]blocks.Block
} }
...@@ -61,7 +61,7 @@ func New(full bool) BitSwapMessage { ...@@ -61,7 +61,7 @@ func New(full bool) BitSwapMessage {
func newMsg(full bool) *impl { func newMsg(full bool) *impl {
return &impl{ return &impl{
blocks: make(map[string]blocks.Block), blocks: make(map[string]blocks.Block),
wantlist: make(map[string]Entry), wantlist: make(map[string]*Entry),
full: full, full: full,
} }
} }
...@@ -122,7 +122,7 @@ func (m *impl) Empty() bool { ...@@ -122,7 +122,7 @@ func (m *impl) Empty() bool {
func (m *impl) Wantlist() []Entry { func (m *impl) Wantlist() []Entry {
out := make([]Entry, 0, len(m.wantlist)) out := make([]Entry, 0, len(m.wantlist))
for _, e := range m.wantlist { for _, e := range m.wantlist {
out = append(out, e) out = append(out, *e)
} }
return out return out
} }
...@@ -151,7 +151,7 @@ func (m *impl) addEntry(c *cid.Cid, priority int, cancel bool) { ...@@ -151,7 +151,7 @@ func (m *impl) addEntry(c *cid.Cid, priority int, cancel bool) {
e.Priority = priority e.Priority = priority
e.Cancel = cancel e.Cancel = cancel
} else { } else {
m.wantlist[k] = Entry{ m.wantlist[k] = &Entry{
Entry: &wantlist.Entry{ Entry: &wantlist.Entry{
Cid: c, Cid: c,
Priority: priority, Priority: priority,
......
...@@ -2,6 +2,7 @@ package notifications ...@@ -2,6 +2,7 @@ package notifications
import ( import (
"context" "context"
"sync"
blocks "gx/ipfs/Qmej7nf81hi2x2tvjRBF3mcp74sQyuDH4VMYDGd1YtXjb2/go-block-format" blocks "gx/ipfs/Qmej7nf81hi2x2tvjRBF3mcp74sQyuDH4VMYDGd1YtXjb2/go-block-format"
...@@ -18,18 +19,43 @@ type PubSub interface { ...@@ -18,18 +19,43 @@ type PubSub interface {
} }
func New() PubSub { func New() PubSub {
return &impl{*pubsub.New(bufferSize)} return &impl{
wrapped: *pubsub.New(bufferSize),
cancel: make(chan struct{}),
}
} }
type impl struct { type impl struct {
wrapped pubsub.PubSub wrapped pubsub.PubSub
// These two fields make up a shutdown "lock".
// We need them as calling, e.g., `Unsubscribe` after calling `Shutdown`
// blocks forever and fixing this in pubsub would be rather invasive.
cancel chan struct{}
wg sync.WaitGroup
} }
func (ps *impl) Publish(block blocks.Block) { func (ps *impl) Publish(block blocks.Block) {
ps.wg.Add(1)
defer ps.wg.Done()
select {
case <-ps.cancel:
// Already shutdown, bail.
return
default:
}
ps.wrapped.Pub(block, block.Cid().KeyString()) ps.wrapped.Pub(block, block.Cid().KeyString())
} }
// Not safe to call more than once.
func (ps *impl) Shutdown() { func (ps *impl) Shutdown() {
// Interrupt in-progress subscriptions.
close(ps.cancel)
// Wait for them to finish.
ps.wg.Wait()
// shutdown the pubsub.
ps.wrapped.Shutdown() ps.wrapped.Shutdown()
} }
...@@ -44,12 +70,34 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B ...@@ -44,12 +70,34 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B
close(blocksCh) close(blocksCh)
return blocksCh return blocksCh
} }
// prevent shutdown
ps.wg.Add(1)
// check if shutdown *after* preventing shutdowns.
select {
case <-ps.cancel:
// abort, allow shutdown to continue.
ps.wg.Done()
close(blocksCh)
return blocksCh
default:
}
ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...) ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
go func() { go func() {
defer close(blocksCh) defer func() {
defer ps.wrapped.Unsub(valuesCh) // with a len(keys) buffer, this is an optimization ps.wrapped.Unsub(valuesCh)
close(blocksCh)
// Unblock shutdown.
ps.wg.Done()
}()
for { for {
select { select {
case <-ps.cancel:
return
case <-ctx.Done(): case <-ctx.Done():
return return
case val, ok := <-valuesCh: case val, ok := <-valuesCh:
...@@ -61,6 +109,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B ...@@ -61,6 +109,8 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...*cid.Cid) <-chan blocks.B
return return
} }
select { select {
case <-ps.cancel:
return
case <-ctx.Done(): case <-ctx.Done():
return return
case blocksCh <- block: // continue case blocksCh <- block: // continue
......
...@@ -100,6 +100,25 @@ func TestDuplicateSubscribe(t *testing.T) { ...@@ -100,6 +100,25 @@ func TestDuplicateSubscribe(t *testing.T) {
assertBlocksEqual(t, e1, r2) assertBlocksEqual(t, e1, r2)
} }
func TestShutdownBeforeUnsubscribe(t *testing.T) {
e1 := blocks.NewBlock([]byte("1"))
n := New()
ctx, cancel := context.WithCancel(context.Background())
ch := n.Subscribe(ctx, e1.Cid()) // no keys provided
n.Shutdown()
cancel()
select {
case _, ok := <-ch:
if ok {
t.Fatal("channel should have been closed")
}
default:
t.Fatal("channel should have been closed")
}
}
func TestSubscribeIsANoopWhenCalledWithNoKeys(t *testing.T) { func TestSubscribeIsANoopWhenCalledWithNoKeys(t *testing.T) {
n := New() n := New()
defer n.Shutdown() defer n.Shutdown()
......
...@@ -83,6 +83,15 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session { ...@@ -83,6 +83,15 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session {
} }
func (bs *Bitswap) removeSession(s *Session) { func (bs *Bitswap) removeSession(s *Session) {
s.notif.Shutdown()
live := make([]*cid.Cid, 0, len(s.liveWants))
for c := range s.liveWants {
cs, _ := cid.Cast([]byte(c))
live = append(live, cs)
}
bs.CancelWants(live, s.id)
bs.sessLk.Lock() bs.sessLk.Lock()
defer bs.sessLk.Unlock() defer bs.sessLk.Unlock()
for i := 0; i < len(bs.sessions); i++ { for i := 0; i < len(bs.sessions); i++ {
...@@ -270,8 +279,9 @@ func (s *Session) receiveBlock(ctx context.Context, blk blocks.Block) { ...@@ -270,8 +279,9 @@ func (s *Session) receiveBlock(ctx context.Context, blk blocks.Block) {
} }
func (s *Session) wantBlocks(ctx context.Context, ks []*cid.Cid) { func (s *Session) wantBlocks(ctx context.Context, ks []*cid.Cid) {
now := time.Now()
for _, c := range ks { for _, c := range ks {
s.liveWants[c.KeyString()] = time.Now() s.liveWants[c.KeyString()] = now
} }
s.bs.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id) s.bs.wm.WantBlocks(ctx, ks, s.activePeersArr, s.id)
} }
......
...@@ -285,3 +285,36 @@ func TestMultipleSessions(t *testing.T) { ...@@ -285,3 +285,36 @@ func TestMultipleSessions(t *testing.T) {
} }
_ = blkch _ = blkch
} }
func TestWantlistClearsOnCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
vnet := getVirtualNetwork()
sesgen := NewTestSessionGenerator(vnet)
defer sesgen.Close()
bgen := blocksutil.NewBlockGenerator()
blks := bgen.Blocks(10)
var cids []*cid.Cid
for _, blk := range blks {
cids = append(cids, blk.Cid())
}
inst := sesgen.Instances(1)
a := inst[0]
ctx1, cancel1 := context.WithCancel(ctx)
ses := a.Exchange.NewSession(ctx1)
_, err := ses.GetBlocks(ctx, cids)
if err != nil {
t.Fatal(err)
}
cancel1()
if len(a.Exchange.GetWantlist()) > 0 {
t.Fatal("expected empty wantlist")
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"sync" "sync"
"time"
bsmsg "github.com/ipfs/go-ipfs/exchange/bitswap/message" bsmsg "github.com/ipfs/go-ipfs/exchange/bitswap/message"
bsnet "github.com/ipfs/go-ipfs/exchange/bitswap/network" bsnet "github.com/ipfs/go-ipfs/exchange/bitswap/network"
...@@ -22,7 +23,7 @@ var log = logging.Logger("bstestnet") ...@@ -22,7 +23,7 @@ var log = logging.Logger("bstestnet")
func VirtualNetwork(rs mockrouting.Server, d delay.D) Network { func VirtualNetwork(rs mockrouting.Server, d delay.D) Network {
return &network{ return &network{
clients: make(map[peer.ID]bsnet.Receiver), clients: make(map[peer.ID]*receiverQueue),
delay: d, delay: d,
routingserver: rs, routingserver: rs,
conns: make(map[string]struct{}), conns: make(map[string]struct{}),
...@@ -31,12 +32,28 @@ func VirtualNetwork(rs mockrouting.Server, d delay.D) Network { ...@@ -31,12 +32,28 @@ func VirtualNetwork(rs mockrouting.Server, d delay.D) Network {
type network struct { type network struct {
mu sync.Mutex mu sync.Mutex
clients map[peer.ID]bsnet.Receiver clients map[peer.ID]*receiverQueue
routingserver mockrouting.Server routingserver mockrouting.Server
delay delay.D delay delay.D
conns map[string]struct{} conns map[string]struct{}
} }
type message struct {
from peer.ID
msg bsmsg.BitSwapMessage
shouldSend time.Time
}
// receiverQueue queues up a set of messages to be sent, and sends them *in
// order* with their delays respected as much as sending them in order allows
// for
type receiverQueue struct {
receiver bsnet.Receiver
queue []*message
active bool
lk sync.Mutex
}
func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork { func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork {
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
...@@ -46,7 +63,7 @@ func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork { ...@@ -46,7 +63,7 @@ func (n *network) Adapter(p testutil.Identity) bsnet.BitSwapNetwork {
network: n, network: n,
routing: n.routingserver.Client(p), routing: n.routingserver.Client(p),
} }
n.clients[p.ID()] = client n.clients[p.ID()] = &receiverQueue{receiver: client}
return client return client
} }
...@@ -64,7 +81,7 @@ func (n *network) SendMessage( ...@@ -64,7 +81,7 @@ func (n *network) SendMessage(
ctx context.Context, ctx context.Context,
from peer.ID, from peer.ID,
to peer.ID, to peer.ID,
message bsmsg.BitSwapMessage) error { mes bsmsg.BitSwapMessage) error {
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
...@@ -77,7 +94,12 @@ func (n *network) SendMessage( ...@@ -77,7 +94,12 @@ func (n *network) SendMessage(
// nb: terminate the context since the context wouldn't actually be passed // nb: terminate the context since the context wouldn't actually be passed
// over the network in a real scenario // over the network in a real scenario
go n.deliver(receiver, from, message) msg := &message{
from: from,
msg: mes,
shouldSend: time.Now().Add(n.delay.Get()),
}
receiver.enqueue(msg)
return nil return nil
} }
...@@ -191,11 +213,38 @@ func (nc *networkClient) ConnectTo(_ context.Context, p peer.ID) error { ...@@ -191,11 +213,38 @@ func (nc *networkClient) ConnectTo(_ context.Context, p peer.ID) error {
// TODO: add handling for disconnects // TODO: add handling for disconnects
otherClient.PeerConnected(nc.local) otherClient.receiver.PeerConnected(nc.local)
nc.Receiver.PeerConnected(p) nc.Receiver.PeerConnected(p)
return nil return nil
} }
func (rq *receiverQueue) enqueue(m *message) {
rq.lk.Lock()
defer rq.lk.Unlock()
rq.queue = append(rq.queue, m)
if !rq.active {
rq.active = true
go rq.process()
}
}
func (rq *receiverQueue) process() {
for {
rq.lk.Lock()
if len(rq.queue) == 0 {
rq.active = false
rq.lk.Unlock()
return
}
m := rq.queue[0]
rq.queue = rq.queue[1:]
rq.lk.Unlock()
time.Sleep(time.Until(m.shouldSend))
rq.receiver.ReceiveMessage(context.TODO(), m.from, m.msg)
}
}
func tagForPeers(a, b peer.ID) string { func tagForPeers(a, b peer.ID) string {
if a < b { if a < b {
return string(a + b) return string(a + b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment