From c7e7afca3f78a56d19088cb5023f0b5e0379daed Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Wed, 3 Jun 2020 16:10:34 -0400 Subject: [PATCH] fix: ensure conns are unprotected on shutdown --- .../sessionpeermanager/sessionpeermanager.go | 9 +++++++-- .../sessionpeermanager_test.go | 17 +++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/internal/sessionpeermanager/sessionpeermanager.go b/internal/sessionpeermanager/sessionpeermanager.go index 1ad144d..e5442d5 100644 --- a/internal/sessionpeermanager/sessionpeermanager.go +++ b/internal/sessionpeermanager/sessionpeermanager.go @@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { return } - spm.tagger.Protect(p, fmt.Sprintf("%d", spm.id)) + spm.tagger.Protect(p, spm.protectedTag()) } // RemovePeer removes the peer from the SessionPeerManager. @@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { delete(spm.peers, p) spm.tagger.UntagPeer(p, spm.tag) - spm.tagger.Unprotect(p, fmt.Sprintf("%d", spm.id)) + spm.tagger.Unprotect(p, spm.protectedTag()) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) return true @@ -145,5 +145,10 @@ func (spm *SessionPeerManager) Shutdown() { // connections to those peers for p := range spm.peers { spm.tagger.UntagPeer(p, spm.tag) + spm.tagger.Unprotect(p, spm.protectedTag()) } } + +func (spm *SessionPeerManager) protectedTag() string { + return fmt.Sprintf("%d", spm.id) +} diff --git a/internal/sessionpeermanager/sessionpeermanager_test.go b/internal/sessionpeermanager/sessionpeermanager_test.go index ba3a342..7bb36b3 100644 --- a/internal/sessionpeermanager/sessionpeermanager_test.go +++ b/internal/sessionpeermanager/sessionpeermanager_test.go @@ -62,6 +62,9 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { if tags, ok := fpt.protectedPeers[p]; ok { delete(tags, tag) + if len(tags) == 0 { + delete(fpt.protectedPeers, p) + } return len(tags) > 0 } @@ -270,8 +273,10 @@ func TestProtectConnection(t *testing.T) { func TestShutdown(t *testing.T) { peers := testutil.GeneratePeers(2) - fpt := &fakePeerTagger{} - spm := New(1, fpt) + fpt := newFakePeerTagger() + sid := uint64(1) + sidstr := fmt.Sprintf("%d", sid) + spm := New(sid, fpt) spm.AddPeer(peers[0]) spm.AddPeer(peers[1]) @@ -279,9 +284,17 @@ func TestShutdown(t *testing.T) { t.Fatal("Expected to have tagged two peers") } + spm.ProtectConnection(peers[0]) + if _, ok := fpt.protectedPeers[peers[0]][sidstr]; !ok { + t.Fatal("Expected peer to be protected") + } + spm.Shutdown() if len(fpt.taggedPeers) != 0 { t.Fatal("Expected to have untagged all peers") } + if len(fpt.protectedPeers) != 0 { + t.Fatal("Expected to have unprotected all peers") + } } -- GitLab