Commit d4191c4d authored by hannahhoward's avatar hannahhoward

feat(peermanager): move refcnt

Move refcnt tracking from the messagequeue to the peermanager, where it's relevant
parent 434e0f41
...@@ -34,8 +34,6 @@ type MessageQueue struct { ...@@ -34,8 +34,6 @@ type MessageQueue struct {
sender bsnet.MessageSender sender bsnet.MessageSender
refcnt int
work chan struct{} work chan struct{}
done chan struct{} done chan struct{}
} }
...@@ -48,27 +46,9 @@ func New(p peer.ID, network MessageNetwork) *MessageQueue { ...@@ -48,27 +46,9 @@ func New(p peer.ID, network MessageNetwork) *MessageQueue {
wl: wantlist.NewThreadSafe(), wl: wantlist.NewThreadSafe(),
network: network, network: network,
p: p, p: p,
refcnt: 0,
} }
} }
// RefCount returns the number of open connections for this queue.
func (mq *MessageQueue) RefCount() int {
return mq.refcnt
}
// RefIncrement increments the refcount for a message queue.
func (mq *MessageQueue) RefIncrement() {
mq.refcnt++
}
// RefDecrement decrements the refcount for a message queue and returns true
// if the refcount is now 0.
func (mq *MessageQueue) RefDecrement() bool {
mq.refcnt--
return mq.refcnt > 0
}
// AddMessage adds new entries to an outgoing message for a given session. // AddMessage adds new entries to an outgoing message for a given session.
func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) { func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) {
if !mq.addEntries(entries, ses) { if !mq.addEntries(entries, ses) {
......
...@@ -19,9 +19,6 @@ var ( ...@@ -19,9 +19,6 @@ var (
// PeerQueue provides a queer of messages to be sent for a single peer. // PeerQueue provides a queer of messages to be sent for a single peer.
type PeerQueue interface { type PeerQueue interface {
RefIncrement()
RefDecrement() bool
RefCount() int
AddMessage(entries []*bsmsg.Entry, ses uint64) AddMessage(entries []*bsmsg.Entry, ses uint64)
Startup(ctx context.Context) Startup(ctx context.Context)
AddWantlist(initialEntries []*wantlist.Entry) AddWantlist(initialEntries []*wantlist.Entry)
...@@ -35,10 +32,15 @@ type peerMessage interface { ...@@ -35,10 +32,15 @@ type peerMessage interface {
handle(pm *PeerManager) handle(pm *PeerManager)
} }
type peerQueueInstance struct {
refcnt int
pq PeerQueue
}
// PeerManager manages a pool of peers and sends messages to peers in the pool. // PeerManager manages a pool of peers and sends messages to peers in the pool.
type PeerManager struct { type PeerManager struct {
// peerQueues -- interact through internal utility functions get/set/remove/iterate // peerQueues -- interact through internal utility functions get/set/remove/iterate
peerQueues map[peer.ID]PeerQueue peerQueues map[peer.ID]*peerQueueInstance
peerQueuesLk sync.RWMutex peerQueuesLk sync.RWMutex
createPeerQueue PeerQueueFactory createPeerQueue PeerQueueFactory
...@@ -48,7 +50,7 @@ type PeerManager struct { ...@@ -48,7 +50,7 @@ type PeerManager struct {
// New creates a new PeerManager, given a context and a peerQueueFactory. // New creates a new PeerManager, given a context and a peerQueueFactory.
func New(ctx context.Context, createPeerQueue PeerQueueFactory) *PeerManager { func New(ctx context.Context, createPeerQueue PeerQueueFactory) *PeerManager {
return &PeerManager{ return &PeerManager{
peerQueues: make(map[peer.ID]PeerQueue), peerQueues: make(map[peer.ID]*peerQueueInstance),
createPeerQueue: createPeerQueue, createPeerQueue: createPeerQueue,
ctx: ctx, ctx: ctx,
} }
...@@ -68,12 +70,17 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID { ...@@ -68,12 +70,17 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID {
// Connected is called to add a new peer to the pool, and send it an initial set // Connected is called to add a new peer to the pool, and send it an initial set
// of wants. // of wants.
func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) { func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) {
mq := pm.getOrCreate(p) pm.peerQueuesLk.Lock()
pq := pm.getOrCreate(p)
if mq.RefCount() == 0 { if pq.refcnt == 0 {
mq.AddWantlist(initialEntries) pq.pq.AddWantlist(initialEntries)
} }
mq.RefIncrement()
pq.refcnt++
pm.peerQueuesLk.Unlock()
} }
// Disconnected is called to remove a peer from the pool. // Disconnected is called to remove a peer from the pool.
...@@ -81,7 +88,13 @@ func (pm *PeerManager) Disconnected(p peer.ID) { ...@@ -81,7 +88,13 @@ func (pm *PeerManager) Disconnected(p peer.ID) {
pm.peerQueuesLk.Lock() pm.peerQueuesLk.Lock()
pq, ok := pm.peerQueues[p] pq, ok := pm.peerQueues[p]
if !ok || pq.RefDecrement() { if !ok {
pm.peerQueuesLk.Unlock()
return
}
pq.refcnt--
if pq.refcnt > 0 {
pm.peerQueuesLk.Unlock() pm.peerQueuesLk.Unlock()
return return
} }
...@@ -89,7 +102,7 @@ func (pm *PeerManager) Disconnected(p peer.ID) { ...@@ -89,7 +102,7 @@ func (pm *PeerManager) Disconnected(p peer.ID) {
delete(pm.peerQueues, p) delete(pm.peerQueues, p)
pm.peerQueuesLk.Unlock() pm.peerQueuesLk.Unlock()
pq.Shutdown() pq.pq.Shutdown()
} }
...@@ -99,25 +112,26 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr ...@@ -99,25 +112,26 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr
if len(targets) == 0 { if len(targets) == 0 {
pm.peerQueuesLk.RLock() pm.peerQueuesLk.RLock()
for _, p := range pm.peerQueues { for _, p := range pm.peerQueues {
p.AddMessage(entries, from) p.pq.AddMessage(entries, from)
} }
pm.peerQueuesLk.RUnlock() pm.peerQueuesLk.RUnlock()
} else { } else {
for _, t := range targets { for _, t := range targets {
p := pm.getOrCreate(t) pm.peerQueuesLk.Lock()
p.AddMessage(entries, from) pqi := pm.getOrCreate(t)
pm.peerQueuesLk.Unlock()
pqi.pq.AddMessage(entries, from)
} }
} }
} }
func (pm *PeerManager) getOrCreate(p peer.ID) PeerQueue { func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance {
pm.peerQueuesLk.Lock() pqi, ok := pm.peerQueues[p]
pq, ok := pm.peerQueues[p]
if !ok { if !ok {
pq = pm.createPeerQueue(p) pq := pm.createPeerQueue(p)
pq.Startup(pm.ctx) pq.Startup(pm.ctx)
pm.peerQueues[p] = pq pqi = &peerQueueInstance{0, pq}
pm.peerQueues[p] = pqi
} }
pm.peerQueuesLk.Unlock() return pqi
return pq
} }
...@@ -20,19 +20,13 @@ type messageSent struct { ...@@ -20,19 +20,13 @@ type messageSent struct {
} }
type fakePeer struct { type fakePeer struct {
refcnt int
p peer.ID p peer.ID
messagesSent chan messageSent messagesSent chan messageSent
} }
func (fp *fakePeer) Startup(ctx context.Context) {} func (fp *fakePeer) Startup(ctx context.Context) {}
func (fp *fakePeer) Shutdown() {} func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) RefCount() int { return fp.refcnt }
func (fp *fakePeer) RefIncrement() { fp.refcnt++ }
func (fp *fakePeer) RefDecrement() bool {
fp.refcnt--
return fp.refcnt > 0
}
func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) { func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) {
fp.messagesSent <- messageSent{fp.p, entries, ses} fp.messagesSent <- messageSent{fp.p, entries, ses}
} }
...@@ -41,7 +35,6 @@ func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory { ...@@ -41,7 +35,6 @@ func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory {
return func(p peer.ID) PeerQueue { return func(p peer.ID) PeerQueue {
return &fakePeer{ return &fakePeer{
p: p, p: p,
refcnt: 0,
messagesSent: messagesSent, messagesSent: messagesSent,
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment