diff --git a/internal/messagequeue/messagequeue.go b/internal/messagequeue/messagequeue.go index daf8664bfb92ffbaf964849bb567807e26e4b4ee..ca6f7c3bc6bc8353c3cc59ae415f88970558220b 100644 --- a/internal/messagequeue/messagequeue.go +++ b/internal/messagequeue/messagequeue.go @@ -9,6 +9,7 @@ import ( bsmsg "github.com/ipfs/go-bitswap/message" pb "github.com/ipfs/go-bitswap/message/pb" bsnet "github.com/ipfs/go-bitswap/network" + "github.com/ipfs/go-bitswap/wantlist" bswl "github.com/ipfs/go-bitswap/wantlist" cid "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log" @@ -80,41 +81,44 @@ type MessageQueue struct { msg bsmsg.BitSwapMessage } -// recallWantlist keeps a list of pending wants, and a list of all wants that -// have ever been requested +// recallWantlist keeps a list of pending wants and a list of sent wants type recallWantlist struct { - // The list of all wants that have been requested, including wants that - // have been sent and wants that have not yet been sent - allWants *bswl.Wantlist // The list of wants that have not yet been sent pending *bswl.Wantlist + // The list of wants that have been sent + sent *bswl.Wantlist } func newRecallWantList() recallWantlist { return recallWantlist{ - allWants: bswl.New(), - pending: bswl.New(), + pending: bswl.New(), + sent: bswl.New(), } } -// Add want to both the pending list and the list of all wants +// Add want to the pending list func (r *recallWantlist) Add(c cid.Cid, priority int32, wtype pb.Message_Wantlist_WantType) { - r.allWants.Add(c, priority, wtype) r.pending.Add(c, priority, wtype) } -// Remove wants from both the pending list and the list of all wants +// Remove wants from both the pending list and the list of sent wants func (r *recallWantlist) Remove(c cid.Cid) { - r.allWants.Remove(c) + r.sent.Remove(c) r.pending.Remove(c) } -// Remove wants by type from both the pending list and the list of all wants +// Remove wants by type from both the pending list and the list of sent wants func (r *recallWantlist) RemoveType(c cid.Cid, wtype pb.Message_Wantlist_WantType) { - r.allWants.RemoveType(c, wtype) + r.sent.RemoveType(c, wtype) r.pending.RemoveType(c, wtype) } +// Sent moves the want from the pending to the sent list +func (r *recallWantlist) Sent(e bsmsg.Entry) { + r.pending.RemoveType(e.Cid, e.WantType) + r.sent.Add(e.Cid, e.Priority, e.WantType) +} + type peerConn struct { p peer.ID network MessageNetwork @@ -251,15 +255,29 @@ func (mq *MessageQueue) AddCancels(cancelKs []cid.Cid) { mq.wllock.Lock() defer mq.wllock.Unlock() + workReady := false + // Remove keys from broadcast and peer wants, and add to cancels for _, c := range cancelKs { + // Check if a want for the key was sent + _, wasSentBcst := mq.bcstWants.sent.Contains(c) + _, wasSentPeer := mq.peerWants.sent.Contains(c) + + // Remove the want from tracking wantlists mq.bcstWants.Remove(c) mq.peerWants.Remove(c) - mq.cancels.Add(c) + + // Only send a cancel if a want was sent + if wasSentBcst || wasSentPeer { + mq.cancels.Add(c) + workReady = true + } } // Schedule a message send - mq.signalWorkReady() + if workReady { + mq.signalWorkReady() + } } // SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist @@ -366,13 +384,13 @@ func (mq *MessageQueue) transferRebroadcastWants() bool { defer mq.wllock.Unlock() // Check if there are any wants to rebroadcast - if mq.bcstWants.allWants.Len() == 0 && mq.peerWants.allWants.Len() == 0 { + if mq.bcstWants.sent.Len() == 0 && mq.peerWants.sent.Len() == 0 { return false } - // Copy all wants into pending wants lists - mq.bcstWants.pending.Absorb(mq.bcstWants.allWants) - mq.peerWants.pending.Absorb(mq.peerWants.allWants) + // Copy sent wants into pending wants lists + mq.bcstWants.pending.Absorb(mq.bcstWants.sent) + mq.peerWants.pending.Absorb(mq.peerWants.sent) return true } @@ -405,7 +423,7 @@ func (mq *MessageQueue) sendMessage() { mq.dhTimeoutMgr.Start() // Convert want lists to a Bitswap Message - message := mq.extractOutgoingMessage(mq.sender.SupportsHave()) + message, onSent := mq.extractOutgoingMessage(mq.sender.SupportsHave()) // After processing the message, clear out its fields to save memory defer mq.msg.Reset(false) @@ -421,7 +439,7 @@ func (mq *MessageQueue) sendMessage() { for i := 0; i < maxRetries; i++ { if mq.attemptSendAndRecovery(message) { // We were able to send successfully. - mq.onMessageSent(wantlist) + onSent(wantlist) mq.simulateDontHaveWithTimeout(wantlist) @@ -452,7 +470,7 @@ func (mq *MessageQueue) simulateDontHaveWithTimeout(wantlist []bsmsg.Entry) { // Unlikely, but just in case check that the block hasn't been // received in the interim c := entry.Cid - if _, ok := mq.peerWants.allWants.Contains(c); ok { + if _, ok := mq.peerWants.sent.Contains(c); ok { wants = append(wants, c) } } @@ -522,7 +540,7 @@ func (mq *MessageQueue) pendingWorkCount() int { } // Convert the lists of wants into a Bitswap message -func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapMessage { +func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwapMessage, func([]bsmsg.Entry)) { mq.wllock.Lock() defer mq.wllock.Unlock() @@ -572,19 +590,35 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapM mq.cancels.Remove(c) } - return mq.msg -} + // Called when the message has been successfully sent. + onMessageSent := func(wantlist []bsmsg.Entry) { + bcst := keysToSet(bcstEntries) + prws := keysToSet(peerEntries) -// Called when the message has been successfully sent. -func (mq *MessageQueue) onMessageSent(wantlist []bsmsg.Entry) { - // Remove the sent keys from the broadcast and regular wantlists. - mq.wllock.Lock() - defer mq.wllock.Unlock() + mq.wllock.Lock() + defer mq.wllock.Unlock() - for _, e := range wantlist { - mq.bcstWants.pending.Remove(e.Cid) - mq.peerWants.pending.RemoveType(e.Cid, e.WantType) + // Move the keys from pending to sent + for _, e := range wantlist { + if _, ok := bcst[e.Cid]; ok { + mq.bcstWants.Sent(e) + } + if _, ok := prws[e.Cid]; ok { + mq.peerWants.Sent(e) + } + } + } + + return mq.msg, onMessageSent +} + +// Convert wantlist entries into a set of cids +func keysToSet(wl []wantlist.Entry) map[cid.Cid]struct{} { + set := make(map[cid.Cid]struct{}, len(wl)) + for _, e := range wl { + set[e.Cid] = struct{}{} } + return set } func (mq *MessageQueue) initializeSender() error { diff --git a/internal/messagequeue/messagequeue_test.go b/internal/messagequeue/messagequeue_test.go index 059534057a4e05ecd76f9c190c3888f71dc643e4..49c1033d697d169a8b08cda2117c83817026aa1a 100644 --- a/internal/messagequeue/messagequeue_test.go +++ b/internal/messagequeue/messagequeue_test.go @@ -319,18 +319,22 @@ func TestCancelOverridesPendingWants(t *testing.T) { fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) + wantHaves := testutil.GenerateCids(2) wantBlocks := testutil.GenerateCids(2) + cancels := []cid.Cid{wantBlocks[0], wantHaves[0]} messageQueue.Startup() messageQueue.AddWants(wantBlocks, wantHaves) - messageQueue.AddCancels([]cid.Cid{wantBlocks[0], wantHaves[0]}) + messageQueue.AddCancels(cancels) messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) - if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks) { + if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks)-len(cancels) { t.Fatal("Wrong message count") } + // Cancelled 1 want-block and 1 want-have before they were sent + // so that leaves 1 want-block and 1 want-have wb, wh, cl := filterWantTypes(messages[0]) if len(wb) != 1 || !wb[0].Equals(wantBlocks[1]) { t.Fatal("Expected 1 want-block") @@ -338,6 +342,20 @@ func TestCancelOverridesPendingWants(t *testing.T) { if len(wh) != 1 || !wh[0].Equals(wantHaves[1]) { t.Fatal("Expected 1 want-have") } + // Cancelled wants before they were sent, so no cancel should be sent + // to the network + if len(cl) != 0 { + t.Fatal("Expected no cancels") + } + + // Cancel the remaining want-blocks and want-haves + cancels = append(wantHaves, wantBlocks...) + messageQueue.AddCancels(cancels) + messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond) + + // The remaining 2 cancels should be sent to the network as they are for + // wants that were sent to the network + _, _, cl = filterWantTypes(messages[0]) if len(cl) != 2 { t.Fatal("Expected 2 cancels") } @@ -353,26 +371,41 @@ func TestWantOverridesPendingCancels(t *testing.T) { fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) - cancels := testutil.GenerateCids(3) + + cids := testutil.GenerateCids(3) + wantBlocks := cids[:1] + wantHaves := cids[1:] messageQueue.Startup() - messageQueue.AddCancels(cancels) - messageQueue.AddWants([]cid.Cid{cancels[0]}, []cid.Cid{cancels[1]}) + + // Add 1 want-block and 2 want-haves + messageQueue.AddWants(wantBlocks, wantHaves) + messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) + if totalEntriesLength(messages) != len(wantBlocks)+len(wantHaves) { + t.Fatal("Wrong message count", totalEntriesLength(messages)) + } - if totalEntriesLength(messages) != len(cancels) { - t.Fatal("Wrong message count") + // Cancel existing wants + messageQueue.AddCancels(cids) + // Override one cancel with a want-block (before cancel is sent to network) + messageQueue.AddWants(cids[:1], []cid.Cid{}) + + messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond) + if totalEntriesLength(messages) != 3 { + t.Fatal("Wrong message count", totalEntriesLength(messages)) } + // Should send 1 want-block and 2 cancels wb, wh, cl := filterWantTypes(messages[0]) - if len(wb) != 1 || !wb[0].Equals(cancels[0]) { + if len(wb) != 1 { t.Fatal("Expected 1 want-block") } - if len(wh) != 1 || !wh[0].Equals(cancels[1]) { - t.Fatal("Expected 1 want-have") + if len(wh) != 0 { + t.Fatal("Expected 0 want-have") } - if len(cl) != 1 || !cl[0].Equals(cancels[2]) { - t.Fatal("Expected 1 cancel") + if len(cl) != 2 { + t.Fatal("Expected 2 cancels") } } diff --git a/internal/sessionwantlist/sessionwantlist.go b/internal/sessionwantlist/sessionwantlist.go index d98147396c2b1f8c4379dc4d3c3fea75e7874e97..05c1433671ad1fa9f28a5571eeabcb70eb1dcbce 100644 --- a/internal/sessionwantlist/sessionwantlist.go +++ b/internal/sessionwantlist/sessionwantlist.go @@ -6,6 +6,7 @@ import ( cid "github.com/ipfs/go-cid" ) +// The SessionWantList keeps track of which sessions want a CID type SessionWantlist struct { sync.RWMutex wants map[cid.Cid]map[uint64]struct{} @@ -17,6 +18,7 @@ func NewSessionWantlist() *SessionWantlist { } } +// The given session wants the keys func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) { swl.Lock() defer swl.Unlock() @@ -29,6 +31,8 @@ func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) { } } +// Remove the keys for all sessions. +// Called when blocks are received. func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) { swl.Lock() defer swl.Unlock() @@ -38,6 +42,8 @@ func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) { } } +// Remove the session's wants, and return wants that are no longer wanted by +// any session. func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid { swl.Lock() defer swl.Unlock() @@ -54,6 +60,7 @@ func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid { return deletedKs } +// Remove the session's wants func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) { swl.Lock() defer swl.Unlock() @@ -68,6 +75,7 @@ func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) { } } +// All keys wanted by all sessions func (swl *SessionWantlist) Keys() []cid.Cid { swl.RLock() defer swl.RUnlock() @@ -79,6 +87,7 @@ func (swl *SessionWantlist) Keys() []cid.Cid { return ks } +// All sessions that want the given keys func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 { swl.RLock() defer swl.RUnlock() @@ -97,6 +106,7 @@ func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 { return ses } +// Filter for keys that at least one session wants func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set { swl.RLock() defer swl.RUnlock() @@ -110,6 +120,7 @@ func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set { return has } +// Filter for keys that the given session wants func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set { swl.RLock() defer swl.RUnlock()