Commit b38f4513 authored by Dirk McCormick's avatar Dirk McCormick

fix: ensure unique tag for session connection protection

parent a38d8a9c
......@@ -94,12 +94,11 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
return false
}
func (fpt *fakePeerTagger) isProtected(p peer.ID, tag string) bool {
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
_, ok := fpt.protectedPeers[p][tag]
return ok
return len(fpt.protectedPeers[p]) > 0
}
type fakeProviderFinder struct {
......
......@@ -2,7 +2,6 @@ package session
import (
"context"
"fmt"
"sync"
"testing"
"time"
......@@ -383,7 +382,6 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
peerB := peers[1]
peerC := peers[2]
sid := uint64(1)
sidStr := fmt.Sprintf("%d", sid)
pm := newMockPeerManager()
fpt := newFakePeerTagger()
fpm := bsspm.New(1, fpt)
......@@ -406,7 +404,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond)
// Expect peer A to be protected as it was first to send the block
if !fpt.isProtected(peerA, sidStr) {
if !fpt.isProtected(peerA) {
t.Fatal("Expected first peer to send block to have protected connection")
}
......@@ -417,7 +415,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond)
// Expect peer B not to be protected as it was not first to send the block
if fpt.isProtected(peerB, sidStr) {
if fpt.isProtected(peerB) {
t.Fatal("Expected peer not to be protected")
}
......@@ -428,7 +426,7 @@ func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
time.Sleep(10 * time.Millisecond)
// Expect peer C not to be protected as we didn't want the block it sent
if fpt.isProtected(peerC, sidStr) {
if fpt.isProtected(peerC) {
t.Fatal("Expected peer not to be protected")
}
}
......
......@@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
return
}
spm.tagger.Protect(p, spm.protectedTag())
spm.tagger.Protect(p, spm.tag)
}
// RemovePeer removes the peer from the SessionPeerManager.
......@@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.protectedTag())
spm.tagger.Unprotect(p, spm.tag)
log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true
......@@ -145,10 +145,6 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers
for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.protectedTag())
spm.tagger.Unprotect(p, spm.tag)
}
}
func (spm *SessionPeerManager) protectedTag() string {
return fmt.Sprintf("%d", spm.id)
}
package sessionpeermanager
import (
"fmt"
"sync"
"testing"
......@@ -71,6 +70,13 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
return false
}
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
}
func TestAddPeers(t *testing.T) {
peers := testutil.GeneratePeers(2)
spm := New(1, &fakePeerTagger{})
......@@ -247,26 +253,24 @@ func TestProtectConnection(t *testing.T) {
peers := testutil.GeneratePeers(1)
peerA := peers[0]
fpt := newFakePeerTagger()
sid := 1
sidstr := fmt.Sprintf("%d", sid)
spm := New(1, fpt)
// Should not protect connection if peer hasn't been added yet
spm.ProtectConnection(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; ok {
if fpt.isProtected(peerA) {
t.Fatal("Expected peer not to be protected")
}
// Once peer is added, should be able to protect connection
spm.AddPeer(peerA)
spm.ProtectConnection(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; !ok {
if !fpt.isProtected(peerA) {
t.Fatal("Expected peer to be protected")
}
// Removing peer should unprotect connection
spm.RemovePeer(peerA)
if _, ok := fpt.protectedPeers[peerA][sidstr]; ok {
if fpt.isProtected(peerA) {
t.Fatal("Expected peer to be unprotected")
}
}
......@@ -274,9 +278,7 @@ func TestProtectConnection(t *testing.T) {
func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2)
fpt := newFakePeerTagger()
sid := uint64(1)
sidstr := fmt.Sprintf("%d", sid)
spm := New(sid, fpt)
spm := New(1, fpt)
spm.AddPeer(peers[0])
spm.AddPeer(peers[1])
......@@ -285,7 +287,7 @@ func TestShutdown(t *testing.T) {
}
spm.ProtectConnection(peers[0])
if _, ok := fpt.protectedPeers[peers[0]][sidstr]; !ok {
if !fpt.isProtected(peers[0]) {
t.Fatal("Expected peer to be protected")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment