Commit c7e7afca authored by Dirk McCormick's avatar Dirk McCormick

fix: ensure conns are unprotected on shutdown

parent ba0f59c3
...@@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { ...@@ -78,7 +78,7 @@ func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
return return
} }
spm.tagger.Protect(p, fmt.Sprintf("%d", spm.id)) spm.tagger.Protect(p, spm.protectedTag())
} }
// RemovePeer removes the peer from the SessionPeerManager. // RemovePeer removes the peer from the SessionPeerManager.
...@@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { ...@@ -93,7 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
delete(spm.peers, p) delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, fmt.Sprintf("%d", spm.id)) spm.tagger.Unprotect(p, spm.protectedTag())
log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true return true
...@@ -145,5 +145,10 @@ func (spm *SessionPeerManager) Shutdown() { ...@@ -145,5 +145,10 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers // connections to those peers
for p := range spm.peers { for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.protectedTag())
} }
} }
func (spm *SessionPeerManager) protectedTag() string {
return fmt.Sprintf("%d", spm.id)
}
...@@ -62,6 +62,9 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { ...@@ -62,6 +62,9 @@ func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
if tags, ok := fpt.protectedPeers[p]; ok { if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag) delete(tags, tag)
if len(tags) == 0 {
delete(fpt.protectedPeers, p)
}
return len(tags) > 0 return len(tags) > 0
} }
...@@ -270,8 +273,10 @@ func TestProtectConnection(t *testing.T) { ...@@ -270,8 +273,10 @@ func TestProtectConnection(t *testing.T) {
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
fpt := &fakePeerTagger{} fpt := newFakePeerTagger()
spm := New(1, fpt) sid := uint64(1)
sidstr := fmt.Sprintf("%d", sid)
spm := New(sid, fpt)
spm.AddPeer(peers[0]) spm.AddPeer(peers[0])
spm.AddPeer(peers[1]) spm.AddPeer(peers[1])
...@@ -279,9 +284,17 @@ func TestShutdown(t *testing.T) { ...@@ -279,9 +284,17 @@ func TestShutdown(t *testing.T) {
t.Fatal("Expected to have tagged two peers") t.Fatal("Expected to have tagged two peers")
} }
spm.ProtectConnection(peers[0])
if _, ok := fpt.protectedPeers[peers[0]][sidstr]; !ok {
t.Fatal("Expected peer to be protected")
}
spm.Shutdown() spm.Shutdown()
if len(fpt.taggedPeers) != 0 { if len(fpt.taggedPeers) != 0 {
t.Fatal("Expected to have untagged all peers") t.Fatal("Expected to have untagged all peers")
} }
if len(fpt.protectedPeers) != 0 {
t.Fatal("Expected to have unprotected all peers")
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment