Unverified Commit 6728add5 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #345 from ipfs/fix/cancel-leak

fix: in message queue only send cancel if want was sent
parents d44a5f67 4800d07d
......@@ -9,6 +9,7 @@ import (
bsmsg "github.com/ipfs/go-bitswap/message"
pb "github.com/ipfs/go-bitswap/message/pb"
bsnet "github.com/ipfs/go-bitswap/network"
"github.com/ipfs/go-bitswap/wantlist"
bswl "github.com/ipfs/go-bitswap/wantlist"
cid "github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log"
......@@ -80,41 +81,44 @@ type MessageQueue struct {
msg bsmsg.BitSwapMessage
}
// recallWantlist keeps a list of pending wants, and a list of all wants that
// have ever been requested
// recallWantlist keeps a list of pending wants and a list of sent wants
type recallWantlist struct {
// The list of all wants that have been requested, including wants that
// have been sent and wants that have not yet been sent
allWants *bswl.Wantlist
// The list of wants that have not yet been sent
pending *bswl.Wantlist
// The list of wants that have been sent
sent *bswl.Wantlist
}
func newRecallWantList() recallWantlist {
return recallWantlist{
allWants: bswl.New(),
pending: bswl.New(),
pending: bswl.New(),
sent: bswl.New(),
}
}
// Add want to both the pending list and the list of all wants
// Add want to the pending list
func (r *recallWantlist) Add(c cid.Cid, priority int32, wtype pb.Message_Wantlist_WantType) {
r.allWants.Add(c, priority, wtype)
r.pending.Add(c, priority, wtype)
}
// Remove wants from both the pending list and the list of all wants
// Remove wants from both the pending list and the list of sent wants
func (r *recallWantlist) Remove(c cid.Cid) {
r.allWants.Remove(c)
r.sent.Remove(c)
r.pending.Remove(c)
}
// Remove wants by type from both the pending list and the list of all wants
// Remove wants by type from both the pending list and the list of sent wants
func (r *recallWantlist) RemoveType(c cid.Cid, wtype pb.Message_Wantlist_WantType) {
r.allWants.RemoveType(c, wtype)
r.sent.RemoveType(c, wtype)
r.pending.RemoveType(c, wtype)
}
// Sent moves the want from the pending to the sent list
func (r *recallWantlist) Sent(e bsmsg.Entry) {
r.pending.RemoveType(e.Cid, e.WantType)
r.sent.Add(e.Cid, e.Priority, e.WantType)
}
type peerConn struct {
p peer.ID
network MessageNetwork
......@@ -251,15 +255,29 @@ func (mq *MessageQueue) AddCancels(cancelKs []cid.Cid) {
mq.wllock.Lock()
defer mq.wllock.Unlock()
workReady := false
// Remove keys from broadcast and peer wants, and add to cancels
for _, c := range cancelKs {
// Check if a want for the key was sent
_, wasSentBcst := mq.bcstWants.sent.Contains(c)
_, wasSentPeer := mq.peerWants.sent.Contains(c)
// Remove the want from tracking wantlists
mq.bcstWants.Remove(c)
mq.peerWants.Remove(c)
mq.cancels.Add(c)
// Only send a cancel if a want was sent
if wasSentBcst || wasSentPeer {
mq.cancels.Add(c)
workReady = true
}
}
// Schedule a message send
mq.signalWorkReady()
if workReady {
mq.signalWorkReady()
}
}
// SetRebroadcastInterval sets a new interval on which to rebroadcast the full wantlist
......@@ -366,13 +384,13 @@ func (mq *MessageQueue) transferRebroadcastWants() bool {
defer mq.wllock.Unlock()
// Check if there are any wants to rebroadcast
if mq.bcstWants.allWants.Len() == 0 && mq.peerWants.allWants.Len() == 0 {
if mq.bcstWants.sent.Len() == 0 && mq.peerWants.sent.Len() == 0 {
return false
}
// Copy all wants into pending wants lists
mq.bcstWants.pending.Absorb(mq.bcstWants.allWants)
mq.peerWants.pending.Absorb(mq.peerWants.allWants)
// Copy sent wants into pending wants lists
mq.bcstWants.pending.Absorb(mq.bcstWants.sent)
mq.peerWants.pending.Absorb(mq.peerWants.sent)
return true
}
......@@ -405,7 +423,7 @@ func (mq *MessageQueue) sendMessage() {
mq.dhTimeoutMgr.Start()
// Convert want lists to a Bitswap Message
message := mq.extractOutgoingMessage(mq.sender.SupportsHave())
message, onSent := mq.extractOutgoingMessage(mq.sender.SupportsHave())
// After processing the message, clear out its fields to save memory
defer mq.msg.Reset(false)
......@@ -421,7 +439,7 @@ func (mq *MessageQueue) sendMessage() {
for i := 0; i < maxRetries; i++ {
if mq.attemptSendAndRecovery(message) {
// We were able to send successfully.
mq.onMessageSent(wantlist)
onSent(wantlist)
mq.simulateDontHaveWithTimeout(wantlist)
......@@ -452,7 +470,7 @@ func (mq *MessageQueue) simulateDontHaveWithTimeout(wantlist []bsmsg.Entry) {
// Unlikely, but just in case check that the block hasn't been
// received in the interim
c := entry.Cid
if _, ok := mq.peerWants.allWants.Contains(c); ok {
if _, ok := mq.peerWants.sent.Contains(c); ok {
wants = append(wants, c)
}
}
......@@ -522,7 +540,7 @@ func (mq *MessageQueue) pendingWorkCount() int {
}
// Convert the lists of wants into a Bitswap message
func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapMessage {
func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwapMessage, func([]bsmsg.Entry)) {
mq.wllock.Lock()
defer mq.wllock.Unlock()
......@@ -572,19 +590,35 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapM
mq.cancels.Remove(c)
}
return mq.msg
}
// Called when the message has been successfully sent.
onMessageSent := func(wantlist []bsmsg.Entry) {
bcst := keysToSet(bcstEntries)
prws := keysToSet(peerEntries)
// Called when the message has been successfully sent.
func (mq *MessageQueue) onMessageSent(wantlist []bsmsg.Entry) {
// Remove the sent keys from the broadcast and regular wantlists.
mq.wllock.Lock()
defer mq.wllock.Unlock()
mq.wllock.Lock()
defer mq.wllock.Unlock()
for _, e := range wantlist {
mq.bcstWants.pending.Remove(e.Cid)
mq.peerWants.pending.RemoveType(e.Cid, e.WantType)
// Move the keys from pending to sent
for _, e := range wantlist {
if _, ok := bcst[e.Cid]; ok {
mq.bcstWants.Sent(e)
}
if _, ok := prws[e.Cid]; ok {
mq.peerWants.Sent(e)
}
}
}
return mq.msg, onMessageSent
}
// Convert wantlist entries into a set of cids
func keysToSet(wl []wantlist.Entry) map[cid.Cid]struct{} {
set := make(map[cid.Cid]struct{}, len(wl))
for _, e := range wl {
set[e.Cid] = struct{}{}
}
return set
}
func (mq *MessageQueue) initializeSender() error {
......
......@@ -319,18 +319,22 @@ func TestCancelOverridesPendingWants(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
wantHaves := testutil.GenerateCids(2)
wantBlocks := testutil.GenerateCids(2)
cancels := []cid.Cid{wantBlocks[0], wantHaves[0]}
messageQueue.Startup()
messageQueue.AddWants(wantBlocks, wantHaves)
messageQueue.AddCancels([]cid.Cid{wantBlocks[0], wantHaves[0]})
messageQueue.AddCancels(cancels)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks) {
if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks)-len(cancels) {
t.Fatal("Wrong message count")
}
// Cancelled 1 want-block and 1 want-have before they were sent
// so that leaves 1 want-block and 1 want-have
wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(wantBlocks[1]) {
t.Fatal("Expected 1 want-block")
......@@ -338,6 +342,20 @@ func TestCancelOverridesPendingWants(t *testing.T) {
if len(wh) != 1 || !wh[0].Equals(wantHaves[1]) {
t.Fatal("Expected 1 want-have")
}
// Cancelled wants before they were sent, so no cancel should be sent
// to the network
if len(cl) != 0 {
t.Fatal("Expected no cancels")
}
// Cancel the remaining want-blocks and want-haves
cancels = append(wantHaves, wantBlocks...)
messageQueue.AddCancels(cancels)
messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
// The remaining 2 cancels should be sent to the network as they are for
// wants that were sent to the network
_, _, cl = filterWantTypes(messages[0])
if len(cl) != 2 {
t.Fatal("Expected 2 cancels")
}
......@@ -353,26 +371,41 @@ func TestWantOverridesPendingCancels(t *testing.T) {
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
cancels := testutil.GenerateCids(3)
cids := testutil.GenerateCids(3)
wantBlocks := cids[:1]
wantHaves := cids[1:]
messageQueue.Startup()
messageQueue.AddCancels(cancels)
messageQueue.AddWants([]cid.Cid{cancels[0]}, []cid.Cid{cancels[1]})
// Add 1 want-block and 2 want-haves
messageQueue.AddWants(wantBlocks, wantHaves)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != len(wantBlocks)+len(wantHaves) {
t.Fatal("Wrong message count", totalEntriesLength(messages))
}
if totalEntriesLength(messages) != len(cancels) {
t.Fatal("Wrong message count")
// Cancel existing wants
messageQueue.AddCancels(cids)
// Override one cancel with a want-block (before cancel is sent to network)
messageQueue.AddWants(cids[:1], []cid.Cid{})
messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if totalEntriesLength(messages) != 3 {
t.Fatal("Wrong message count", totalEntriesLength(messages))
}
// Should send 1 want-block and 2 cancels
wb, wh, cl := filterWantTypes(messages[0])
if len(wb) != 1 || !wb[0].Equals(cancels[0]) {
if len(wb) != 1 {
t.Fatal("Expected 1 want-block")
}
if len(wh) != 1 || !wh[0].Equals(cancels[1]) {
t.Fatal("Expected 1 want-have")
if len(wh) != 0 {
t.Fatal("Expected 0 want-have")
}
if len(cl) != 1 || !cl[0].Equals(cancels[2]) {
t.Fatal("Expected 1 cancel")
if len(cl) != 2 {
t.Fatal("Expected 2 cancels")
}
}
......
......@@ -6,6 +6,7 @@ import (
cid "github.com/ipfs/go-cid"
)
// The SessionWantList keeps track of which sessions want a CID
type SessionWantlist struct {
sync.RWMutex
wants map[cid.Cid]map[uint64]struct{}
......@@ -17,6 +18,7 @@ func NewSessionWantlist() *SessionWantlist {
}
}
// The given session wants the keys
func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
swl.Lock()
defer swl.Unlock()
......@@ -29,6 +31,8 @@ func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) {
}
}
// Remove the keys for all sessions.
// Called when blocks are received.
func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
swl.Lock()
defer swl.Unlock()
......@@ -38,6 +42,8 @@ func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) {
}
}
// Remove the session's wants, and return wants that are no longer wanted by
// any session.
func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
swl.Lock()
defer swl.Unlock()
......@@ -54,6 +60,7 @@ func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid {
return deletedKs
}
// Remove the session's wants
func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
swl.Lock()
defer swl.Unlock()
......@@ -68,6 +75,7 @@ func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) {
}
}
// All keys wanted by all sessions
func (swl *SessionWantlist) Keys() []cid.Cid {
swl.RLock()
defer swl.RUnlock()
......@@ -79,6 +87,7 @@ func (swl *SessionWantlist) Keys() []cid.Cid {
return ks
}
// All sessions that want the given keys
func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
swl.RLock()
defer swl.RUnlock()
......@@ -97,6 +106,7 @@ func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 {
return ses
}
// Filter for keys that at least one session wants
func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
swl.RLock()
defer swl.RUnlock()
......@@ -110,6 +120,7 @@ func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set {
return has
}
// Filter for keys that the given session wants
func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set {
swl.RLock()
defer swl.RUnlock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment