diff --git a/internal/session/session.go b/internal/session/session.go index 7a0d23b366029e55b53012deaa4460bce3f43fd4..7b2953f951e077b7680e0f913a105104283ace67 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -65,6 +65,8 @@ type SessionPeerManager interface { Peers() []peer.ID // Whether there are any peers in the session HasPeers() bool + // Protect connection from being pruned by the connection manager + ProtectConnection(peer.ID) } // ProviderFinder is used to find providers for a given key diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 028ee46e2096a53f1a2ea34bc5b8025645dc0f71..08bc9f88b397c0d48abb27b3b203fedc3aa25c9d 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -56,16 +56,49 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager { return bsspm.New(1, newFakePeerTagger()) } +func newFakePeerTagger() *fakePeerTagger { + return &fakePeerTagger{ + protectedPeers: make(map[peer.ID]map[string]struct{}), + } +} + type fakePeerTagger struct { + lk sync.Mutex + protectedPeers map[peer.ID]map[string]struct{} } -func newFakePeerTagger() *fakePeerTagger { - return &fakePeerTagger{} +func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {} +func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {} + +func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + tags, ok := fpt.protectedPeers[p] + if !ok { + tags = make(map[string]struct{}) + fpt.protectedPeers[p] = tags + } + tags[tag] = struct{}{} } -func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) { +func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + if tags, ok := fpt.protectedPeers[p]; ok { + delete(tags, tag) + return len(tags) > 0 + } + + return false } -func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { + +func (fpt *fakePeerTagger) isProtected(p peer.ID) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + return len(fpt.protectedPeers[p]) > 0 } type fakeProviderFinder struct { diff --git a/internal/session/sessionwantsender.go b/internal/session/sessionwantsender.go index 036a7e9107b9a24d3478cc6e994ca5d5115f1c72..95439a9bf40df7c317fd56dc5f60bb08f4bc0179 100644 --- a/internal/session/sessionwantsender.go +++ b/internal/session/sessionwantsender.go @@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid { // Inform the peer tracker that this peer was the first to send // us the block sws.peerRspTrkr.receivedBlockFrom(upd.from) + + // Protect the connection to this peer so that we can ensure + // that the connection doesn't get pruned by the connection + // manager + sws.spm.ProtectConnection(upd.from) } delete(sws.peerConsecutiveDontHaves, upd.from) } diff --git a/internal/session/sessionwantsender_test.go b/internal/session/sessionwantsender_test.go index a36eb432e4998adba7c5d8576a2c02c36e742ff3..806112f55f0d09619a064c951dcc1179d4ccda49 100644 --- a/internal/session/sessionwantsender_test.go +++ b/internal/session/sessionwantsender_test.go @@ -8,6 +8,7 @@ import ( bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager" bspm "github.com/ipfs/go-bitswap/internal/peermanager" + bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager" "github.com/ipfs/go-bitswap/internal/testutil" cid "github.com/ipfs/go-cid" peer "github.com/libp2p/go-libp2p-core/peer" @@ -374,6 +375,62 @@ func TestRegisterSessionWithPeerManager(t *testing.T) { } } +func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { + cids := testutil.GenerateCids(2) + peers := testutil.GeneratePeers(3) + peerA := peers[0] + peerB := peers[1] + peerC := peers[2] + sid := uint64(1) + pm := newMockPeerManager() + fpt := newFakePeerTagger() + fpm := bsspm.New(1, fpt) + swc := newMockSessionMgr() + bpm := bsbpm.New() + onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} + onPeersExhausted := func([]cid.Cid) {} + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() + + go spm.Run() + + // add cid0 + spm.Add(cids[:1]) + + // peerA: block cid0 + spm.Update(peerA, cids[:1], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer A to be protected as it was first to send the block + if !fpt.isProtected(peerA) { + t.Fatal("Expected first peer to send block to have protected connection") + } + + // peerB: block cid0 + spm.Update(peerB, cids[:1], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer B not to be protected as it was not first to send the block + if fpt.isProtected(peerB) { + t.Fatal("Expected peer not to be protected") + } + + // peerC: block cid1 + spm.Update(peerC, cids[1:], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer C not to be protected as we didn't want the block it sent + if fpt.isProtected(peerC) { + t.Fatal("Expected peer not to be protected") + } +} + func TestPeerUnavailable(t *testing.T) { cids := testutil.GenerateCids(2) peers := testutil.GeneratePeers(2) diff --git a/internal/sessionmanager/sessionmanager_test.go b/internal/sessionmanager/sessionmanager_test.go index 3be1f9b557b2814d5f119131817ec4c331df80dd..fb8445f1e5d54136a5f0553678a09aa39e7158ff 100644 --- a/internal/sessionmanager/sessionmanager_test.go +++ b/internal/sessionmanager/sessionmanager_test.go @@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() { type fakeSesPeerManager struct { } -func (*fakeSesPeerManager) Peers() []peer.ID { return nil } -func (*fakeSesPeerManager) PeersDiscovered() bool { return false } -func (*fakeSesPeerManager) Shutdown() {} -func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false } -func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } -func (*fakeSesPeerManager) HasPeers() bool { return false } +func (*fakeSesPeerManager) Peers() []peer.ID { return nil } +func (*fakeSesPeerManager) PeersDiscovered() bool { return false } +func (*fakeSesPeerManager) Shutdown() {} +func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false } +func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } +func (*fakeSesPeerManager) HasPeers() bool { return false } +func (*fakeSesPeerManager) ProtectConnection(peer.ID) {} type fakePeerManager struct { lk sync.Mutex diff --git a/internal/sessionpeermanager/sessionpeermanager.go b/internal/sessionpeermanager/sessionpeermanager.go index 499aa830bc562be34604666ed56306d6b23ce9e1..db46691b9442ede94e2f20ae4bd711afe8c53dab 100644 --- a/internal/sessionpeermanager/sessionpeermanager.go +++ b/internal/sessionpeermanager/sessionpeermanager.go @@ -21,6 +21,8 @@ const ( type PeerTagger interface { TagPeer(peer.ID, string, int) UntagPeer(p peer.ID, tag string) + Protect(peer.ID, string) + Unprotect(peer.ID, string) bool } // SessionPeerManager keeps track of peers for a session, and takes care of @@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool { return true } +// Protect connection to this peer from being pruned by the connection manager +func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { + spm.plk.Lock() + defer spm.plk.Unlock() + + if _, ok := spm.peers[p]; !ok { + return + } + + spm.tagger.Protect(p, spm.tag) +} + // RemovePeer removes the peer from the SessionPeerManager. // Returns true if the peer was removed, false if it did not exist. func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { @@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { delete(spm.peers, p) spm.tagger.UntagPeer(p, spm.tag) + spm.tagger.Unprotect(p, spm.tag) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) return true @@ -130,5 +145,6 @@ func (spm *SessionPeerManager) Shutdown() { // connections to those peers for p := range spm.peers { spm.tagger.UntagPeer(p, spm.tag) + spm.tagger.Unprotect(p, spm.tag) } } diff --git a/internal/sessionpeermanager/sessionpeermanager_test.go b/internal/sessionpeermanager/sessionpeermanager_test.go index e3c1c4ab46a24285fb2587f972696fff53826d12..746333c22772de3be01e9630fc82d6a9c66bb5b3 100644 --- a/internal/sessionpeermanager/sessionpeermanager_test.go +++ b/internal/sessionpeermanager/sessionpeermanager_test.go @@ -9,9 +9,16 @@ import ( ) type fakePeerTagger struct { - lk sync.Mutex - taggedPeers []peer.ID - wait sync.WaitGroup + lk sync.Mutex + taggedPeers []peer.ID + protectedPeers map[peer.ID]map[string]struct{} + wait sync.WaitGroup +} + +func newFakePeerTagger() *fakePeerTagger { + return &fakePeerTagger{ + protectedPeers: make(map[peer.ID]map[string]struct{}), + } } func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { @@ -36,6 +43,40 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { } } +func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + tags, ok := fpt.protectedPeers[p] + if !ok { + tags = make(map[string]struct{}) + fpt.protectedPeers[p] = tags + } + tags[tag] = struct{}{} +} + +func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + if tags, ok := fpt.protectedPeers[p]; ok { + delete(tags, tag) + if len(tags) == 0 { + delete(fpt.protectedPeers, p) + } + return len(tags) > 0 + } + + return false +} + +func (fpt *fakePeerTagger) isProtected(p peer.ID) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + return len(fpt.protectedPeers[p]) > 0 +} + func TestAddPeers(t *testing.T) { peers := testutil.GeneratePeers(2) spm := New(1, &fakePeerTagger{}) @@ -208,9 +249,35 @@ func TestPeerTagging(t *testing.T) { } } +func TestProtectConnection(t *testing.T) { + peers := testutil.GeneratePeers(1) + peerA := peers[0] + fpt := newFakePeerTagger() + spm := New(1, fpt) + + // Should not protect connection if peer hasn't been added yet + spm.ProtectConnection(peerA) + if fpt.isProtected(peerA) { + t.Fatal("Expected peer not to be protected") + } + + // Once peer is added, should be able to protect connection + spm.AddPeer(peerA) + spm.ProtectConnection(peerA) + if !fpt.isProtected(peerA) { + t.Fatal("Expected peer to be protected") + } + + // Removing peer should unprotect connection + spm.RemovePeer(peerA) + if fpt.isProtected(peerA) { + t.Fatal("Expected peer to be unprotected") + } +} + func TestShutdown(t *testing.T) { peers := testutil.GeneratePeers(2) - fpt := &fakePeerTagger{} + fpt := newFakePeerTagger() spm := New(1, fpt) spm.AddPeer(peers[0]) @@ -219,9 +286,17 @@ func TestShutdown(t *testing.T) { t.Fatal("Expected to have tagged two peers") } + spm.ProtectConnection(peers[0]) + if !fpt.isProtected(peers[0]) { + t.Fatal("Expected peer to be protected") + } + spm.Shutdown() if len(fpt.taggedPeers) != 0 { t.Fatal("Expected to have untagged all peers") } + if len(fpt.protectedPeers) != 0 { + t.Fatal("Expected to have unprotected all peers") + } }