Unverified Commit 70bced72 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #406 from ipfs/feat/protect-conns

If peer is first to send a block to session, protect connection
parents 1910e213 b38f4513
...@@ -65,6 +65,8 @@ type SessionPeerManager interface { ...@@ -65,6 +65,8 @@ type SessionPeerManager interface {
Peers() []peer.ID Peers() []peer.ID
// Whether there are any peers in the session // Whether there are any peers in the session
HasPeers() bool HasPeers() bool
// Protect connection from being pruned by the connection manager
ProtectConnection(peer.ID)
} }
// ProviderFinder is used to find providers for a given key // ProviderFinder is used to find providers for a given key
......
...@@ -56,16 +56,49 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager { ...@@ -56,16 +56,49 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
return bsspm.New(1, newFakePeerTagger()) return bsspm.New(1, newFakePeerTagger())
} }
func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
}
type fakePeerTagger struct { type fakePeerTagger struct {
lk sync.Mutex
protectedPeers map[peer.ID]map[string]struct{}
} }
func newFakePeerTagger() *fakePeerTagger { func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {}
return &fakePeerTagger{} func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {}
func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()
tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
} }
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) { func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
return len(tags) > 0
}
return false
} }
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
} }
type fakeProviderFinder struct { type fakeProviderFinder struct {
......
...@@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid { ...@@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid {
// Inform the peer tracker that this peer was the first to send // Inform the peer tracker that this peer was the first to send
// us the block // us the block
sws.peerRspTrkr.receivedBlockFrom(upd.from) sws.peerRspTrkr.receivedBlockFrom(upd.from)
// Protect the connection to this peer so that we can ensure
// that the connection doesn't get pruned by the connection
// manager
sws.spm.ProtectConnection(upd.from)
} }
delete(sws.peerConsecutiveDontHaves, upd.from) delete(sws.peerConsecutiveDontHaves, upd.from)
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager" bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager"
bspm "github.com/ipfs/go-bitswap/internal/peermanager" bspm "github.com/ipfs/go-bitswap/internal/peermanager"
bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager"
"github.com/ipfs/go-bitswap/internal/testutil" "github.com/ipfs/go-bitswap/internal/testutil"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-core/peer" peer "github.com/libp2p/go-libp2p-core/peer"
...@@ -374,6 +375,62 @@ func TestRegisterSessionWithPeerManager(t *testing.T) { ...@@ -374,6 +375,62 @@ func TestRegisterSessionWithPeerManager(t *testing.T) {
} }
} }
func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(3)
peerA := peers[0]
peerB := peers[1]
peerC := peers[2]
sid := uint64(1)
pm := newMockPeerManager()
fpt := newFakePeerTagger()
fpm := bsspm.New(1, fpt)
swc := newMockSessionMgr()
bpm := bsbpm.New()
onSend := func(peer.ID, []cid.Cid, []cid.Cid) {}
onPeersExhausted := func([]cid.Cid) {}
spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted)
defer spm.Shutdown()
go spm.Run()
// add cid0
spm.Add(cids[:1])
// peerA: block cid0
spm.Update(peerA, cids[:1], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer A to be protected as it was first to send the block
if !fpt.isProtected(peerA) {
t.Fatal("Expected first peer to send block to have protected connection")
}
// peerB: block cid0
spm.Update(peerB, cids[:1], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer B not to be protected as it was not first to send the block
if fpt.isProtected(peerB) {
t.Fatal("Expected peer not to be protected")
}
// peerC: block cid1
spm.Update(peerC, cids[1:], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer C not to be protected as we didn't want the block it sent
if fpt.isProtected(peerC) {
t.Fatal("Expected peer not to be protected")
}
}
func TestPeerUnavailable(t *testing.T) { func TestPeerUnavailable(t *testing.T) {
cids := testutil.GenerateCids(2) cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
......
...@@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() { ...@@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() {
type fakeSesPeerManager struct { type fakeSesPeerManager struct {
} }
func (*fakeSesPeerManager) Peers() []peer.ID { return nil } func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) PeersDiscovered() bool { return false } func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) Shutdown() {} func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false } func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) HasPeers() bool { return false } func (*fakeSesPeerManager) HasPeers() bool { return false }
func (*fakeSesPeerManager) ProtectConnection(peer.ID) {}
type fakePeerManager struct { type fakePeerManager struct {
lk sync.Mutex lk sync.Mutex
......
...@@ -21,6 +21,8 @@ const ( ...@@ -21,6 +21,8 @@ const (
type PeerTagger interface { type PeerTagger interface {
TagPeer(peer.ID, string, int) TagPeer(peer.ID, string, int)
UntagPeer(p peer.ID, tag string) UntagPeer(p peer.ID, tag string)
Protect(peer.ID, string)
Unprotect(peer.ID, string) bool
} }
// SessionPeerManager keeps track of peers for a session, and takes care of // SessionPeerManager keeps track of peers for a session, and takes care of
...@@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool { ...@@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool {
return true return true
} }
// Protect connection to this peer from being pruned by the connection manager
func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
spm.plk.Lock()
defer spm.plk.Unlock()
if _, ok := spm.peers[p]; !ok {
return
}
spm.tagger.Protect(p, spm.tag)
}
// RemovePeer removes the peer from the SessionPeerManager. // RemovePeer removes the peer from the SessionPeerManager.
// Returns true if the peer was removed, false if it did not exist. // Returns true if the peer was removed, false if it did not exist.
func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
...@@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool { ...@@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
delete(spm.peers, p) delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)
log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers)) log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true return true
...@@ -130,5 +145,6 @@ func (spm *SessionPeerManager) Shutdown() { ...@@ -130,5 +145,6 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers // connections to those peers
for p := range spm.peers { for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)
} }
} }
...@@ -9,9 +9,16 @@ import ( ...@@ -9,9 +9,16 @@ import (
) )
type fakePeerTagger struct { type fakePeerTagger struct {
lk sync.Mutex lk sync.Mutex
taggedPeers []peer.ID taggedPeers []peer.ID
wait sync.WaitGroup protectedPeers map[peer.ID]map[string]struct{}
wait sync.WaitGroup
}
func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
} }
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
...@@ -36,6 +43,40 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { ...@@ -36,6 +43,40 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
} }
} }
func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()
tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
}
func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
if len(tags) == 0 {
delete(fpt.protectedPeers, p)
}
return len(tags) > 0
}
return false
}
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
}
func TestAddPeers(t *testing.T) { func TestAddPeers(t *testing.T) {
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
spm := New(1, &fakePeerTagger{}) spm := New(1, &fakePeerTagger{})
...@@ -208,9 +249,35 @@ func TestPeerTagging(t *testing.T) { ...@@ -208,9 +249,35 @@ func TestPeerTagging(t *testing.T) {
} }
} }
func TestProtectConnection(t *testing.T) {
peers := testutil.GeneratePeers(1)
peerA := peers[0]
fpt := newFakePeerTagger()
spm := New(1, fpt)
// Should not protect connection if peer hasn't been added yet
spm.ProtectConnection(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer not to be protected")
}
// Once peer is added, should be able to protect connection
spm.AddPeer(peerA)
spm.ProtectConnection(peerA)
if !fpt.isProtected(peerA) {
t.Fatal("Expected peer to be protected")
}
// Removing peer should unprotect connection
spm.RemovePeer(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer to be unprotected")
}
}
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2) peers := testutil.GeneratePeers(2)
fpt := &fakePeerTagger{} fpt := newFakePeerTagger()
spm := New(1, fpt) spm := New(1, fpt)
spm.AddPeer(peers[0]) spm.AddPeer(peers[0])
...@@ -219,9 +286,17 @@ func TestShutdown(t *testing.T) { ...@@ -219,9 +286,17 @@ func TestShutdown(t *testing.T) {
t.Fatal("Expected to have tagged two peers") t.Fatal("Expected to have tagged two peers")
} }
spm.ProtectConnection(peers[0])
if !fpt.isProtected(peers[0]) {
t.Fatal("Expected peer to be protected")
}
spm.Shutdown() spm.Shutdown()
if len(fpt.taggedPeers) != 0 { if len(fpt.taggedPeers) != 0 {
t.Fatal("Expected to have untagged all peers") t.Fatal("Expected to have untagged all peers")
} }
if len(fpt.protectedPeers) != 0 {
t.Fatal("Expected to have unprotected all peers")
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment