Unverified Commit 70bced72 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #406 from ipfs/feat/protect-conns

If peer is first to send a block to session, protect connection
parents 1910e213 b38f4513
......@@ -65,6 +65,8 @@ type SessionPeerManager interface {
Peers() []peer.ID
// Whether there are any peers in the session
HasPeers() bool
// Protect connection from being pruned by the connection manager
ProtectConnection(peer.ID)
}
// ProviderFinder is used to find providers for a given key
......
......@@ -56,16 +56,49 @@ func newFakeSessionPeerManager() *bsspm.SessionPeerManager {
return bsspm.New(1, newFakePeerTagger())
}
func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
}
type fakePeerTagger struct {
lk sync.Mutex
protectedPeers map[peer.ID]map[string]struct{}
}
func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{}
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {}
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {}
func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()
tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
}
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, val int) {
func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
return len(tags) > 0
}
return false
}
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
}
type fakeProviderFinder struct {
......
......@@ -379,6 +379,11 @@ func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid {
// Inform the peer tracker that this peer was the first to send
// us the block
sws.peerRspTrkr.receivedBlockFrom(upd.from)
// Protect the connection to this peer so that we can ensure
// that the connection doesn't get pruned by the connection
// manager
sws.spm.ProtectConnection(upd.from)
}
delete(sws.peerConsecutiveDontHaves, upd.from)
}
......
......@@ -8,6 +8,7 @@ import (
bsbpm "github.com/ipfs/go-bitswap/internal/blockpresencemanager"
bspm "github.com/ipfs/go-bitswap/internal/peermanager"
bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager"
"github.com/ipfs/go-bitswap/internal/testutil"
cid "github.com/ipfs/go-cid"
peer "github.com/libp2p/go-libp2p-core/peer"
......@@ -374,6 +375,62 @@ func TestRegisterSessionWithPeerManager(t *testing.T) {
}
}
func TestProtectConnFirstPeerToSendWantedBlock(t *testing.T) {
cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(3)
peerA := peers[0]
peerB := peers[1]
peerC := peers[2]
sid := uint64(1)
pm := newMockPeerManager()
fpt := newFakePeerTagger()
fpm := bsspm.New(1, fpt)
swc := newMockSessionMgr()
bpm := bsbpm.New()
onSend := func(peer.ID, []cid.Cid, []cid.Cid) {}
onPeersExhausted := func([]cid.Cid) {}
spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted)
defer spm.Shutdown()
go spm.Run()
// add cid0
spm.Add(cids[:1])
// peerA: block cid0
spm.Update(peerA, cids[:1], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer A to be protected as it was first to send the block
if !fpt.isProtected(peerA) {
t.Fatal("Expected first peer to send block to have protected connection")
}
// peerB: block cid0
spm.Update(peerB, cids[:1], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer B not to be protected as it was not first to send the block
if fpt.isProtected(peerB) {
t.Fatal("Expected peer not to be protected")
}
// peerC: block cid1
spm.Update(peerC, cids[1:], nil, nil)
// Wait for processing to complete
time.Sleep(10 * time.Millisecond)
// Expect peer C not to be protected as we didn't want the block it sent
if fpt.isProtected(peerC) {
t.Fatal("Expected peer not to be protected")
}
}
func TestPeerUnavailable(t *testing.T) {
cids := testutil.GenerateCids(2)
peers := testutil.GeneratePeers(2)
......
......@@ -51,12 +51,13 @@ func (fs *fakeSession) Shutdown() {
type fakeSesPeerManager struct {
}
func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) HasPeers() bool { return false }
func (*fakeSesPeerManager) Peers() []peer.ID { return nil }
func (*fakeSesPeerManager) PeersDiscovered() bool { return false }
func (*fakeSesPeerManager) Shutdown() {}
func (*fakeSesPeerManager) AddPeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false }
func (*fakeSesPeerManager) HasPeers() bool { return false }
func (*fakeSesPeerManager) ProtectConnection(peer.ID) {}
type fakePeerManager struct {
lk sync.Mutex
......
......@@ -21,6 +21,8 @@ const (
type PeerTagger interface {
TagPeer(peer.ID, string, int)
UntagPeer(p peer.ID, tag string)
Protect(peer.ID, string)
Unprotect(peer.ID, string) bool
}
// SessionPeerManager keeps track of peers for a session, and takes care of
......@@ -67,6 +69,18 @@ func (spm *SessionPeerManager) AddPeer(p peer.ID) bool {
return true
}
// Protect connection to this peer from being pruned by the connection manager
func (spm *SessionPeerManager) ProtectConnection(p peer.ID) {
spm.plk.Lock()
defer spm.plk.Unlock()
if _, ok := spm.peers[p]; !ok {
return
}
spm.tagger.Protect(p, spm.tag)
}
// RemovePeer removes the peer from the SessionPeerManager.
// Returns true if the peer was removed, false if it did not exist.
func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
......@@ -79,6 +93,7 @@ func (spm *SessionPeerManager) RemovePeer(p peer.ID) bool {
delete(spm.peers, p)
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)
log.Debugw("Bitswap: removed peer from session", "session", spm.id, "peer", p, "peerCount", len(spm.peers))
return true
......@@ -130,5 +145,6 @@ func (spm *SessionPeerManager) Shutdown() {
// connections to those peers
for p := range spm.peers {
spm.tagger.UntagPeer(p, spm.tag)
spm.tagger.Unprotect(p, spm.tag)
}
}
......@@ -9,9 +9,16 @@ import (
)
type fakePeerTagger struct {
lk sync.Mutex
taggedPeers []peer.ID
wait sync.WaitGroup
lk sync.Mutex
taggedPeers []peer.ID
protectedPeers map[peer.ID]map[string]struct{}
wait sync.WaitGroup
}
func newFakePeerTagger() *fakePeerTagger {
return &fakePeerTagger{
protectedPeers: make(map[peer.ID]map[string]struct{}),
}
}
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
......@@ -36,6 +43,40 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
}
}
func (fpt *fakePeerTagger) Protect(p peer.ID, tag string) {
fpt.lk.Lock()
defer fpt.lk.Unlock()
tags, ok := fpt.protectedPeers[p]
if !ok {
tags = make(map[string]struct{})
fpt.protectedPeers[p] = tags
}
tags[tag] = struct{}{}
}
func (fpt *fakePeerTagger) Unprotect(p peer.ID, tag string) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
if tags, ok := fpt.protectedPeers[p]; ok {
delete(tags, tag)
if len(tags) == 0 {
delete(fpt.protectedPeers, p)
}
return len(tags) > 0
}
return false
}
func (fpt *fakePeerTagger) isProtected(p peer.ID) bool {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.protectedPeers[p]) > 0
}
func TestAddPeers(t *testing.T) {
peers := testutil.GeneratePeers(2)
spm := New(1, &fakePeerTagger{})
......@@ -208,9 +249,35 @@ func TestPeerTagging(t *testing.T) {
}
}
func TestProtectConnection(t *testing.T) {
peers := testutil.GeneratePeers(1)
peerA := peers[0]
fpt := newFakePeerTagger()
spm := New(1, fpt)
// Should not protect connection if peer hasn't been added yet
spm.ProtectConnection(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer not to be protected")
}
// Once peer is added, should be able to protect connection
spm.AddPeer(peerA)
spm.ProtectConnection(peerA)
if !fpt.isProtected(peerA) {
t.Fatal("Expected peer to be protected")
}
// Removing peer should unprotect connection
spm.RemovePeer(peerA)
if fpt.isProtected(peerA) {
t.Fatal("Expected peer to be unprotected")
}
}
func TestShutdown(t *testing.T) {
peers := testutil.GeneratePeers(2)
fpt := &fakePeerTagger{}
fpt := newFakePeerTagger()
spm := New(1, fpt)
spm.AddPeer(peers[0])
......@@ -219,9 +286,17 @@ func TestShutdown(t *testing.T) {
t.Fatal("Expected to have tagged two peers")
}
spm.ProtectConnection(peers[0])
if !fpt.isProtected(peers[0]) {
t.Fatal("Expected peer to be protected")
}
spm.Shutdown()
if len(fpt.taggedPeers) != 0 {
t.Fatal("Expected to have untagged all peers")
}
if len(fpt.protectedPeers) != 0 {
t.Fatal("Expected to have unprotected all peers")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment