diff --git a/internal/session/session_test.go b/internal/session/session_test.go index b6aa5b5eeded361712cf3982b601ae726641081a..08bc9f88b397c0d48abb27b3b203fedc3aa25c9d 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -94,12 +94,11 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { return false } -func (fpt *fakePeerTagger) isProtected(p peer.ID, tag string) bool { +func (fpt *fakePeerTagger) isProtected(p peer.ID) bool { fpt.lk.Lock() defer fpt.lk.Unlock() - _, ok := fpt.protectedPeers[p][tag] - return ok + return len(fpt.protectedPeers[p]) > 0 } type fakeProviderFinder struct { diff --git a/internal/session/sessionwantsender_test.go b/internal/session/sessionwantsender_test.go index 08c465bf71e89a6d076b556cd06e0b158b5d1d15..806112f55f0d09619a064c951dcc1179d4ccda49 100644 --- a/internal/session/sessionwantsender_test.go +++ b/internal/session/sessionwantsender_test.go @@ -2,7 +2,6 @@ package session import ( "context" - "fmt" "sync" "testing" "time" @@ -383,7 +382,6 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { peerB := peers[1] peerC := peers[2] sid := uint64(1) - sidStr := fmt.Sprintf("%d", sid) pm := newMockPeerManager() fpt := newFakePeerTagger() fpm := bsspm.New(1, fpt) @@ -406,7 +404,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { time.Sleep(10 * time.Millisecond) // Expect peer A to be protected as it was first to send the block - if !fpt.isProtected(peerA, sidStr) { + if !fpt.isProtected(peerA) { t.Fatal("Expected first peer to send block to have protected connection") } @@ -417,7 +415,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { time.Sleep(10 * time.Millisecond) // Expect peer B not to be protected as it was not first to send the block - if fpt.isProtected(peerB, sidStr) { + if fpt.isProtected(peerB) { t.Fatal("Expected peer not to be protected") } @@ -428,7 +426,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { time.Sleep(10 * time.Millisecond) // Expect peer C not to be protected as we didn't want the block it sent - if fpt.isProtected(peerC, sidStr) { + if fpt.isProtected(peerC) { t.Fatal("Expected peer not to be protected") } } diff --git a/internal/sessionpeermanager/sessionpeermanager.go b/internal/sessionpeermanager/sessionpeermanager.go index e5442d5c4e4245e1e892e81d82c564858912ff46..db46691b9442ede94e2f20ae4bd711afe8c53dab 100644 --- a/internal/sessionpeermanager/sessionpeermanager.go +++ b/internal/sessionpeermanager/sessionpeermanager.go @@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { return } - spm.tagger.Protect(p, spm.protectedTag()) + spm.tagger.Protect(p, spm.tag) } // RemovePeer removes the peer from the SessionPeerManager. @@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { delete(spm.peers, p) spm.tagger.UntagPeer(p, spm.tag) - spm.tagger.Unprotect(p, spm.protectedTag()) + spm.tagger.Unprotect(p, spm.tag) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) return true @@ -145,10 +145,6 @@ func (spm *SessionPeerManager) Shutdown() { // connections to those peers for p := range spm.peers { spm.tagger.UntagPeer(p, spm.tag) - spm.tagger.Unprotect(p, spm.protectedTag()) + spm.tagger.Unprotect(p, spm.tag) } } - -func (spm *SessionPeerManager) protectedTag() string { - return fmt.Sprintf("%d", spm.id) -} diff --git a/internal/sessionpeermanager/sessionpeermanager_test.go b/internal/sessionpeermanager/sessionpeermanager_test.go index 7bb36b342f60fd98157618dc82cf536376a78a6c..746333c22772de3be01e9630fc82d6a9c66bb5b3 100644 --- a/internal/sessionpeermanager/sessionpeermanager_test.go +++ b/internal/sessionpeermanager/sessionpeermanager_test.go @@ -1,7 +1,6 @@ package sessionpeermanager import ( - "fmt" "sync" "testing" @@ -71,6 +70,13 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { return false } +func (fpt *fakePeerTagger) isProtected(p peer.ID) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + return len(fpt.protectedPeers[p]) > 0 +} + func TestAddPeers(t *testing.T) { peers := testutil.GeneratePeers(2) spm := New(1, &fakePeerTagger{}) @@ -247,26 +253,24 @@ func TestProtectConnection(t *testing.T) { peers := testutil.GeneratePeers(1) peerA := peers[0] fpt := newFakePeerTagger() - sid := 1 - sidstr := fmt.Sprintf("%d", sid) spm := New(1, fpt) // Should not protect connection if peer hasn't been added yet spm.ProtectConnection(peerA) - if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { + if fpt.isProtected(peerA) { t.Fatal("Expected peer not to be protected") } // Once peer is added, should be able to protect connection spm.AddPeer(peerA) spm.ProtectConnection(peerA) - if _, ok := fpt.protectedPeers[peerA][sidstr]; !ok { + if !fpt.isProtected(peerA) { t.Fatal("Expected peer to be protected") } // Removing peer should unprotect connection spm.RemovePeer(peerA) - if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { + if fpt.isProtected(peerA) { t.Fatal("Expected peer to be unprotected") } } @@ -274,9 +278,7 @@ func TestProtectConnection(t *testing.T) { func TestShutdown(t *testing.T) { peers := testutil.GeneratePeers(2) fpt := newFakePeerTagger() - sid := uint64(1) - sidstr := fmt.Sprintf("%d", sid) - spm := New(sid, fpt) + spm := New(1, fpt) spm.AddPeer(peers[0]) spm.AddPeer(peers[1]) @@ -285,7 +287,7 @@ func TestShutdown(t *testing.T) { } spm.ProtectConnection(peers[0]) - if _, ok := fpt.protectedPeers[peers[0]][sidstr]; !ok { + if !fpt.isProtected(peers[0]) { t.Fatal("Expected peer to be protected") }