Commit 78386f0e authored by hannahhoward's avatar hannahhoward

feat(wantlist): differentiate types

Seperate want list into differentiated types - session tracking and regular

fix #13
parent 472a8ab9
...@@ -97,8 +97,8 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -97,8 +97,8 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
return nil return nil
}) })
peerQueueFactory := func(p peer.ID) bspm.PeerQueue { peerQueueFactory := func(ctx context.Context, p peer.ID) bspm.PeerQueue {
return bsmq.New(p, network) return bsmq.New(ctx, p, network)
} }
wm := bswm.New(ctx) wm := bswm.New(ctx)
......
...@@ -2,7 +2,6 @@ package messagequeue ...@@ -2,7 +2,6 @@ package messagequeue
import ( import (
"context" "context"
"sync"
"time" "time"
bsmsg "github.com/ipfs/go-bitswap/message" bsmsg "github.com/ipfs/go-bitswap/message"
...@@ -23,68 +22,72 @@ type MessageNetwork interface { ...@@ -23,68 +22,72 @@ type MessageNetwork interface {
NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error) NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error)
} }
type request interface {
handle(mq *MessageQueue)
}
// MessageQueue implements queue of want messages to send to peers. // MessageQueue implements queue of want messages to send to peers.
type MessageQueue struct { type MessageQueue struct {
p peer.ID ctx context.Context
p peer.ID
outlk sync.Mutex
out bsmsg.BitSwapMessage
network MessageNetwork network MessageNetwork
wl *wantlist.ThreadSafe
sender bsnet.MessageSender newRequests chan request
outgoingMessages chan bsmsg.BitSwapMessage
done chan struct{}
// do not touch out of run loop
wl *wantlist.SessionTrackedWantlist
nextMessage bsmsg.BitSwapMessage
sender bsnet.MessageSender
}
type messageRequest struct {
entries []*bsmsg.Entry
ses uint64
}
work chan struct{} type wantlistRequest struct {
done chan struct{} wl *wantlist.SessionTrackedWantlist
} }
// New creats a new MessageQueue. // New creats a new MessageQueue.
func New(p peer.ID, network MessageNetwork) *MessageQueue { func New(ctx context.Context, p peer.ID, network MessageNetwork) *MessageQueue {
return &MessageQueue{ return &MessageQueue{
done: make(chan struct{}), ctx: ctx,
work: make(chan struct{}, 1), wl: wantlist.NewSessionTrackedWantlist(),
wl: wantlist.NewThreadSafe(), network: network,
network: network, p: p,
p: p, newRequests: make(chan request, 16),
outgoingMessages: make(chan bsmsg.BitSwapMessage),
done: make(chan struct{}),
} }
} }
// AddMessage adds new entries to an outgoing message for a given session. // AddMessage adds new entries to an outgoing message for a given session.
func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) { func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) {
if !mq.addEntries(entries, ses) {
return
}
select { select {
case mq.work <- struct{}{}: case mq.newRequests <- &messageRequest{entries, ses}:
default: case <-mq.ctx.Done():
} }
} }
// AddWantlist adds a complete session tracked want list to a message queue // AddWantlist adds a complete session tracked want list to a message queue
func (mq *MessageQueue) AddWantlist(initialEntries []*wantlist.Entry) { func (mq *MessageQueue) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) {
if len(initialEntries) > 0 { wl := wantlist.NewSessionTrackedWantlist()
if mq.out == nil { initialWants.CopyWants(wl)
mq.out = bsmsg.New(false)
}
for _, e := range initialEntries { select {
for k := range e.SesTrk { case mq.newRequests <- &wantlistRequest{wl}:
mq.wl.AddEntry(e, k) case <-mq.ctx.Done():
}
mq.out.AddEntry(e.Cid, e.Priority)
}
select {
case mq.work <- struct{}{}:
default:
}
} }
} }
// Startup starts the processing of messages, and creates an initial message // Startup starts the processing of messages, and creates an initial message
// based on the given initial wantlist. // based on the given initial wantlist.
func (mq *MessageQueue) Startup(ctx context.Context) { func (mq *MessageQueue) Startup() {
go mq.runQueue(ctx) go mq.runQueue()
go mq.sendMessages()
} }
// Shutdown stops the processing of messages for a message queue. // Shutdown stops the processing of messages for a message queue.
...@@ -92,17 +95,26 @@ func (mq *MessageQueue) Shutdown() { ...@@ -92,17 +95,26 @@ func (mq *MessageQueue) Shutdown() {
close(mq.done) close(mq.done)
} }
func (mq *MessageQueue) runQueue(ctx context.Context) { func (mq *MessageQueue) runQueue() {
outgoingMessages := func() chan bsmsg.BitSwapMessage {
if mq.nextMessage == nil {
return nil
}
return mq.outgoingMessages
}
for { for {
select { select {
case <-mq.work: // there is work to be done case newRequest := <-mq.newRequests:
mq.doWork(ctx) newRequest.handle(mq)
case outgoingMessages() <- mq.nextMessage:
mq.nextMessage = nil
case <-mq.done: case <-mq.done:
if mq.sender != nil { if mq.sender != nil {
mq.sender.Close() mq.sender.Close()
} }
return return
case <-ctx.Done(): case <-mq.ctx.Done():
if mq.sender != nil { if mq.sender != nil {
mq.sender.Reset() mq.sender.Reset()
} }
...@@ -111,63 +123,77 @@ func (mq *MessageQueue) runQueue(ctx context.Context) { ...@@ -111,63 +123,77 @@ func (mq *MessageQueue) runQueue(ctx context.Context) {
} }
} }
func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) bool { func (mr *messageRequest) handle(mq *MessageQueue) {
var work bool mq.addEntries(mr.entries, mr.ses)
mq.outlk.Lock() }
defer mq.outlk.Unlock()
// if we have no message held allocate a new one func (wr *wantlistRequest) handle(mq *MessageQueue) {
if mq.out == nil { initialWants := wr.wl
mq.out = bsmsg.New(false) initialWants.CopyWants(mq.wl)
if initialWants.Len() > 0 {
if mq.nextMessage == nil {
mq.nextMessage = bsmsg.New(false)
}
for _, e := range initialWants.Entries() {
mq.nextMessage.AddEntry(e.Cid, e.Priority)
}
} }
}
// TODO: add a msg.Combine(...) method func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) {
// otherwise, combine the one we are holding with the
// one passed in
for _, e := range entries { for _, e := range entries {
if e.Cancel { if e.Cancel {
if mq.wl.Remove(e.Cid, ses) { if mq.wl.Remove(e.Cid, ses) {
work = true if mq.nextMessage == nil {
mq.out.Cancel(e.Cid) mq.nextMessage = bsmsg.New(false)
}
mq.nextMessage.Cancel(e.Cid)
} }
} else { } else {
if mq.wl.Add(e.Cid, e.Priority, ses) { if mq.wl.Add(e.Cid, e.Priority, ses) {
work = true if mq.nextMessage == nil {
mq.out.AddEntry(e.Cid, e.Priority) mq.nextMessage = bsmsg.New(false)
}
mq.nextMessage.AddEntry(e.Cid, e.Priority)
} }
} }
} }
return work
} }
func (mq *MessageQueue) doWork(ctx context.Context) { func (mq *MessageQueue) sendMessages() {
for {
wlm := mq.extractOutgoingMessage() select {
if wlm == nil || wlm.Empty() { case nextMessage := <-mq.outgoingMessages:
return mq.sendMessage(nextMessage)
case <-mq.done:
return
case <-mq.ctx.Done():
return
}
} }
}
func (mq *MessageQueue) sendMessage(message bsmsg.BitSwapMessage) {
// NB: only open a stream if we actually have data to send err := mq.initializeSender()
err := mq.initializeSender(ctx)
if err != nil { if err != nil {
log.Infof("cant open message sender to peer %s: %s", mq.p, err) log.Infof("cant open message sender to peer %s: %s", mq.p, err)
// TODO: cant connect, what now? // TODO: cant connect, what now?
return return
} }
// send wantlist updates
for i := 0; i < maxRetries; i++ { // try to send this message until we fail. for i := 0; i < maxRetries; i++ { // try to send this message until we fail.
if mq.attemptSendAndRecovery(ctx, wlm) { if mq.attemptSendAndRecovery(message) {
return return
} }
} }
} }
func (mq *MessageQueue) initializeSender(ctx context.Context) error { func (mq *MessageQueue) initializeSender() error {
if mq.sender != nil { if mq.sender != nil {
return nil return nil
} }
nsender, err := openSender(ctx, mq.network, mq.p) nsender, err := openSender(mq.ctx, mq.network, mq.p)
if err != nil { if err != nil {
return err return err
} }
...@@ -175,8 +201,8 @@ func (mq *MessageQueue) initializeSender(ctx context.Context) error { ...@@ -175,8 +201,8 @@ func (mq *MessageQueue) initializeSender(ctx context.Context) error {
return nil return nil
} }
func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.BitSwapMessage) bool { func (mq *MessageQueue) attemptSendAndRecovery(message bsmsg.BitSwapMessage) bool {
err := mq.sender.SendMsg(ctx, wlm) err := mq.sender.SendMsg(mq.ctx, message)
if err == nil { if err == nil {
return true return true
} }
...@@ -188,14 +214,14 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi ...@@ -188,14 +214,14 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi
select { select {
case <-mq.done: case <-mq.done:
return true return true
case <-ctx.Done(): case <-mq.ctx.Done():
return true return true
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
// wait 100ms in case disconnect notifications are still propogating // wait 100ms in case disconnect notifications are still propogating
log.Warning("SendMsg errored but neither 'done' nor context.Done() were set") log.Warning("SendMsg errored but neither 'done' nor context.Done() were set")
} }
err = mq.initializeSender(ctx) err = mq.initializeSender()
if err != nil { if err != nil {
log.Infof("couldnt open sender again after SendMsg(%s) failed: %s", mq.p, err) log.Infof("couldnt open sender again after SendMsg(%s) failed: %s", mq.p, err)
// TODO(why): what do we do now? // TODO(why): what do we do now?
...@@ -215,15 +241,6 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi ...@@ -215,15 +241,6 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi
return false return false
} }
func (mq *MessageQueue) extractOutgoingMessage() bsmsg.BitSwapMessage {
// grab outgoing message
mq.outlk.Lock()
wlm := mq.out
mq.out = nil
mq.outlk.Unlock()
return wlm
}
func openSender(ctx context.Context, network MessageNetwork, p peer.ID) (bsnet.MessageSender, error) { func openSender(ctx context.Context, network MessageNetwork, p peer.ID) (bsnet.MessageSender, error) {
// allow ten minutes for connections this includes looking them up in the // allow ten minutes for connections this includes looking them up in the
// dht dialing them, and handshaking // dht dialing them, and handshaking
......
...@@ -27,7 +27,6 @@ func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet ...@@ -27,7 +27,6 @@ func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet
return fmn.messageSender, nil return fmn.messageSender, nil
} }
return nil, fmn.messageSenderError return nil, fmn.messageSenderError
} }
type fakeMessageSender struct { type fakeMessageSender struct {
...@@ -77,12 +76,12 @@ func TestStartupAndShutdown(t *testing.T) { ...@@ -77,12 +76,12 @@ func TestStartupAndShutdown(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet) messageQueue := New(ctx, peerID, fakenet)
ses := testutil.GenerateSessionID() ses := testutil.GenerateSessionID()
wl := testutil.GenerateWantlist(10, ses) wl := testutil.GenerateWantlist(10, ses)
messageQueue.Startup(ctx) messageQueue.Startup()
messageQueue.AddWantlist(wl.Entries()) messageQueue.AddWantlist(wl)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if len(messages) != 1 { if len(messages) != 1 {
t.Fatal("wrong number of messages were sent for initial wants") t.Fatal("wrong number of messages were sent for initial wants")
...@@ -119,11 +118,11 @@ func TestSendingMessagesDeduped(t *testing.T) { ...@@ -119,11 +118,11 @@ func TestSendingMessagesDeduped(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet) messageQueue := New(ctx, peerID, fakenet)
ses1 := testutil.GenerateSessionID() ses1 := testutil.GenerateSessionID()
ses2 := testutil.GenerateSessionID() ses2 := testutil.GenerateSessionID()
entries := testutil.GenerateMessageEntries(10, false) entries := testutil.GenerateMessageEntries(10, false)
messageQueue.Startup(ctx) messageQueue.Startup()
messageQueue.AddMessage(entries, ses1) messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(entries, ses2) messageQueue.AddMessage(entries, ses2)
...@@ -142,13 +141,13 @@ func TestSendingMessagesPartialDupe(t *testing.T) { ...@@ -142,13 +141,13 @@ func TestSendingMessagesPartialDupe(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet) messageQueue := New(ctx, peerID, fakenet)
ses1 := testutil.GenerateSessionID() ses1 := testutil.GenerateSessionID()
ses2 := testutil.GenerateSessionID() ses2 := testutil.GenerateSessionID()
entries := testutil.GenerateMessageEntries(10, false) entries := testutil.GenerateMessageEntries(10, false)
moreEntries := testutil.GenerateMessageEntries(5, false) moreEntries := testutil.GenerateMessageEntries(5, false)
secondEntries := append(entries[5:], moreEntries...) secondEntries := append(entries[5:], moreEntries...)
messageQueue.Startup(ctx) messageQueue.Startup()
messageQueue.AddMessage(entries, ses1) messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(secondEntries, ses2) messageQueue.AddMessage(secondEntries, ses2)
......
...@@ -20,13 +20,13 @@ var ( ...@@ -20,13 +20,13 @@ var (
// PeerQueue provides a queer of messages to be sent for a single peer. // PeerQueue provides a queer of messages to be sent for a single peer.
type PeerQueue interface { type PeerQueue interface {
AddMessage(entries []*bsmsg.Entry, ses uint64) AddMessage(entries []*bsmsg.Entry, ses uint64)
Startup(ctx context.Context) Startup()
AddWantlist(initialEntries []*wantlist.Entry) AddWantlist(initialWants *wantlist.SessionTrackedWantlist)
Shutdown() Shutdown()
} }
// PeerQueueFactory provides a function that will create a PeerQueue. // PeerQueueFactory provides a function that will create a PeerQueue.
type PeerQueueFactory func(p peer.ID) PeerQueue type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue
type peerMessage interface { type peerMessage interface {
handle(pm *PeerManager) handle(pm *PeerManager)
...@@ -69,13 +69,13 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID { ...@@ -69,13 +69,13 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID {
// Connected is called to add a new peer to the pool, and send it an initial set // Connected is called to add a new peer to the pool, and send it an initial set
// of wants. // of wants.
func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) { func (pm *PeerManager) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) {
pm.peerQueuesLk.Lock() pm.peerQueuesLk.Lock()
pq := pm.getOrCreate(p) pq := pm.getOrCreate(p)
if pq.refcnt == 0 { if pq.refcnt == 0 {
pq.pq.AddWantlist(initialEntries) pq.pq.AddWantlist(initialWants)
} }
pq.refcnt++ pq.refcnt++
...@@ -128,8 +128,8 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr ...@@ -128,8 +128,8 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr
func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance { func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance {
pqi, ok := pm.peerQueues[p] pqi, ok := pm.peerQueues[p]
if !ok { if !ok {
pq := pm.createPeerQueue(p) pq := pm.createPeerQueue(pm.ctx, p)
pq.Startup(pm.ctx) pq.Startup()
pqi = &peerQueueInstance{0, pq} pqi = &peerQueueInstance{0, pq}
pm.peerQueues[p] = pqi pm.peerQueues[p] = pqi
} }
......
...@@ -24,15 +24,15 @@ type fakePeer struct { ...@@ -24,15 +24,15 @@ type fakePeer struct {
messagesSent chan messageSent messagesSent chan messageSent
} }
func (fp *fakePeer) Startup(ctx context.Context) {} func (fp *fakePeer) Startup() {}
func (fp *fakePeer) Shutdown() {} func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) { func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) {
fp.messagesSent <- messageSent{fp.p, entries, ses} fp.messagesSent <- messageSent{fp.p, entries, ses}
} }
func (fp *fakePeer) AddWantlist(initialEntries []*wantlist.Entry) {} func (fp *fakePeer) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) {}
func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory { func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory {
return func(p peer.ID) PeerQueue { return func(ctx context.Context, p peer.ID) PeerQueue {
return &fakePeer{ return &fakePeer{
p: p, p: p,
messagesSent: messagesSent, messagesSent: messagesSent,
......
...@@ -39,8 +39,8 @@ func GenerateCids(n int) []cid.Cid { ...@@ -39,8 +39,8 @@ func GenerateCids(n int) []cid.Cid {
} }
// GenerateWantlist makes a populated wantlist. // GenerateWantlist makes a populated wantlist.
func GenerateWantlist(n int, ses uint64) *wantlist.ThreadSafe { func GenerateWantlist(n int, ses uint64) *wantlist.SessionTrackedWantlist {
wl := wantlist.NewThreadSafe() wl := wantlist.NewSessionTrackedWantlist()
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
prioritySeq++ prioritySeq++
entry := wantlist.NewRefEntry(blockGenerator.Next().Cid(), prioritySeq) entry := wantlist.NewRefEntry(blockGenerator.Next().Cid(), prioritySeq)
......
// package wantlist implements an object for bitswap that contains the keys // Package wantlist implements an object for bitswap that contains the keys
// that a given peer wants. // that a given peer wants.
package wantlist package wantlist
import ( import (
"sort" "sort"
"sync"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
) )
type ThreadSafe struct { type SessionTrackedWantlist struct {
lk sync.RWMutex set map[cid.Cid]*sessionTrackedEntry
set map[cid.Cid]*Entry
} }
// not threadsafe
type Wantlist struct { type Wantlist struct {
set map[cid.Cid]*Entry set map[cid.Cid]*Entry
} }
...@@ -23,17 +20,20 @@ type Entry struct { ...@@ -23,17 +20,20 @@ type Entry struct {
Cid cid.Cid Cid cid.Cid
Priority int Priority int
SesTrk map[uint64]struct{}
// Trash in a book-keeping field // Trash in a book-keeping field
Trash bool Trash bool
} }
type sessionTrackedEntry struct {
*Entry
sesTrk map[uint64]struct{}
}
// NewRefEntry creates a new reference tracked wantlist entry. // NewRefEntry creates a new reference tracked wantlist entry.
func NewRefEntry(c cid.Cid, p int) *Entry { func NewRefEntry(c cid.Cid, p int) *Entry {
return &Entry{ return &Entry{
Cid: c, Cid: c,
Priority: p, Priority: p,
SesTrk: make(map[uint64]struct{}),
} }
} }
...@@ -43,9 +43,9 @@ func (es entrySlice) Len() int { return len(es) } ...@@ -43,9 +43,9 @@ func (es entrySlice) Len() int { return len(es) }
func (es entrySlice) Swap(i, j int) { es[i], es[j] = es[j], es[i] } func (es entrySlice) Swap(i, j int) { es[i], es[j] = es[j], es[i] }
func (es entrySlice) Less(i, j int) bool { return es[i].Priority > es[j].Priority } func (es entrySlice) Less(i, j int) bool { return es[i].Priority > es[j].Priority }
func NewThreadSafe() *ThreadSafe { func NewSessionTrackedWantlist() *SessionTrackedWantlist {
return &ThreadSafe{ return &SessionTrackedWantlist{
set: make(map[cid.Cid]*Entry), set: make(map[cid.Cid]*sessionTrackedEntry),
} }
} }
...@@ -63,33 +63,31 @@ func New() *Wantlist { ...@@ -63,33 +63,31 @@ func New() *Wantlist {
// TODO: think through priority changes here // TODO: think through priority changes here
// Add returns true if the cid did not exist in the wantlist before this call // Add returns true if the cid did not exist in the wantlist before this call
// (even if it was under a different session). // (even if it was under a different session).
func (w *ThreadSafe) Add(c cid.Cid, priority int, ses uint64) bool { func (w *SessionTrackedWantlist) Add(c cid.Cid, priority int, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
if e, ok := w.set[c]; ok { if e, ok := w.set[c]; ok {
e.SesTrk[ses] = struct{}{} e.sesTrk[ses] = struct{}{}
return false return false
} }
w.set[c] = &Entry{ w.set[c] = &sessionTrackedEntry{
Cid: c, Entry: &Entry{Cid: c, Priority: priority},
Priority: priority, sesTrk: map[uint64]struct{}{ses: struct{}{}},
SesTrk: map[uint64]struct{}{ses: struct{}{}},
} }
return true return true
} }
// AddEntry adds given Entry to the wantlist. For more information see Add method. // AddEntry adds given Entry to the wantlist. For more information see Add method.
func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool { func (w *SessionTrackedWantlist) AddEntry(e *Entry, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
if ex, ok := w.set[e.Cid]; ok { if ex, ok := w.set[e.Cid]; ok {
ex.SesTrk[ses] = struct{}{} ex.sesTrk[ses] = struct{}{}
return false return false
} }
w.set[e.Cid] = e w.set[e.Cid] = &sessionTrackedEntry{
e.SesTrk[ses] = struct{}{} Entry: e,
sesTrk: map[uint64]struct{}{ses: struct{}{}},
}
return true return true
} }
...@@ -97,16 +95,14 @@ func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool { ...@@ -97,16 +95,14 @@ func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool {
// 'true' is returned if this call to Remove removed the final session ID // 'true' is returned if this call to Remove removed the final session ID
// tracking the cid. (meaning true will be returned iff this call caused the // tracking the cid. (meaning true will be returned iff this call caused the
// value of 'Contains(c)' to change from true to false) // value of 'Contains(c)' to change from true to false)
func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool { func (w *SessionTrackedWantlist) Remove(c cid.Cid, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
e, ok := w.set[c] e, ok := w.set[c]
if !ok { if !ok {
return false return false
} }
delete(e.SesTrk, ses) delete(e.sesTrk, ses)
if len(e.SesTrk) == 0 { if len(e.sesTrk) == 0 {
delete(w.set, c) delete(w.set, c)
return true return true
} }
...@@ -115,35 +111,40 @@ func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool { ...@@ -115,35 +111,40 @@ func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool {
// Contains returns true if the given cid is in the wantlist tracked by one or // Contains returns true if the given cid is in the wantlist tracked by one or
// more sessions. // more sessions.
func (w *ThreadSafe) Contains(k cid.Cid) (*Entry, bool) { func (w *SessionTrackedWantlist) Contains(k cid.Cid) (*Entry, bool) {
w.lk.RLock()
defer w.lk.RUnlock()
e, ok := w.set[k] e, ok := w.set[k]
return e, ok if !ok {
return nil, false
}
return e.Entry, true
} }
func (w *ThreadSafe) Entries() []*Entry { func (w *SessionTrackedWantlist) Entries() []*Entry {
w.lk.RLock()
defer w.lk.RUnlock()
es := make([]*Entry, 0, len(w.set)) es := make([]*Entry, 0, len(w.set))
for _, e := range w.set { for _, e := range w.set {
es = append(es, e) es = append(es, e.Entry)
} }
return es return es
} }
func (w *ThreadSafe) SortedEntries() []*Entry { func (w *SessionTrackedWantlist) SortedEntries() []*Entry {
es := w.Entries() es := w.Entries()
sort.Sort(entrySlice(es)) sort.Sort(entrySlice(es))
return es return es
} }
func (w *ThreadSafe) Len() int { func (w *SessionTrackedWantlist) Len() int {
w.lk.RLock()
defer w.lk.RUnlock()
return len(w.set) return len(w.set)
} }
func (w *SessionTrackedWantlist) CopyWants(to *SessionTrackedWantlist) {
for _, e := range w.set {
for k := range e.sesTrk {
to.AddEntry(e.Entry, k)
}
}
}
func (w *Wantlist) Len() int { func (w *Wantlist) Len() int {
return len(w.set) return len(w.set)
} }
......
...@@ -82,8 +82,8 @@ func TestBasicWantlist(t *testing.T) { ...@@ -82,8 +82,8 @@ func TestBasicWantlist(t *testing.T) {
} }
} }
func TestSesRefWantlist(t *testing.T) { func TestSessionTrackedWantlist(t *testing.T) {
wl := NewThreadSafe() wl := NewSessionTrackedWantlist()
if !wl.Add(testcids[0], 5, 1) { if !wl.Add(testcids[0], 5, 1) {
t.Fatal("should have added") t.Fatal("should have added")
......
...@@ -24,7 +24,7 @@ const ( ...@@ -24,7 +24,7 @@ const (
// managed by the WantManager. // managed by the WantManager.
type PeerHandler interface { type PeerHandler interface {
Disconnected(p peer.ID) Disconnected(p peer.ID)
Connected(p peer.ID, initialEntries []*wantlist.Entry) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist)
SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64)
} }
...@@ -42,8 +42,8 @@ type WantManager struct { ...@@ -42,8 +42,8 @@ type WantManager struct {
wantMessages chan wantMessage wantMessages chan wantMessage
// synchronized by Run loop, only touch inside there // synchronized by Run loop, only touch inside there
wl *wantlist.ThreadSafe wl *wantlist.SessionTrackedWantlist
bcwl *wantlist.ThreadSafe bcwl *wantlist.SessionTrackedWantlist
ctx context.Context ctx context.Context
cancel func() cancel func()
...@@ -59,8 +59,8 @@ func New(ctx context.Context) *WantManager { ...@@ -59,8 +59,8 @@ func New(ctx context.Context) *WantManager {
"Number of items in wantlist.").Gauge() "Number of items in wantlist.").Gauge()
return &WantManager{ return &WantManager{
wantMessages: make(chan wantMessage, 10), wantMessages: make(chan wantMessage, 10),
wl: wantlist.NewThreadSafe(), wl: wantlist.NewSessionTrackedWantlist(),
bcwl: wantlist.NewThreadSafe(), bcwl: wantlist.NewSessionTrackedWantlist(),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
wantlistGauge: wantlistGauge, wantlistGauge: wantlistGauge,
...@@ -274,7 +274,7 @@ type connectedMessage struct { ...@@ -274,7 +274,7 @@ type connectedMessage struct {
} }
func (cm *connectedMessage) handle(wm *WantManager) { func (cm *connectedMessage) handle(wm *WantManager) {
wm.peerHandler.Connected(cm.p, wm.bcwl.Entries()) wm.peerHandler.Connected(cm.p, wm.bcwl)
} }
type disconnectedMessage struct { type disconnectedMessage struct {
......
...@@ -25,8 +25,8 @@ func (fph *fakePeerHandler) SendMessage(entries []*bsmsg.Entry, targets []peer.I ...@@ -25,8 +25,8 @@ func (fph *fakePeerHandler) SendMessage(entries []*bsmsg.Entry, targets []peer.I
fph.lk.Unlock() fph.lk.Unlock()
} }
func (fph *fakePeerHandler) Connected(p peer.ID, initialEntries []*wantlist.Entry) {} func (fph *fakePeerHandler) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) {}
func (fph *fakePeerHandler) Disconnected(p peer.ID) {} func (fph *fakePeerHandler) Disconnected(p peer.ID) {}
func (fph *fakePeerHandler) getLastWantSet() wantSet { func (fph *fakePeerHandler) getLastWantSet() wantSet {
fph.lk.Lock() fph.lk.Lock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment