Commit 843391e6 authored by hannahhoward's avatar hannahhoward

feat(ProviderQueryManager): integrate in sessions

Integrate the ProviderQueryManager into the SessionPeerManager and bitswap in general

re #52, re #49
parent 1f2b49ef
......@@ -18,6 +18,7 @@ import (
bsnet "github.com/ipfs/go-bitswap/network"
notifications "github.com/ipfs/go-bitswap/notifications"
bspm "github.com/ipfs/go-bitswap/peermanager"
bspqm "github.com/ipfs/go-bitswap/providerquerymanager"
bssession "github.com/ipfs/go-bitswap/session"
bssm "github.com/ipfs/go-bitswap/sessionmanager"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
......@@ -105,11 +106,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
}
wm := bswm.New(ctx)
pqm := bspqm.New(ctx, network)
sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter) bssm.Session {
return bssession.New(ctx, id, wm, pm, srs)
}
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager {
return bsspm.New(ctx, id, network)
return bsspm.New(ctx, id, network.ConnectionManager(), pqm)
}
sessionRequestSplitterFactory := func(ctx context.Context) bssession.RequestSplitter {
return bssrs.New(ctx)
......@@ -125,6 +128,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
newBlocks: make(chan cid.Cid, HasBlockBufferSize),
provideKeys: make(chan cid.Cid, provideKeysBufferSize),
wm: wm,
pqm: pqm,
pm: bspm.New(ctx, peerQueueFactory),
sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory),
counters: new(counters),
......@@ -136,6 +140,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
bs.wm.SetDelegate(bs.pm)
bs.pm.Startup()
bs.wm.Startup()
bs.pqm.Startup()
network.SetDelegate(bs)
// Start up bitswaps async worker routines
......@@ -161,6 +166,9 @@ type Bitswap struct {
// the wantlist tracks global wants for bitswap
wm *bswm.WantManager
// the provider query manager manages requests to find providers
pqm *bspqm.ProviderQueryManager
// the engine is the bit of logic that decides who to send which blocks to
engine *decision.Engine
......
......@@ -8,7 +8,6 @@ import (
logging "github.com/ipfs/go-log"
cid "github.com/ipfs/go-cid"
ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr"
peer "github.com/libp2p/go-libp2p-peer"
)
......@@ -19,11 +18,15 @@ const (
reservePeers = 2
)
// PeerNetwork is an interface for finding providers and managing connections
type PeerNetwork interface {
ConnectionManager() ifconnmgr.ConnManager
ConnectTo(context.Context, peer.ID) error
FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID
// PeerTagger is an interface for tagging peers with metadata
type PeerTagger interface {
TagPeer(peer.ID, string, int)
UntagPeer(p peer.ID, tag string)
}
// PeerProviderFinder is an interface for finding providers
type PeerProviderFinder interface {
FindProvidersAsync(context.Context, cid.Cid, uint64) <-chan peer.ID
}
type peerMessage interface {
......@@ -33,9 +36,11 @@ type peerMessage interface {
// SessionPeerManager tracks and manages peers for a session, and provides
// the best ones to the session
type SessionPeerManager struct {
ctx context.Context
network PeerNetwork
tag string
ctx context.Context
tagger PeerTagger
providerFinder PeerProviderFinder
tag string
id uint64
peerMessages chan peerMessage
......@@ -46,12 +51,14 @@ type SessionPeerManager struct {
}
// New creates a new SessionPeerManager
func New(ctx context.Context, id uint64, network PeerNetwork) *SessionPeerManager {
func New(ctx context.Context, id uint64, tagger PeerTagger, providerFinder PeerProviderFinder) *SessionPeerManager {
spm := &SessionPeerManager{
ctx: ctx,
network: network,
peerMessages: make(chan peerMessage, 16),
activePeers: make(map[peer.ID]bool),
id: id,
ctx: ctx,
tagger: tagger,
providerFinder: providerFinder,
peerMessages: make(chan peerMessage, 16),
activePeers: make(map[peer.ID]bool),
}
spm.tag = fmt.Sprint("bs-ses-", id)
......@@ -101,24 +108,13 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID {
// providers for the given Cid
func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) {
// TODO: have a task queue setup for this to:
// - rate limit
// - manage timeouts
// - ensure two 'findprovs' calls for the same block don't run concurrently
// - share peers between sessions based on interest set
for p := range spm.network.FindProvidersAsync(ctx, k, 10) {
go func(p peer.ID) {
// TODO: Also use context from spm.
err := spm.network.ConnectTo(ctx, p)
if err != nil {
log.Debugf("failed to connect to provider %s: %s", p, err)
}
select {
case spm.peerMessages <- &peerFoundMessage{p}:
case <-ctx.Done():
case <-spm.ctx.Done():
}
}(p)
for p := range spm.providerFinder.FindProvidersAsync(ctx, k, spm.id) {
select {
case spm.peerMessages <- &peerFoundMessage{p}:
case <-ctx.Done():
case <-spm.ctx.Done():
}
}
}(c)
}
......@@ -136,8 +132,7 @@ func (spm *SessionPeerManager) run(ctx context.Context) {
}
func (spm *SessionPeerManager) tagPeer(p peer.ID) {
cmgr := spm.network.ConnectionManager()
cmgr.TagPeer(p, spm.tag, 10)
spm.tagger.TagPeer(p, spm.tag, 10)
}
func (spm *SessionPeerManager) insertOptimizedPeer(p peer.ID) {
......@@ -223,8 +218,7 @@ func (prm *peerReqMessage) handle(spm *SessionPeerManager) {
}
func (spm *SessionPeerManager) handleShutdown() {
cmgr := spm.network.ConnectionManager()
for p := range spm.activePeers {
cmgr.UntagPeer(p, spm.tag)
spm.tagger.UntagPeer(p, spm.tag)
}
}
......@@ -2,7 +2,6 @@ package sessionpeermanager
import (
"context"
"errors"
"math/rand"
"sync"
"testing"
......@@ -11,35 +10,19 @@ import (
"github.com/ipfs/go-bitswap/testutil"
cid "github.com/ipfs/go-cid"
ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr"
inet "github.com/libp2p/go-libp2p-net"
peer "github.com/libp2p/go-libp2p-peer"
)
type fakePeerNetwork struct {
peers []peer.ID
connManager ifconnmgr.ConnManager
completed chan struct{}
connect chan struct{}
type fakePeerProviderFinder struct {
peers []peer.ID
completed chan struct{}
}
func (fpn *fakePeerNetwork) ConnectionManager() ifconnmgr.ConnManager {
return fpn.connManager
}
func (fpn *fakePeerNetwork) ConnectTo(ctx context.Context, p peer.ID) error {
select {
case fpn.connect <- struct{}{}:
return nil
case <-ctx.Done():
return errors.New("Timeout Occurred")
}
}
func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, num int) <-chan peer.ID {
func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid, ses uint64) <-chan peer.ID {
peerCh := make(chan peer.ID)
go func() {
for _, p := range fpn.peers {
for _, p := range fppf.peers {
select {
case peerCh <- p:
case <-ctx.Done():
......@@ -50,52 +33,48 @@ func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, n
close(peerCh)
select {
case fpn.completed <- struct{}{}:
case fppf.completed <- struct{}{}:
case <-ctx.Done():
}
}()
return peerCh
}
type fakeConnManager struct {
type fakePeerTagger struct {
taggedPeers []peer.ID
wait sync.WaitGroup
}
func (fcm *fakeConnManager) TagPeer(p peer.ID, tag string, n int) {
fcm.wait.Add(1)
fcm.taggedPeers = append(fcm.taggedPeers, p)
func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
fpt.wait.Add(1)
fpt.taggedPeers = append(fpt.taggedPeers, p)
}
func (fcm *fakeConnManager) UntagPeer(p peer.ID, tag string) {
defer fcm.wait.Done()
for i := 0; i < len(fcm.taggedPeers); i++ {
if fcm.taggedPeers[i] == p {
fcm.taggedPeers[i] = fcm.taggedPeers[len(fcm.taggedPeers)-1]
fcm.taggedPeers = fcm.taggedPeers[:len(fcm.taggedPeers)-1]
func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
defer fpt.wait.Done()
for i := 0; i < len(fpt.taggedPeers); i++ {
if fpt.taggedPeers[i] == p {
fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1]
fpt.taggedPeers = fpt.taggedPeers[:len(fpt.taggedPeers)-1]
return
}
}
}
func (*fakeConnManager) GetTagInfo(p peer.ID) *ifconnmgr.TagInfo { return nil }
func (*fakeConnManager) TrimOpenConns(ctx context.Context) {}
func (*fakeConnManager) Notifee() inet.Notifiee { return nil }
func TestFindingMorePeers(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
completed := make(chan struct{})
connect := make(chan struct{})
peers := testutil.GeneratePeers(5)
fcm := &fakeConnManager{}
fpn := &fakePeerNetwork{peers, fcm, completed, connect}
fpt := &fakePeerTagger{}
fppf := &fakePeerProviderFinder{peers, completed}
c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn)
sessionPeerManager := New(ctx, id, fpt, fppf)
findCtx, findCancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer findCancel()
......@@ -105,13 +84,6 @@ func TestFindingMorePeers(t *testing.T) {
case <-findCtx.Done():
t.Fatal("Did not finish finding providers")
}
for range peers {
select {
case <-connect:
case <-findCtx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond)
sessionPeers := sessionPeerManager.GetOptimizedPeers()
......@@ -123,7 +95,7 @@ func TestFindingMorePeers(t *testing.T) {
t.Fatal("incorrect peer found through finding providers")
}
}
if len(fcm.taggedPeers) != len(peers) {
if len(fpt.taggedPeers) != len(peers) {
t.Fatal("Peers were not tagged!")
}
}
......@@ -133,12 +105,12 @@ func TestRecordingReceivedBlocks(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
p := testutil.GeneratePeers(1)[0]
fcm := &fakeConnManager{}
fpn := &fakePeerNetwork{nil, fcm, nil, nil}
fpt := &fakePeerTagger{}
fppf := &fakePeerProviderFinder{}
c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn)
sessionPeerManager := New(ctx, id, fpt, fppf)
sessionPeerManager.RecordPeerResponse(p, c)
time.Sleep(10 * time.Millisecond)
sessionPeers := sessionPeerManager.GetOptimizedPeers()
......@@ -148,7 +120,7 @@ func TestRecordingReceivedBlocks(t *testing.T) {
if sessionPeers[0] != p {
t.Fatal("incorrect peer added on receive")
}
if len(fcm.taggedPeers) != 1 {
if len(fpt.taggedPeers) != 1 {
t.Fatal("Peers was not tagged!")
}
}
......@@ -159,12 +131,11 @@ func TestOrderingPeers(t *testing.T) {
defer cancel()
peers := testutil.GeneratePeers(100)
completed := make(chan struct{})
connect := make(chan struct{})
fcm := &fakeConnManager{}
fpn := &fakePeerNetwork{peers, fcm, completed, connect}
fpt := &fakePeerTagger{}
fppf := &fakePeerProviderFinder{peers, completed}
c := testutil.GenerateCids(1)
id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn)
sessionPeerManager := New(ctx, id, fpt, fppf)
// add all peers to session
sessionPeerManager.FindMorePeers(ctx, c[0])
......@@ -173,13 +144,6 @@ func TestOrderingPeers(t *testing.T) {
case <-ctx.Done():
t.Fatal("Did not finish finding providers")
}
for range peers {
select {
case <-connect:
case <-ctx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond)
// record broadcast
......@@ -237,13 +201,12 @@ func TestUntaggingPeers(t *testing.T) {
defer cancel()
peers := testutil.GeneratePeers(5)
completed := make(chan struct{})
connect := make(chan struct{})
fcm := &fakeConnManager{}
fpn := &fakePeerNetwork{peers, fcm, completed, connect}
fpt := &fakePeerTagger{}
fppf := &fakePeerProviderFinder{peers, completed}
c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn)
sessionPeerManager := New(ctx, id, fpt, fppf)
sessionPeerManager.FindMorePeers(ctx, c)
select {
......@@ -251,22 +214,15 @@ func TestUntaggingPeers(t *testing.T) {
case <-ctx.Done():
t.Fatal("Did not finish finding providers")
}
for range peers {
select {
case <-connect:
case <-ctx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond)
if len(fcm.taggedPeers) != len(peers) {
if len(fpt.taggedPeers) != len(peers) {
t.Fatal("Peers were not tagged!")
}
<-ctx.Done()
fcm.wait.Wait()
fpt.wait.Wait()
if len(fcm.taggedPeers) != 0 {
if len(fpt.taggedPeers) != 0 {
t.Fatal("Peers were not untagged!")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment