diff --git a/messagequeue/messagequeue.go b/messagequeue/messagequeue.go index 2b8f5f7cf467f477f4e0ab50f43e92379dc8d49e..a71425085f164d41469701f254bb25f3745641ec 100644 --- a/messagequeue/messagequeue.go +++ b/messagequeue/messagequeue.go @@ -14,7 +14,10 @@ import ( var log = logging.Logger("bitswap") -const maxRetries = 10 +const ( + defaultRebroadcastInterval = 30 * time.Second + maxRetries = 10 +) // MessageNetwork is any network that can connect peers and generate a message // sender. @@ -33,21 +36,25 @@ type MessageQueue struct { done chan struct{} // do not touch out of run loop - wl *wantlist.SessionTrackedWantlist - nextMessage bsmsg.BitSwapMessage - nextMessageLk sync.RWMutex - sender bsnet.MessageSender + wl *wantlist.SessionTrackedWantlist + nextMessage bsmsg.BitSwapMessage + nextMessageLk sync.RWMutex + sender bsnet.MessageSender + rebroadcastIntervalLk sync.RWMutex + rebroadcastInterval time.Duration + rebroadcastTimer *time.Timer } // New creats a new MessageQueue. func New(ctx context.Context, p peer.ID, network MessageNetwork) *MessageQueue { return &MessageQueue{ - ctx: ctx, - wl: wantlist.NewSessionTrackedWantlist(), - network: network, - p: p, - outgoingWork: make(chan struct{}, 1), - done: make(chan struct{}), + ctx: ctx, + wl: wantlist.NewSessionTrackedWantlist(), + network: network, + p: p, + outgoingWork: make(chan struct{}, 1), + done: make(chan struct{}), + rebroadcastInterval: defaultRebroadcastInterval, } } @@ -64,27 +71,26 @@ func (mq *MessageQueue) AddMessage(entries []bsmsg.Entry, ses uint64) { // AddWantlist adds a complete session tracked want list to a message queue func (mq *MessageQueue) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) { - mq.nextMessageLk.Lock() - defer mq.nextMessageLk.Unlock() - initialWants.CopyWants(mq.wl) - if initialWants.Len() > 0 { - if mq.nextMessage == nil { - mq.nextMessage = bsmsg.New(false) - } - for _, e := range initialWants.Entries() { - mq.nextMessage.AddEntry(e.Cid, e.Priority) - } - select { - case mq.outgoingWork <- struct{}{}: - default: - } + mq.addWantlist() +} + +// SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist +func (mq *MessageQueue) SetRebroadcastInterval(delay time.Duration) { + mq.rebroadcastIntervalLk.Lock() + mq.rebroadcastInterval = delay + if mq.rebroadcastTimer != nil { + mq.rebroadcastTimer.Reset(delay) } + mq.rebroadcastIntervalLk.Unlock() } // Startup starts the processing of messages, and creates an initial message // based on the given initial wantlist. func (mq *MessageQueue) Startup() { + mq.rebroadcastIntervalLk.RLock() + mq.rebroadcastTimer = time.NewTimer(mq.rebroadcastInterval) + mq.rebroadcastIntervalLk.RUnlock() go mq.runQueue() } @@ -96,6 +102,8 @@ func (mq *MessageQueue) Shutdown() { func (mq *MessageQueue) runQueue() { for { select { + case <-mq.rebroadcastTimer.C: + mq.rebroadcastWantlist() case <-mq.outgoingWork: mq.sendMessage() case <-mq.done: @@ -112,6 +120,33 @@ func (mq *MessageQueue) runQueue() { } } +func (mq *MessageQueue) addWantlist() { + + mq.nextMessageLk.Lock() + defer mq.nextMessageLk.Unlock() + + if mq.wl.Len() > 0 { + if mq.nextMessage == nil { + mq.nextMessage = bsmsg.New(false) + } + for _, e := range mq.wl.Entries() { + mq.nextMessage.AddEntry(e.Cid, e.Priority) + } + select { + case mq.outgoingWork <- struct{}{}: + default: + } + } +} + +func (mq *MessageQueue) rebroadcastWantlist() { + mq.rebroadcastIntervalLk.RLock() + mq.rebroadcastTimer.Reset(mq.rebroadcastInterval) + mq.rebroadcastIntervalLk.RUnlock() + + mq.addWantlist() +} + func (mq *MessageQueue) addEntries(entries []bsmsg.Entry, ses uint64) bool { var work bool mq.nextMessageLk.Lock() diff --git a/messagequeue/messagequeue_test.go b/messagequeue/messagequeue_test.go index aeb903ddc981d0705ac4fdbcf804316737c31510..146f21124d4b2f1a4c538614a81259e4585a6106 100644 --- a/messagequeue/messagequeue_test.go +++ b/messagequeue/messagequeue_test.go @@ -158,3 +158,40 @@ func TestSendingMessagesPartialDupe(t *testing.T) { } } + +func TestWantlistRebroadcast(t *testing.T) { + + ctx := context.Background() + messagesSent := make(chan bsmsg.BitSwapMessage) + resetChan := make(chan struct{}, 1) + fullClosedChan := make(chan struct{}, 1) + fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent} + fakenet := &fakeMessageNetwork{nil, nil, fakeSender} + peerID := testutil.GeneratePeers(1)[0] + messageQueue := New(ctx, peerID, fakenet) + ses := testutil.GenerateSessionID() + wl := testutil.GenerateWantlist(10, ses) + + messageQueue.Startup() + messageQueue.AddWantlist(wl) + messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) + if len(messages) != 1 { + t.Fatal("wrong number of messages were sent for initial wants") + } + + messageQueue.SetRebroadcastInterval(5 * time.Millisecond) + messages = collectMessages(ctx, t, messagesSent, 8*time.Millisecond) + if len(messages) != 1 { + t.Fatal("wrong number of messages were rebroadcast") + } + + firstMessage := messages[0] + if len(firstMessage.Wantlist()) != wl.Len() { + t.Fatal("did not add all wants to want list") + } + for _, entry := range firstMessage.Wantlist() { + if entry.Cancel { + t.Fatal("initial add sent cancel entry when it should not have") + } + } +}