Commit b38f4513 authored by Dirk McCormick's avatar Dirk McCormick

fix: ensure unique tag for session connection protection

parent a38d8a9c
...@@ -94,12 +94,11 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { ...@@ -94,12 +94,11 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
return false return false
} }
func (fpt *fakePeerTagger) isProtected(p peer.ID, tag string) bool { func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock() fpt.lk.Lock()
defer fpt.lk.Unlock() defer fpt.lk.Unlock()
_, ok := fpt.protectedPeers[p][tag] return len(fpt.protectedPeers[p]) > 0
return ok
} }
type fakeProviderFinder struct { type fakeProviderFinder struct {
......
...@@ -2,7 +2,6 @@ package session ...@@ -2,7 +2,6 @@ package session
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"testing" "testing"
"time" "time"
...@@ -383,7 +382,6 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { ...@@ -383,7 +382,6 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
peerB := peers[1] peerB := peers[1]
peerC := peers[2] peerC := peers[2]
sid := uint64(1) sid := uint64(1)
sidStr := fmt.Sprintf("%d", sid)
pm := newMockPeerManager() pm := newMockPeerManager()
fpt := newFakePeerTagger() fpt := newFakePeerTagger()
fpm := bsspm.New(1, fpt) fpm := bsspm.New(1, fpt)
...@@ -406,7 +404,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { ...@@ -406,7 +404,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Expect peer A to be protected as it was first to send the block // Expect peer A to be protected as it was first to send the block
if !fpt.isProtected(peerA, sidStr) { if !fpt.isProtected(peerA) {
t.Fatal("Expected first peer to send block to have protected connection") t.Fatal("Expected first peer to send block to have protected connection")
} }
...@@ -417,7 +415,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { ...@@ -417,7 +415,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Expect peer B not to be protected as it was not first to send the block // Expect peer B not to be protected as it was not first to send the block
if fpt.isProtected(peerB, sidStr) { if fpt.isProtected(peerB) {
t.Fatal("Expected peer not to be protected") t.Fatal("Expected peer not to be protected")
} }
...@@ -428,7 +426,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { ...@@ -428,7 +426,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// Expect peer C not to be protected as we didn't want the block it sent // Expect peer C not to be protected as we didn't want the block it sent
if fpt.isProtected(peerC, sidStr) { if fpt.isProtected(peerC) {
t.Fatal("Expected peer not to be protected") t.Fatal("Expected peer not to be protected")
} }
} }
......
...@@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { ...@@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
return return
} }
spm.tagger.Protect(p, spm.protectedTag()) spm.tagger.Protect(p, spm.tag)
} }
// RemovePeer removes the peer from the SessionPeerManager. // RemovePeer removes the peer from the SessionPeerManager.
...@@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { ...@@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
delete(spm.peers, p) delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.protectedTag()) spm.tagger.Unprotect(p, spm.tag)
log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true return true
...@@ -145,10 +145,6 @@ func (spm *SessionPeerManager) Shutdown() { ...@@ -145,10 +145,6 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers // connections to those peers
for p := range spm.peers { for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.protectedTag()) spm.tagger.Unprotect(p, spm.tag)
} }
} }
func (spm *SessionPeerManager) protectedTag() string {
return fmt.Sprintf("%d", spm.id)
}
package sessionpeermanager package sessionpeermanager
import ( import (
"fmt"
"sync" "sync"
"testing" "testing"
...@@ -71,6 +70,13 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { ...@@ -71,6 +70,13 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
return false return false
} }
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
}
func TestAddPeers(t *testing.T) { func TestAddPeers(t *testing.T) {
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
spm := New(1, &fakePeerTagger{}) spm := New(1, &fakePeerTagger{})
...@@ -247,26 +253,24 @@ func TestProtectConnection(t *testing.T) { ...@@ -247,26 +253,24 @@ func TestProtectConnection(t *testing.T) {
peers := testutil.GeneratePeers(1) peers := testutil.GeneratePeers(1)
peerA := peers[0] peerA := peers[0]
fpt := newFakePeerTagger() fpt := newFakePeerTagger()
sid := 1
sidstr := fmt.Sprintf("%d", sid)
spm := New(1, fpt) spm := New(1, fpt)
// Should not protect connection if peer hasn't been added yet // Should not protect connection if peer hasn't been added yet
spm.ProtectConnection(peerA) spm.ProtectConnection(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { if fpt.isProtected(peerA) {
t.Fatal("Expected peer not to be protected") t.Fatal("Expected peer not to be protected")
} }
// Once peer is added, should be able to protect connection // Once peer is added, should be able to protect connection
spm.AddPeer(peerA) spm.AddPeer(peerA)
spm.ProtectConnection(peerA) spm.ProtectConnection(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; !ok { if !fpt.isProtected(peerA) {
t.Fatal("Expected peer to be protected") t.Fatal("Expected peer to be protected")
} }
// Removing peer should unprotect connection // Removing peer should unprotect connection
spm.RemovePeer(peerA) spm.RemovePeer(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { if fpt.isProtected(peerA) {
t.Fatal("Expected peer to be unprotected") t.Fatal("Expected peer to be unprotected")
} }
} }
...@@ -274,9 +278,7 @@ func TestProtectConnection(t *testing.T) { ...@@ -274,9 +278,7 @@ func TestProtectConnection(t *testing.T) {
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
fpt := newFakePeerTagger() fpt := newFakePeerTagger()
sid := uint64(1) spm := New(1, fpt)
sidstr := fmt.Sprintf("%d", sid)
spm := New(sid, fpt)
spm.AddPeer(peers[0]) spm.AddPeer(peers[0])
spm.AddPeer(peers[1]) spm.AddPeer(peers[1])
...@@ -285,7 +287,7 @@ func TestShutdown(t *testing.T) { ...@@ -285,7 +287,7 @@ func TestShutdown(t *testing.T) {
} }
spm.ProtectConnection(peers[0]) spm.ProtectConnection(peers[0])
if _, ok := fpt.protectedPeers[peers[0]][sidstr]; !ok { if !fpt.isProtected(peers[0]) {
t.Fatal("Expected peer to be protected") t.Fatal("Expected peer to be protected")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment