From ba0f59c33ca033cb497b0a5837ada652f84c9e31 Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Wed, 3 Jun 2020 16:00:15 -0400 Subject: [PATCH] feat: protect connection for session peers that are first to send block --- internal/session/session.go | 2 + internal/session/session_test.go | 36 ++++++++-- internal/session/sessionwantsender.go | 5 ++ internal/session/sessionwantsender_test.go | 59 +++++++++++++++++ .../sessionmanager/sessionmanager_test.go | 13 ++-- .../sessionpeermanager/sessionpeermanager.go | 15 +++++ .../sessionpeermanager_test.go | 66 ++++++++++++++++++- 7 files changed, 182 insertions(+), 14 deletions(-) diff --git a/internal/session/session.go b/internal/session/session.go index 7a0d23b..7b2953f 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -65,6 +65,8 @@ type SessionPeerManager interface { Peers() []peer.ID // Whether there are any peers in the session HasPeers() bool + // Protect connection from being pruned by the connection manager + ProtectConnection(peer.ID) } // ProviderFinder is used to find providers for a given key diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 028ee46..e553bb8 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -56,16 +56,42 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager { return bsspm.New(1, newFakePeerTagger()) } -type fakePeerTagger struct { +func newFakePeerTagger() *fakePeerTagger { + return &fakePeerTagger{ + protectedPeers: make(map[peer.ID]map[string]struct{}), + } } -func newFakePeerTagger() *fakePeerTagger { - return &fakePeerTagger{} +type fakePeerTagger struct { + lk sync.Mutex + protectedPeers map[peer.ID]map[string]struct{} } -func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) { +func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {} +func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {} + +func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + tags, ok := fpt.protectedPeers[p] + if !ok { + tags = make(map[string]struct{}) + fpt.protectedPeers[p] = tags + } + tags[tag] = struct{}{} } -func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { + +func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + if tags, ok := fpt.protectedPeers[p]; ok { + delete(tags, tag) + return len(tags) > 0 + } + + return false } type fakeProviderFinder struct { diff --git a/internal/session/sessionwantsender.go b/internal/session/sessionwantsender.go index 036a7e9..95439a9 100644 --- a/internal/session/sessionwantsender.go +++ b/internal/session/sessionwantsender.go @@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid { // Inform the peer tracker that this peer was the first to send // us the block sws.peerRspTrkr.receivedBlockFrom(upd.from) + + // Protect the connection to this peer so that we can ensure + // that the connection doesn't get pruned by the connection + // manager + sws.spm.ProtectConnection(upd.from) } delete(sws.peerConsecutiveDontHaves, upd.from) } diff --git a/internal/session/sessionwantsender_test.go b/internal/session/sessionwantsender_test.go index a36eb43..de73c56 100644 --- a/internal/session/sessionwantsender_test.go +++ b/internal/session/sessionwantsender_test.go @@ -2,12 +2,14 @@ package session import ( "context" + "fmt" "sync" "testing" "time" bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager" bspm "github.com/ipfs/go-bitswap/internal/peermanager" + bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager" "github.com/ipfs/go-bitswap/internal/testutil" cid "github.com/ipfs/go-cid" peer "github.com/libp2p/go-libp2p-core/peer" @@ -374,6 +376,63 @@ func TestRegisterSessionWithPeerManager(t *testing.T) { } } +func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) { + cids := testutil.GenerateCids(2) + peers := testutil.GeneratePeers(3) + peerA := peers[0] + peerB := peers[1] + peerC := peers[2] + sid := uint64(1) + sidStr := fmt.Sprintf("%d", sid) + pm := newMockPeerManager() + fpt := newFakePeerTagger() + fpm := bsspm.New(1, fpt) + swc := newMockSessionMgr() + bpm := bsbpm.New() + onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} + onPeersExhausted := func([]cid.Cid) {} + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() + + go spm.Run() + + // add cid0 + spm.Add(cids[:1]) + + // peerA: block cid0 + spm.Update(peerA, cids[:1], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer A to be protected as it was first to send the block + if _, ok := fpt.protectedPeers[peerA][sidStr]; !ok { + t.Fatal("Expected first peer to send block to have protected connection") + } + + // peerB: block cid0 + spm.Update(peerB, cids[:1], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer B not to be protected as it was not first to send the block + if _, ok := fpt.protectedPeers[peerB][sidStr]; ok { + t.Fatal("Expected peer not to be protected") + } + + // peerC: block cid1 + spm.Update(peerC, cids[1:], nil, nil) + + // Wait for processing to complete + time.Sleep(10 * time.Millisecond) + + // Expect peer C not to be protected as we didn't want the block it sent + if _, ok := fpt.protectedPeers[peerC][sidStr]; ok { + t.Fatal("Expected peer not to be protected") + } +} + func TestPeerUnavailable(t *testing.T) { cids := testutil.GenerateCids(2) peers := testutil.GeneratePeers(2) diff --git a/internal/sessionmanager/sessionmanager_test.go b/internal/sessionmanager/sessionmanager_test.go index 3be1f9b..fb8445f 100644 --- a/internal/sessionmanager/sessionmanager_test.go +++ b/internal/sessionmanager/sessionmanager_test.go @@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() { type fakeSesPeerManager struct { } -func (*fakeSesPeerManager) Peers() []peer.ID { return nil } -func (*fakeSesPeerManager) PeersDiscovered() bool { return false } -func (*fakeSesPeerManager) Shutdown() {} -func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false } -func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } -func (*fakeSesPeerManager) HasPeers() bool { return false } +func (*fakeSesPeerManager) Peers() []peer.ID { return nil } +func (*fakeSesPeerManager) PeersDiscovered() bool { return false } +func (*fakeSesPeerManager) Shutdown() {} +func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false } +func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } +func (*fakeSesPeerManager) HasPeers() bool { return false } +func (*fakeSesPeerManager) ProtectConnection(peer.ID) {} type fakePeerManager struct { lk sync.Mutex diff --git a/internal/sessionpeermanager/sessionpeermanager.go b/internal/sessionpeermanager/sessionpeermanager.go index 499aa83..1ad144d 100644 --- a/internal/sessionpeermanager/sessionpeermanager.go +++ b/internal/sessionpeermanager/sessionpeermanager.go @@ -21,6 +21,8 @@ const ( type PeerTagger interface { TagPeer(peer.ID, string, int) UntagPeer(p peer.ID, tag string) + Protect(peer.ID, string) + Unprotect(peer.ID, string) bool } // SessionPeerManager keeps track of peers for a session, and takes care of @@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool { return true } +// Protect connection to this peer from being pruned by the connection manager +func (spm *SessionPeerManager) ProtectConnection(p peer.ID) { + spm.plk.Lock() + defer spm.plk.Unlock() + + if _, ok := spm.peers[p]; !ok { + return + } + + spm.tagger.Protect(p, fmt.Sprintf("%d", spm.id)) +} + // RemovePeer removes the peer from the SessionPeerManager. // Returns true if the peer was removed, false if it did not exist. func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { @@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { delete(spm.peers, p) spm.tagger.UntagPeer(p, spm.tag) + spm.tagger.Unprotect(p, fmt.Sprintf("%d", spm.id)) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) return true diff --git a/internal/sessionpeermanager/sessionpeermanager_test.go b/internal/sessionpeermanager/sessionpeermanager_test.go index e3c1c4a..ba3a342 100644 --- a/internal/sessionpeermanager/sessionpeermanager_test.go +++ b/internal/sessionpeermanager/sessionpeermanager_test.go @@ -1,6 +1,7 @@ package sessionpeermanager import ( + "fmt" "sync" "testing" @@ -9,9 +10,16 @@ import ( ) type fakePeerTagger struct { - lk sync.Mutex - taggedPeers []peer.ID - wait sync.WaitGroup + lk sync.Mutex + taggedPeers []peer.ID + protectedPeers map[peer.ID]map[string]struct{} + wait sync.WaitGroup +} + +func newFakePeerTagger() *fakePeerTagger { + return &fakePeerTagger{ + protectedPeers: make(map[peer.ID]map[string]struct{}), + } } func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { @@ -36,6 +44,30 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { } } +func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + tags, ok := fpt.protectedPeers[p] + if !ok { + tags = make(map[string]struct{}) + fpt.protectedPeers[p] = tags + } + tags[tag] = struct{}{} +} + +func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool { + fpt.lk.Lock() + defer fpt.lk.Unlock() + + if tags, ok := fpt.protectedPeers[p]; ok { + delete(tags, tag) + return len(tags) > 0 + } + + return false +} + func TestAddPeers(t *testing.T) { peers := testutil.GeneratePeers(2) spm := New(1, &fakePeerTagger{}) @@ -208,6 +240,34 @@ func TestPeerTagging(t *testing.T) { } } +func TestProtectConnection(t *testing.T) { + peers := testutil.GeneratePeers(1) + peerA := peers[0] + fpt := newFakePeerTagger() + sid := 1 + sidstr := fmt.Sprintf("%d", sid) + spm := New(1, fpt) + + // Should not protect connection if peer hasn't been added yet + spm.ProtectConnection(peerA) + if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { + t.Fatal("Expected peer not to be protected") + } + + // Once peer is added, should be able to protect connection + spm.AddPeer(peerA) + spm.ProtectConnection(peerA) + if _, ok := fpt.protectedPeers[peerA][sidstr]; !ok { + t.Fatal("Expected peer to be protected") + } + + // Removing peer should unprotect connection + spm.RemovePeer(peerA) + if _, ok := fpt.protectedPeers[peerA][sidstr]; ok { + t.Fatal("Expected peer to be unprotected") + } +} + func TestShutdown(t *testing.T) { peers := testutil.GeneratePeers(2) fpt := &fakePeerTagger{} -- GitLab