Unverified Commit 472a8ab9 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #72 from ipfs/bugs/racy-wantlist-handling

fix(wantlist): remove races on setup
parents 722239f1 d4191c4d
......@@ -132,7 +132,6 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
}
bs.wm.SetDelegate(bs.pm)
bs.pm.Startup()
bs.wm.Startup()
bs.pqm.Startup()
network.SetDelegate(bs)
......@@ -361,14 +360,13 @@ func (bs *Bitswap) updateReceiveCounters(b blocks.Block) {
// Connected/Disconnected warns bitswap about peer connections.
func (bs *Bitswap) PeerConnected(p peer.ID) {
initialWants := bs.wm.CurrentBroadcastWants()
bs.pm.Connected(p, initialWants)
bs.wm.Connected(p)
bs.engine.PeerConnected(p)
}
// Connected/Disconnected warns bitswap about peer connections.
func (bs *Bitswap) PeerDisconnected(p peer.ID) {
bs.pm.Disconnected(p)
bs.wm.Disconnected(p)
bs.engine.PeerDisconnected(p)
}
......
......@@ -14,6 +14,8 @@ import (
var log = logging.Logger("bitswap")
const maxRetries = 10
// MessageNetwork is any network that can connect peers and generate a message
// sender.
type MessageNetwork interface {
......@@ -32,8 +34,6 @@ type MessageQueue struct {
sender bsnet.MessageSender
refcnt int
work chan struct{}
done chan struct{}
}
......@@ -46,22 +46,9 @@ func New(p peer.ID, network MessageNetwork) *MessageQueue {
wl: wantlist.NewThreadSafe(),
network: network,
p: p,
refcnt: 1,
}
}
// RefIncrement increments the refcount for a message queue.
func (mq *MessageQueue) RefIncrement() {
mq.refcnt++
}
// RefDecrement decrements the refcount for a message queue and returns true
// if the refcount is now 0.
func (mq *MessageQueue) RefDecrement() bool {
mq.refcnt--
return mq.refcnt > 0
}
// AddMessage adds new entries to an outgoing message for a given session.
func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) {
if !mq.addEntries(entries, ses) {
......@@ -73,24 +60,31 @@ func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) {
}
}
// Startup starts the processing of messages, and creates an initial message
// based on the given initial wantlist.
func (mq *MessageQueue) Startup(ctx context.Context, initialEntries []*wantlist.Entry) {
// new peer, we will want to give them our full wantlist
// AddWantlist adds a complete session tracked want list to a message queue
func (mq *MessageQueue) AddWantlist(initialEntries []*wantlist.Entry) {
if len(initialEntries) > 0 {
fullwantlist := bsmsg.New(true)
if mq.out == nil {
mq.out = bsmsg.New(false)
}
for _, e := range initialEntries {
for k := range e.SesTrk {
mq.wl.AddEntry(e, k)
}
fullwantlist.AddEntry(e.Cid, e.Priority)
mq.out.AddEntry(e.Cid, e.Priority)
}
select {
case mq.work <- struct{}{}:
default:
}
mq.out = fullwantlist
mq.work <- struct{}{}
}
go mq.runQueue(ctx)
}
// Startup starts the processing of messages, and creates an initial message
// based on the given initial wantlist.
func (mq *MessageQueue) Startup(ctx context.Context) {
go mq.runQueue(ctx)
}
// Shutdown stops the processing of messages for a message queue.
......@@ -162,7 +156,7 @@ func (mq *MessageQueue) doWork(ctx context.Context) {
}
// send wantlist updates
for { // try to send this message until we fail.
for i := 0; i < maxRetries; i++ { // try to send this message until we fail.
if mq.attemptSendAndRecovery(ctx, wlm) {
return
}
......
......@@ -25,9 +25,9 @@ func (fmn *fakeMessageNetwork) ConnectTo(context.Context, peer.ID) error {
func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error) {
if fmn.messageSenderError == nil {
return fmn.messageSender, nil
} else {
return nil, fmn.messageSenderError
}
return nil, fmn.messageSenderError
}
type fakeMessageSender struct {
......@@ -81,8 +81,8 @@ func TestStartupAndShutdown(t *testing.T) {
ses := testutil.GenerateSessionID()
wl := testutil.GenerateWantlist(10, ses)
messageQueue.Startup(ctx, wl.Entries())
messageQueue.Startup(ctx)
messageQueue.AddWantlist(wl.Entries())
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if len(messages) != 1 {
t.Fatal("wrong number of messages were sent for initial wants")
......@@ -123,7 +123,7 @@ func TestSendingMessagesDeduped(t *testing.T) {
ses1 := testutil.GenerateSessionID()
ses2 := testutil.GenerateSessionID()
entries := testutil.GenerateMessageEntries(10, false)
messageQueue.Startup(ctx, nil)
messageQueue.Startup(ctx)
messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(entries, ses2)
......@@ -148,7 +148,7 @@ func TestSendingMessagesPartialDupe(t *testing.T) {
entries := testutil.GenerateMessageEntries(10, false)
moreEntries := testutil.GenerateMessageEntries(5, false)
secondEntries := append(entries[5:], moreEntries...)
messageQueue.Startup(ctx, nil)
messageQueue.Startup(ctx)
messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(secondEntries, ses2)
......
......@@ -2,6 +2,7 @@ package peermanager
import (
"context"
"sync"
bsmsg "github.com/ipfs/go-bitswap/message"
wantlist "github.com/ipfs/go-bitswap/wantlist"
......@@ -18,10 +19,9 @@ var (
// PeerQueue provides a queer of messages to be sent for a single peer.
type PeerQueue interface {
RefIncrement()
RefDecrement() bool
AddMessage(entries []*bsmsg.Entry, ses uint64)
Startup(ctx context.Context, initialEntries []*wantlist.Entry)
Startup(ctx context.Context)
AddWantlist(initialEntries []*wantlist.Entry)
Shutdown()
}
......@@ -32,179 +32,106 @@ type peerMessage interface {
handle(pm *PeerManager)
}
type peerQueueInstance struct {
refcnt int
pq PeerQueue
}
// PeerManager manages a pool of peers and sends messages to peers in the pool.
type PeerManager struct {
// sync channel for Run loop
peerMessages chan peerMessage
// synchronized by Run loop, only touch inside there
peerQueues map[peer.ID]PeerQueue
// peerQueues -- interact through internal utility functions get/set/remove/iterate
peerQueues map[peer.ID]*peerQueueInstance
peerQueuesLk sync.RWMutex
createPeerQueue PeerQueueFactory
ctx context.Context
cancel func()
}
// New creates a new PeerManager, given a context and a peerQueueFactory.
func New(ctx context.Context, createPeerQueue PeerQueueFactory) *PeerManager {
ctx, cancel := context.WithCancel(ctx)
return &PeerManager{
peerMessages: make(chan peerMessage, 10),
peerQueues: make(map[peer.ID]PeerQueue),
peerQueues: make(map[peer.ID]*peerQueueInstance),
createPeerQueue: createPeerQueue,
ctx: ctx,
cancel: cancel,
}
}
// ConnectedPeers returns a list of peers this PeerManager is managing.
func (pm *PeerManager) ConnectedPeers() []peer.ID {
resp := make(chan []peer.ID, 1)
select {
case pm.peerMessages <- &getPeersMessage{resp}:
case <-pm.ctx.Done():
return nil
}
select {
case peers := <-resp:
return peers
case <-pm.ctx.Done():
return nil
pm.peerQueuesLk.RLock()
defer pm.peerQueuesLk.RUnlock()
peers := make([]peer.ID, 0, len(pm.peerQueues))
for p := range pm.peerQueues {
peers = append(peers, p)
}
return peers
}
// Connected is called to add a new peer to the pool, and send it an initial set
// of wants.
func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) {
select {
case pm.peerMessages <- &connectPeerMessage{p, initialEntries}:
case <-pm.ctx.Done():
}
}
// Disconnected is called to remove a peer from the pool.
func (pm *PeerManager) Disconnected(p peer.ID) {
select {
case pm.peerMessages <- &disconnectPeerMessage{p}:
case <-pm.ctx.Done():
}
}
// SendMessage is called to send a message to all or some peers in the pool;
// if targets is nil, it sends to all.
func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) {
select {
case pm.peerMessages <- &sendPeerMessage{entries: entries, targets: targets, from: from}:
case <-pm.ctx.Done():
}
}
// Startup enables the run loop for the PeerManager - no processing will occur
// if startup is not called.
func (pm *PeerManager) Startup() {
go pm.run()
}
// Shutdown shutsdown processing for the PeerManager.
func (pm *PeerManager) Shutdown() {
pm.cancel()
}
func (pm *PeerManager) run() {
for {
select {
case message := <-pm.peerMessages:
message.handle(pm)
case <-pm.ctx.Done():
return
}
}
}
type sendPeerMessage struct {
entries []*bsmsg.Entry
targets []peer.ID
from uint64
}
pm.peerQueuesLk.Lock()
func (s *sendPeerMessage) handle(pm *PeerManager) {
pm.sendMessage(s)
}
type connectPeerMessage struct {
p peer.ID
initialEntries []*wantlist.Entry
}
func (c *connectPeerMessage) handle(pm *PeerManager) {
pm.startPeerHandler(c.p, c.initialEntries)
}
type disconnectPeerMessage struct {
p peer.ID
}
pq := pm.getOrCreate(p)
func (dc *disconnectPeerMessage) handle(pm *PeerManager) {
pm.stopPeerHandler(dc.p)
}
type getPeersMessage struct {
peerResp chan<- []peer.ID
}
func (gp *getPeersMessage) handle(pm *PeerManager) {
pm.getPeers(gp.peerResp)
}
func (pm *PeerManager) getPeers(peerResp chan<- []peer.ID) {
peers := make([]peer.ID, 0, len(pm.peerQueues))
for p := range pm.peerQueues {
peers = append(peers, p)
if pq.refcnt == 0 {
pq.pq.AddWantlist(initialEntries)
}
peerResp <- peers
}
func (pm *PeerManager) startPeerHandler(p peer.ID, initialEntries []*wantlist.Entry) PeerQueue {
mq, ok := pm.peerQueues[p]
if ok {
mq.RefIncrement()
return nil
}
pq.refcnt++
mq = pm.createPeerQueue(p)
pm.peerQueues[p] = mq
mq.Startup(pm.ctx, initialEntries)
return mq
pm.peerQueuesLk.Unlock()
}
func (pm *PeerManager) stopPeerHandler(p peer.ID) {
// Disconnected is called to remove a peer from the pool.
func (pm *PeerManager) Disconnected(p peer.ID) {
pm.peerQueuesLk.Lock()
pq, ok := pm.peerQueues[p]
if !ok {
// TODO: log error?
pm.peerQueuesLk.Unlock()
return
}
if pq.RefDecrement() {
pq.refcnt--
if pq.refcnt > 0 {
pm.peerQueuesLk.Unlock()
return
}
pq.Shutdown()
delete(pm.peerQueues, p)
pm.peerQueuesLk.Unlock()
pq.pq.Shutdown()
}
func (pm *PeerManager) sendMessage(ms *sendPeerMessage) {
if len(ms.targets) == 0 {
// SendMessage is called to send a message to all or some peers in the pool;
// if targets is nil, it sends to all.
func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) {
if len(targets) == 0 {
pm.peerQueuesLk.RLock()
for _, p := range pm.peerQueues {
p.AddMessage(ms.entries, ms.from)
p.pq.AddMessage(entries, from)
}
pm.peerQueuesLk.RUnlock()
} else {
for _, t := range ms.targets {
p, ok := pm.peerQueues[t]
if !ok {
log.Infof("tried sending wantlist change to non-partner peer: %s", t)
continue
}
p.AddMessage(ms.entries, ms.from)
for _, t := range targets {
pm.peerQueuesLk.Lock()
pqi := pm.getOrCreate(t)
pm.peerQueuesLk.Unlock()
pqi.pq.AddMessage(entries, from)
}
}
}
func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance {
pqi, ok := pm.peerQueues[p]
if !ok {
pq := pm.createPeerQueue(p)
pq.Startup(pm.ctx)
pqi = &peerQueueInstance{0, pq}
pm.peerQueues[p] = pqi
}
return pqi
}
......@@ -20,27 +20,21 @@ type messageSent struct {
}
type fakePeer struct {
refcnt int
p peer.ID
messagesSent chan messageSent
}
func (fp *fakePeer) Startup(ctx context.Context, initialEntries []*wantlist.Entry) {}
func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) RefIncrement() { fp.refcnt++ }
func (fp *fakePeer) RefDecrement() bool {
fp.refcnt--
return fp.refcnt > 0
}
func (fp *fakePeer) Startup(ctx context.Context) {}
func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) {
fp.messagesSent <- messageSent{fp.p, entries, ses}
}
func (fp *fakePeer) AddWantlist(initialEntries []*wantlist.Entry) {}
func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory {
return func(p peer.ID) PeerQueue {
return &fakePeer{
p: p,
refcnt: 1,
messagesSent: messagesSent,
}
}
......@@ -79,7 +73,6 @@ func TestAddingAndRemovingPeers(t *testing.T) {
tp := testutil.GeneratePeers(5)
peer1, peer2, peer3, peer4, peer5 := tp[0], tp[1], tp[2], tp[3], tp[4]
peerManager := New(ctx, peerQueueFactory)
peerManager.Startup()
peerManager.Connected(peer1, nil)
peerManager.Connected(peer2, nil)
......@@ -118,14 +111,13 @@ func TestAddingAndRemovingPeers(t *testing.T) {
func TestSendingMessagesToPeers(t *testing.T) {
ctx := context.Background()
messagesSent := make(chan messageSent)
messagesSent := make(chan messageSent, 16)
peerQueueFactory := makePeerQueueFactory(messagesSent)
tp := testutil.GeneratePeers(5)
peer1, peer2, peer3, peer4, peer5 := tp[0], tp[1], tp[2], tp[3], tp[4]
peerManager := New(ctx, peerQueueFactory)
peerManager.Startup()
peerManager.Connected(peer1, nil)
peerManager.Connected(peer2, nil)
......@@ -159,7 +151,7 @@ func TestSendingMessagesToPeers(t *testing.T) {
peersReceived = collectAndCheckMessages(
ctx, t, messagesSent, entries, ses, 10*time.Millisecond)
if len(peersReceived) != 2 {
if len(peersReceived) != 3 {
t.Fatal("Incorrect number of peers received messages")
}
......@@ -173,7 +165,7 @@ func TestSendingMessagesToPeers(t *testing.T) {
t.Fatal("Peers received message but should not have")
}
if testutil.ContainsPeer(peersReceived, peer4) {
t.Fatal("Peers targeted received message but was not connected")
if !testutil.ContainsPeer(peersReceived, peer4) {
t.Fatal("Peer should have autoconnected on message send")
}
}
......@@ -20,9 +20,11 @@ const (
maxPriority = math.MaxInt32
)
// WantSender sends changes out to the network as they get added to the wantlist
// PeerHandler sends changes out to the network as they get added to the wantlist
// managed by the WantManager.
type WantSender interface {
type PeerHandler interface {
Disconnected(p peer.ID)
Connected(p peer.ID, initialEntries []*wantlist.Entry)
SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64)
}
......@@ -46,7 +48,7 @@ type WantManager struct {
ctx context.Context
cancel func()
wantSender WantSender
peerHandler PeerHandler
wantlistGauge metrics.Gauge
}
......@@ -66,8 +68,8 @@ func New(ctx context.Context) *WantManager {
}
// SetDelegate specifies who will send want changes out to the internet.
func (wm *WantManager) SetDelegate(wantSender WantSender) {
wm.wantSender = wantSender
func (wm *WantManager) SetDelegate(peerHandler PeerHandler) {
wm.peerHandler = peerHandler
}
// WantBlocks adds the given cids to the wantlist, tracked by the given session.
......@@ -145,6 +147,22 @@ func (wm *WantManager) WantCount() int {
}
}
// Connected is called when a new peer is connected
func (wm *WantManager) Connected(p peer.ID) {
select {
case wm.wantMessages <- &connectedMessage{p}:
case <-wm.ctx.Done():
}
}
// Disconnected is called when a peer is disconnected
func (wm *WantManager) Disconnected(p peer.ID) {
select {
case wm.wantMessages <- &disconnectedMessage{p}:
case <-wm.ctx.Done():
}
}
// Startup starts processing for the WantManager.
func (wm *WantManager) Startup() {
go wm.run()
......@@ -214,7 +232,7 @@ func (ws *wantSet) handle(wm *WantManager) {
}
// broadcast those wantlist changes
wm.wantSender.SendMessage(ws.entries, ws.targets, ws.from)
wm.peerHandler.SendMessage(ws.entries, ws.targets, ws.from)
}
type isWantedMessage struct {
......@@ -250,3 +268,19 @@ type wantCountMessage struct {
func (wcm *wantCountMessage) handle(wm *WantManager) {
wcm.resp <- wm.wl.Len()
}
type connectedMessage struct {
p peer.ID
}
func (cm *connectedMessage) handle(wm *WantManager) {
wm.peerHandler.Connected(cm.p, wm.bcwl.Entries())
}
type disconnectedMessage struct {
p peer.ID
}
func (dm *disconnectedMessage) handle(wm *WantManager) {
wm.peerHandler.Disconnected(dm.p)
}
......@@ -7,35 +7,39 @@ import (
"testing"
"github.com/ipfs/go-bitswap/testutil"
wantlist "github.com/ipfs/go-bitswap/wantlist"
bsmsg "github.com/ipfs/go-bitswap/message"
"github.com/ipfs/go-cid"
"github.com/libp2p/go-libp2p-peer"
)
type fakeWantSender struct {
type fakePeerHandler struct {
lk sync.RWMutex
lastWantSet wantSet
}
func (fws *fakeWantSender) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) {
fws.lk.Lock()
fws.lastWantSet = wantSet{entries, targets, from}
fws.lk.Unlock()
func (fph *fakePeerHandler) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) {
fph.lk.Lock()
fph.lastWantSet = wantSet{entries, targets, from}
fph.lk.Unlock()
}
func (fws *fakeWantSender) getLastWantSet() wantSet {
fws.lk.Lock()
defer fws.lk.Unlock()
return fws.lastWantSet
func (fph *fakePeerHandler) Connected(p peer.ID, initialEntries []*wantlist.Entry) {}
func (fph *fakePeerHandler) Disconnected(p peer.ID) {}
func (fph *fakePeerHandler) getLastWantSet() wantSet {
fph.lk.Lock()
defer fph.lk.Unlock()
return fph.lastWantSet
}
func setupTestFixturesAndInitialWantList() (
context.Context, *fakeWantSender, *WantManager, []cid.Cid, []cid.Cid, []peer.ID, uint64, uint64) {
context.Context, *fakePeerHandler, *WantManager, []cid.Cid, []cid.Cid, []peer.ID, uint64, uint64) {
ctx := context.Background()
// setup fixtures
wantSender := &fakeWantSender{}
wantSender := &fakePeerHandler{}
wantManager := New(ctx)
keys := testutil.GenerateCids(10)
otherKeys := testutil.GenerateCids(5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment