Unverified Commit 6728add5 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #345 from ipfs/fix/cancel-leak

fix: in message queue only send cancel if want was sent
parents d44a5f67 4800d07d
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
bsmsg "github.com/ipfs/go-bitswap/message" bsmsg "github.com/ipfs/go-bitswap/message"
pb "github.com/ipfs/go-bitswap/message/pb" pb "github.com/ipfs/go-bitswap/message/pb"
bsnet "github.com/ipfs/go-bitswap/network" bsnet "github.com/ipfs/go-bitswap/network"
"github.com/ipfs/go-bitswap/wantlist"
bswl "github.com/ipfs/go-bitswap/wantlist" bswl "github.com/ipfs/go-bitswap/wantlist"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
...@@ -80,41 +81,44 @@ type MessageQueue struct { ...@@ -80,41 +81,44 @@ type MessageQueue struct {
msg bsmsg.BitSwapMessage msg bsmsg.BitSwapMessage
} }
// recallWantlist keeps a list of pending wants, and a list of all wants that // recallWantlist keeps a list of pending wants and a list of sent wants
// have ever been requested
type recallWantlist struct { type recallWantlist struct {
// The list of all wants that have been requested, including wants that
// have been sent and wants that have not yet been sent
allWants *bswl.Wantlist
// The list of wants that have not yet been sent // The list of wants that have not yet been sent
pending *bswl.Wantlist pending *bswl.Wantlist
// The list of wants that have been sent
sent *bswl.Wantlist
} }
func newRecallWantList() recallWantlist { func newRecallWantList() recallWantlist {
return recallWantlist{ return recallWantlist{
allWants: bswl.New(), pending: bswl.New(),
pending: bswl.New(), sent: bswl.New(),
} }
} }
// Add want to both the pending list and the list of all wants // Add want to the pending list
func (r *recallWantlist) Add(c cid.Cid, priority int32, wtype pb.Message_Wantlist_WantType) { func (r *recallWantlist) Add(c cid.Cid, priority int32, wtype pb.Message_Wantlist_WantType) {
r.allWants.Add(c, priority, wtype)
r.pending.Add(c, priority, wtype) r.pending.Add(c, priority, wtype)
} }
// Remove wants from both the pending list and the list of all wants // Remove wants from both the pending list and the list of sent wants
func (r *recallWantlist) Remove(c cid.Cid) { func (r *recallWantlist) Remove(c cid.Cid) {
r.allWants.Remove(c) r.sent.Remove(c)
r.pending.Remove(c) r.pending.Remove(c)
} }
// Remove wants by type from both the pending list and the list of all wants // Remove wants by type from both the pending list and the list of sent wants
func (r *recallWantlist) RemoveType(c cid.Cid, wtype pb.Message_Wantlist_WantType) { func (r *recallWantlist) RemoveType(c cid.Cid, wtype pb.Message_Wantlist_WantType) {
r.allWants.RemoveType(c, wtype) r.sent.RemoveType(c, wtype)
r.pending.RemoveType(c, wtype) r.pending.RemoveType(c, wtype)
} }
// Sent moves the want from the pending to the sent list
func (r *recallWantlist) Sent(e bsmsg.Entry) {
r.pending.RemoveType(e.Cid, e.WantType)
r.sent.Add(e.Cid, e.Priority, e.WantType)
}
type peerConn struct { type peerConn struct {
p peer.ID p peer.ID
network MessageNetwork network MessageNetwork
...@@ -251,15 +255,29 @@ func (mq *MessageQueue) AddCancels(cancelKs []cid.Cid) { ...@@ -251,15 +255,29 @@ func (mq *MessageQueue) AddCancels(cancelKs []cid.Cid) {
mq.wllock.Lock() mq.wllock.Lock()
defer mq.wllock.Unlock() defer mq.wllock.Unlock()
workReady := false
// Remove keys from broadcast and peer wants, and add to cancels // Remove keys from broadcast and peer wants, and add to cancels
for _, c := range cancelKs { for _, c := range cancelKs {
// Check if a want for the key was sent
_, wasSentBcst := mq.bcstWants.sent.Contains(c)
_, wasSentPeer := mq.peerWants.sent.Contains(c)
// Remove the want from tracking wantlists
mq.bcstWants.Remove(c) mq.bcstWants.Remove(c)
mq.peerWants.Remove(c) mq.peerWants.Remove(c)
mq.cancels.Add(c)
// Only send a cancel if a want was sent
if wasSentBcst || wasSentPeer {
mq.cancels.Add(c)
workReady = true
}
} }
// Schedule a message send // Schedule a message send
mq.signalWorkReady() if workReady {
mq.signalWorkReady()
}
} }
// SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist // SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist
...@@ -366,13 +384,13 @@ func (mq *MessageQueue) transferRebroadcastWants() bool { ...@@ -366,13 +384,13 @@ func (mq *MessageQueue) transferRebroadcastWants() bool {
defer mq.wllock.Unlock() defer mq.wllock.Unlock()
// Check if there are any wants to rebroadcast // Check if there are any wants to rebroadcast
if mq.bcstWants.allWants.Len() == 0 && mq.peerWants.allWants.Len() == 0 { if mq.bcstWants.sent.Len() == 0 && mq.peerWants.sent.Len() == 0 {
return false return false
} }
// Copy all wants into pending wants lists // Copy sent wants into pending wants lists
mq.bcstWants.pending.Absorb(mq.bcstWants.allWants) mq.bcstWants.pending.Absorb(mq.bcstWants.sent)
mq.peerWants.pending.Absorb(mq.peerWants.allWants) mq.peerWants.pending.Absorb(mq.peerWants.sent)
return true return true
} }
...@@ -405,7 +423,7 @@ func (mq *MessageQueue) sendMessage() { ...@@ -405,7 +423,7 @@ func (mq *MessageQueue) sendMessage() {
mq.dhTimeoutMgr.Start() mq.dhTimeoutMgr.Start()
// Convert want lists to a Bitswap Message // Convert want lists to a Bitswap Message
message := mq.extractOutgoingMessage(mq.sender.SupportsHave()) message, onSent := mq.extractOutgoingMessage(mq.sender.SupportsHave())
// After processing the message, clear out its fields to save memory // After processing the message, clear out its fields to save memory
defer mq.msg.Reset(false) defer mq.msg.Reset(false)
...@@ -421,7 +439,7 @@ func (mq *MessageQueue) sendMessage() { ...@@ -421,7 +439,7 @@ func (mq *MessageQueue) sendMessage() {
for i := 0; i < maxRetries; i++ { for i := 0; i < maxRetries; i++ {
if mq.attemptSendAndRecovery(message) { if mq.attemptSendAndRecovery(message) {
// We were able to send successfully. // We were able to send successfully.
mq.onMessageSent(wantlist) onSent(wantlist)
mq.simulateDontHaveWithTimeout(wantlist) mq.simulateDontHaveWithTimeout(wantlist)
...@@ -452,7 +470,7 @@ func (mq *MessageQueue) simulateDontHaveWithTimeout(wantlist []bsmsg.Entry) { ...@@ -452,7 +470,7 @@ func (mq *MessageQueue) simulateDontHaveWithTimeout(wantlist []bsmsg.Entry) {
// Unlikely, but just in case check that the block hasn't been // Unlikely, but just in case check that the block hasn't been
// received in the interim // received in the interim
c := entry.Cid c := entry.Cid
if _, ok := mq.peerWants.allWants.Contains(c); ok { if _, ok := mq.peerWants.sent.Contains(c); ok {
wants = append(wants, c) wants = append(wants, c)
} }
} }
...@@ -522,7 +540,7 @@ func (mq *MessageQueue) pendingWorkCount() int { ...@@ -522,7 +540,7 @@ func (mq *MessageQueue) pendingWorkCount() int {
} }
// Convert the lists of wants into a Bitswap message // Convert the lists of wants into a Bitswap message
func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapMessage { func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwapMessage, func([]bsmsg.Entry)) {
mq.wllock.Lock() mq.wllock.Lock()
defer mq.wllock.Unlock() defer mq.wllock.Unlock()
...@@ -572,19 +590,35 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapM ...@@ -572,19 +590,35 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapM
mq.cancels.Remove(c) mq.cancels.Remove(c)
} }
return mq.msg // Called when the message has been successfully sent.
} onMessageSent := func(wantlist []bsmsg.Entry) {
bcst := keysToSet(bcstEntries)
prws := keysToSet(peerEntries)
// Called when the message has been successfully sent. mq.wllock.Lock()
func (mq *MessageQueue) onMessageSent(wantlist []bsmsg.Entry) { defer mq.wllock.Unlock()
// Remove the sent keys from the broadcast and regular wantlists.
mq.wllock.Lock()
defer mq.wllock.Unlock()
for _, e := range wantlist { // Move the keys from pending to sent
mq.bcstWants.pending.Remove(e.Cid) for _, e := range wantlist {
mq.peerWants.pending.RemoveType(e.Cid, e.WantType) if _, ok := bcst[e.Cid]; ok {
mq.bcstWants.Sent(e)
}
if _, ok := prws[e.Cid]; ok {
mq.peerWants.Sent(e)
}
}
}
return mq.msg, onMessageSent
}
// Convert wantlist entries into a set of cids
func keysToSet(wl []wantlist.Entry) map[cid.Cid]struct{} {
set := make(map[cid.Cid]struct{}, len(wl))
for _, e := range wl {
set[e.Cid] = struct{}{}
} }
return set
} }
func (mq *MessageQueue) initializeSender() error { func (mq *MessageQueue) initializeSender() error {
......
...@@ -319,18 +319,22 @@ func TestCancelOverridesPendingWants(t *testing.T) { ...@@ -319,18 +319,22 @@ func TestCancelOverridesPendingWants(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
wantHaves := testutil.GenerateCids(2) wantHaves := testutil.GenerateCids(2)
wantBlocks := testutil.GenerateCids(2) wantBlocks := testutil.GenerateCids(2)
cancels := []cid.Cid{wantBlocks[0], wantHaves[0]}
messageQueue.Startup() messageQueue.Startup()
messageQueue.AddWants(wantBlocks, wantHaves) messageQueue.AddWants(wantBlocks, wantHaves)
messageQueue.AddCancels([]cid.Cid{wantBlocks[0], wantHaves[0]}) messageQueue.AddCancels(cancels)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks) { if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks)-len(cancels) {
t.Fatal("Wrong message count") t.Fatal("Wrong message count")
} }
// Cancelled 1 want-block and 1 want-have before they were sent
// so that leaves 1 want-block and 1 want-have
wb, wh, cl := filterWantTypes(messages[0]) wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(wantBlocks[1]) { if len(wb) != 1 || !wb[0].Equals(wantBlocks[1]) {
t.Fatal("Expected 1 want-block") t.Fatal("Expected 1 want-block")
...@@ -338,6 +342,20 @@ func TestCancelOverridesPendingWants(t *testing.T) { ...@@ -338,6 +342,20 @@ func TestCancelOverridesPendingWants(t *testing.T) {
if len(wh) != 1 || !wh[0].Equals(wantHaves[1]) { if len(wh) != 1 || !wh[0].Equals(wantHaves[1]) {
t.Fatal("Expected 1 want-have") t.Fatal("Expected 1 want-have")
} }
// Cancelled wants before they were sent, so no cancel should be sent
// to the network
if len(cl) != 0 {
t.Fatal("Expected no cancels")
}
// Cancel the remaining want-blocks and want-haves
cancels = append(wantHaves, wantBlocks...)
messageQueue.AddCancels(cancels)
messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
// The remaining 2 cancels should be sent to the network as they are for
// wants that were sent to the network
_, _, cl = filterWantTypes(messages[0])
if len(cl) != 2 { if len(cl) != 2 {
t.Fatal("Expected 2 cancels") t.Fatal("Expected 2 cancels")
} }
...@@ -353,26 +371,41 @@ func TestWantOverridesPendingCancels(t *testing.T) { ...@@ -353,26 +371,41 @@ func TestWantOverridesPendingCancels(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
cancels := testutil.GenerateCids(3)
cids := testutil.GenerateCids(3)
wantBlocks := cids[:1]
wantHaves := cids[1:]
messageQueue.Startup() messageQueue.Startup()
messageQueue.AddCancels(cancels)
messageQueue.AddWants([]cid.Cid{cancels[0]}, []cid.Cid{cancels[1]}) // Add 1 want-block and 2 want-haves
messageQueue.AddWants(wantBlocks, wantHaves)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != len(wantBlocks)+len(wantHaves) {
t.Fatal("Wrong message count", totalEntriesLength(messages))
}
if totalEntriesLength(messages) != len(cancels) { // Cancel existing wants
t.Fatal("Wrong message count") messageQueue.AddCancels(cids)
// Override one cancel with a want-block (before cancel is sent to network)
messageQueue.AddWants(cids[:1], []cid.Cid{})
messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != 3 {
t.Fatal("Wrong message count", totalEntriesLength(messages))
} }
// Should send 1 want-block and 2 cancels
wb, wh, cl := filterWantTypes(messages[0]) wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(cancels[0]) { if len(wb) != 1 {
t.Fatal("Expected 1 want-block") t.Fatal("Expected 1 want-block")
} }
if len(wh) != 1 || !wh[0].Equals(cancels[1]) { if len(wh) != 0 {
t.Fatal("Expected 1 want-have") t.Fatal("Expected 0 want-have")
} }
if len(cl) != 1 || !cl[0].Equals(cancels[2]) { if len(cl) != 2 {
t.Fatal("Expected 1 cancel") t.Fatal("Expected 2 cancels")
} }
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
) )
// The SessionWantList keeps track of which sessions want a CID
type SessionWantlist struct { type SessionWantlist struct {
sync.RWMutex sync.RWMutex
wants map[cid.Cid]map[uint64]struct{} wants map[cid.Cid]map[uint64]struct{}
...@@ -17,6 +18,7 @@ func NewSessionWantlist() *SessionWantlist { ...@@ -17,6 +18,7 @@ func NewSessionWantlist() *SessionWantlist {
} }
} }
// The given session wants the keys
func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) { func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
swl.Lock() swl.Lock()
defer swl.Unlock() defer swl.Unlock()
...@@ -29,6 +31,8 @@ func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) { ...@@ -29,6 +31,8 @@ func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
} }
} }
// Remove the keys for all sessions.
// Called when blocks are received.
func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) { func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
swl.Lock() swl.Lock()
defer swl.Unlock() defer swl.Unlock()
...@@ -38,6 +42,8 @@ func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) { ...@@ -38,6 +42,8 @@ func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
} }
} }
// Remove the session's wants, and return wants that are no longer wanted by
// any session.
func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid { func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
swl.Lock() swl.Lock()
defer swl.Unlock() defer swl.Unlock()
...@@ -54,6 +60,7 @@ func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid { ...@@ -54,6 +60,7 @@ func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
return deletedKs return deletedKs
} }
// Remove the session's wants
func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) { func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
swl.Lock() swl.Lock()
defer swl.Unlock() defer swl.Unlock()
...@@ -68,6 +75,7 @@ func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) { ...@@ -68,6 +75,7 @@ func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
} }
} }
// All keys wanted by all sessions
func (swl *SessionWantlist) Keys() []cid.Cid { func (swl *SessionWantlist) Keys() []cid.Cid {
swl.RLock() swl.RLock()
defer swl.RUnlock() defer swl.RUnlock()
...@@ -79,6 +87,7 @@ func (swl *SessionWantlist) Keys() []cid.Cid { ...@@ -79,6 +87,7 @@ func (swl *SessionWantlist) Keys() []cid.Cid {
return ks return ks
} }
// All sessions that want the given keys
func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 { func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
swl.RLock() swl.RLock()
defer swl.RUnlock() defer swl.RUnlock()
...@@ -97,6 +106,7 @@ func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 { ...@@ -97,6 +106,7 @@ func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
return ses return ses
} }
// Filter for keys that at least one session wants
func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set { func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
swl.RLock() swl.RLock()
defer swl.RUnlock() defer swl.RUnlock()
...@@ -110,6 +120,7 @@ func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set { ...@@ -110,6 +120,7 @@ func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
return has return has
} }
// Filter for keys that the given session wants
func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set { func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set {
swl.RLock() swl.RLock()
defer swl.RUnlock() defer swl.RUnlock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment