Commit 843391e6 authored by hannahhoward's avatar hannahhoward

feat(ProviderQueryManager): integrate in sessions

Integrate the ProviderQueryManager into the SessionPeerManager and bitswap in general

re #52, re #49
parent 1f2b49ef
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
bsnet "github.com/ipfs/go-bitswap/network" bsnet "github.com/ipfs/go-bitswap/network"
notifications "github.com/ipfs/go-bitswap/notifications" notifications "github.com/ipfs/go-bitswap/notifications"
bspm "github.com/ipfs/go-bitswap/peermanager" bspm "github.com/ipfs/go-bitswap/peermanager"
bspqm "github.com/ipfs/go-bitswap/providerquerymanager"
bssession "github.com/ipfs/go-bitswap/session" bssession "github.com/ipfs/go-bitswap/session"
bssm "github.com/ipfs/go-bitswap/sessionmanager" bssm "github.com/ipfs/go-bitswap/sessionmanager"
bsspm "github.com/ipfs/go-bitswap/sessionpeermanager" bsspm "github.com/ipfs/go-bitswap/sessionpeermanager"
...@@ -105,11 +106,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -105,11 +106,13 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
} }
wm := bswm.New(ctx) wm := bswm.New(ctx)
pqm := bspqm.New(ctx, network)
sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter) bssm.Session { sessionFactory := func(ctx context.Context, id uint64, pm bssession.PeerManager, srs bssession.RequestSplitter) bssm.Session {
return bssession.New(ctx, id, wm, pm, srs) return bssession.New(ctx, id, wm, pm, srs)
} }
sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager { sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.PeerManager {
return bsspm.New(ctx, id, network) return bsspm.New(ctx, id, network.ConnectionManager(), pqm)
} }
sessionRequestSplitterFactory := func(ctx context.Context) bssession.RequestSplitter { sessionRequestSplitterFactory := func(ctx context.Context) bssession.RequestSplitter {
return bssrs.New(ctx) return bssrs.New(ctx)
...@@ -125,6 +128,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -125,6 +128,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
newBlocks: make(chan cid.Cid, HasBlockBufferSize), newBlocks: make(chan cid.Cid, HasBlockBufferSize),
provideKeys: make(chan cid.Cid, provideKeysBufferSize), provideKeys: make(chan cid.Cid, provideKeysBufferSize),
wm: wm, wm: wm,
pqm: pqm,
pm: bspm.New(ctx, peerQueueFactory), pm: bspm.New(ctx, peerQueueFactory),
sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory), sm: bssm.New(ctx, sessionFactory, sessionPeerManagerFactory, sessionRequestSplitterFactory),
counters: new(counters), counters: new(counters),
...@@ -136,6 +140,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, ...@@ -136,6 +140,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
bs.wm.SetDelegate(bs.pm) bs.wm.SetDelegate(bs.pm)
bs.pm.Startup() bs.pm.Startup()
bs.wm.Startup() bs.wm.Startup()
bs.pqm.Startup()
network.SetDelegate(bs) network.SetDelegate(bs)
// Start up bitswaps async worker routines // Start up bitswaps async worker routines
...@@ -161,6 +166,9 @@ type Bitswap struct { ...@@ -161,6 +166,9 @@ type Bitswap struct {
// the wantlist tracks global wants for bitswap // the wantlist tracks global wants for bitswap
wm *bswm.WantManager wm *bswm.WantManager
// the provider query manager manages requests to find providers
pqm *bspqm.ProviderQueryManager
// the engine is the bit of logic that decides who to send which blocks to // the engine is the bit of logic that decides who to send which blocks to
engine *decision.Engine engine *decision.Engine
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr"
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
) )
...@@ -19,11 +18,15 @@ const ( ...@@ -19,11 +18,15 @@ const (
reservePeers = 2 reservePeers = 2
) )
// PeerNetwork is an interface for finding providers and managing connections // PeerTagger is an interface for tagging peers with metadata
type PeerNetwork interface { type PeerTagger interface {
ConnectionManager() ifconnmgr.ConnManager TagPeer(peer.ID, string, int)
ConnectTo(context.Context, peer.ID) error UntagPeer(p peer.ID, tag string)
FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID }
// PeerProviderFinder is an interface for finding providers
type PeerProviderFinder interface {
FindProvidersAsync(context.Context, cid.Cid, uint64) <-chan peer.ID
} }
type peerMessage interface { type peerMessage interface {
...@@ -34,8 +37,10 @@ type peerMessage interface { ...@@ -34,8 +37,10 @@ type peerMessage interface {
// the best ones to the session // the best ones to the session
type SessionPeerManager struct { type SessionPeerManager struct {
ctx context.Context ctx context.Context
network PeerNetwork tagger PeerTagger
providerFinder PeerProviderFinder
tag string tag string
id uint64
peerMessages chan peerMessage peerMessages chan peerMessage
...@@ -46,10 +51,12 @@ type SessionPeerManager struct { ...@@ -46,10 +51,12 @@ type SessionPeerManager struct {
} }
// New creates a new SessionPeerManager // New creates a new SessionPeerManager
func New(ctx context.Context, id uint64, network PeerNetwork) *SessionPeerManager { func New(ctx context.Context, id uint64, tagger PeerTagger, providerFinder PeerProviderFinder) *SessionPeerManager {
spm := &SessionPeerManager{ spm := &SessionPeerManager{
id: id,
ctx: ctx, ctx: ctx,
network: network, tagger: tagger,
providerFinder: providerFinder,
peerMessages: make(chan peerMessage, 16), peerMessages: make(chan peerMessage, 16),
activePeers: make(map[peer.ID]bool), activePeers: make(map[peer.ID]bool),
} }
...@@ -101,24 +108,13 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID { ...@@ -101,24 +108,13 @@ func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID {
// providers for the given Cid // providers for the given Cid
func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) { func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) { go func(k cid.Cid) {
// TODO: have a task queue setup for this to: for p := range spm.providerFinder.FindProvidersAsync(ctx, k, spm.id) {
// - rate limit
// - manage timeouts
// - ensure two 'findprovs' calls for the same block don't run concurrently
// - share peers between sessions based on interest set
for p := range spm.network.FindProvidersAsync(ctx, k, 10) {
go func(p peer.ID) {
// TODO: Also use context from spm.
err := spm.network.ConnectTo(ctx, p)
if err != nil {
log.Debugf("failed to connect to provider %s: %s", p, err)
}
select { select {
case spm.peerMessages <- &peerFoundMessage{p}: case spm.peerMessages <- &peerFoundMessage{p}:
case <-ctx.Done(): case <-ctx.Done():
case <-spm.ctx.Done(): case <-spm.ctx.Done():
} }
}(p)
} }
}(c) }(c)
} }
...@@ -136,8 +132,7 @@ func (spm *SessionPeerManager) run(ctx context.Context) { ...@@ -136,8 +132,7 @@ func (spm *SessionPeerManager) run(ctx context.Context) {
} }
func (spm *SessionPeerManager) tagPeer(p peer.ID) { func (spm *SessionPeerManager) tagPeer(p peer.ID) {
cmgr := spm.network.ConnectionManager() spm.tagger.TagPeer(p, spm.tag, 10)
cmgr.TagPeer(p, spm.tag, 10)
} }
func (spm *SessionPeerManager) insertOptimizedPeer(p peer.ID) { func (spm *SessionPeerManager) insertOptimizedPeer(p peer.ID) {
...@@ -223,8 +218,7 @@ func (prm *peerReqMessage) handle(spm *SessionPeerManager) { ...@@ -223,8 +218,7 @@ func (prm *peerReqMessage) handle(spm *SessionPeerManager) {
} }
func (spm *SessionPeerManager) handleShutdown() { func (spm *SessionPeerManager) handleShutdown() {
cmgr := spm.network.ConnectionManager()
for p := range spm.activePeers { for p := range spm.activePeers {
cmgr.UntagPeer(p, spm.tag) spm.tagger.UntagPeer(p, spm.tag)
} }
} }
...@@ -2,7 +2,6 @@ package sessionpeermanager ...@@ -2,7 +2,6 @@ package sessionpeermanager
import ( import (
"context" "context"
"errors"
"math/rand" "math/rand"
"sync" "sync"
"testing" "testing"
...@@ -11,35 +10,19 @@ import ( ...@@ -11,35 +10,19 @@ import (
"github.com/ipfs/go-bitswap/testutil" "github.com/ipfs/go-bitswap/testutil"
cid "github.com/ipfs/go-cid" cid "github.com/ipfs/go-cid"
ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr"
inet "github.com/libp2p/go-libp2p-net"
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
) )
type fakePeerNetwork struct { type fakePeerProviderFinder struct {
peers []peer.ID peers []peer.ID
connManager ifconnmgr.ConnManager
completed chan struct{} completed chan struct{}
connect chan struct{}
} }
func (fpn *fakePeerNetwork) ConnectionManager() ifconnmgr.ConnManager { func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c cid.Cid, ses uint64) <-chan peer.ID {
return fpn.connManager
}
func (fpn *fakePeerNetwork) ConnectTo(ctx context.Context, p peer.ID) error {
select {
case fpn.connect <- struct{}{}:
return nil
case <-ctx.Done():
return errors.New("Timeout Occurred")
}
}
func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, num int) <-chan peer.ID {
peerCh := make(chan peer.ID) peerCh := make(chan peer.ID)
go func() { go func() {
for _, p := range fpn.peers {
for _, p := range fppf.peers {
select { select {
case peerCh <- p: case peerCh <- p:
case <-ctx.Done(): case <-ctx.Done():
...@@ -50,52 +33,48 @@ func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, n ...@@ -50,52 +33,48 @@ func (fpn *fakePeerNetwork) FindProvidersAsync(ctx context.Context, c cid.Cid, n
close(peerCh) close(peerCh)
select { select {
case fpn.completed <- struct{}{}: case fppf.completed <- struct{}{}:
case <-ctx.Done(): case <-ctx.Done():
} }
}() }()
return peerCh return peerCh
} }
type fakeConnManager struct { type fakePeerTagger struct {
taggedPeers []peer.ID taggedPeers []peer.ID
wait sync.WaitGroup wait sync.WaitGroup
} }
func (fcm *fakeConnManager) TagPeer(p peer.ID, tag string, n int) { func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
fcm.wait.Add(1) fpt.wait.Add(1)
fcm.taggedPeers = append(fcm.taggedPeers, p) fpt.taggedPeers = append(fpt.taggedPeers, p)
} }
func (fcm *fakeConnManager) UntagPeer(p peer.ID, tag string) { func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
defer fcm.wait.Done() defer fpt.wait.Done()
for i := 0; i < len(fcm.taggedPeers); i++ {
if fcm.taggedPeers[i] == p { for i := 0; i < len(fpt.taggedPeers); i++ {
fcm.taggedPeers[i] = fcm.taggedPeers[len(fcm.taggedPeers)-1] if fpt.taggedPeers[i] == p {
fcm.taggedPeers = fcm.taggedPeers[:len(fcm.taggedPeers)-1] fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1]
fpt.taggedPeers = fpt.taggedPeers[:len(fpt.taggedPeers)-1]
return return
} }
} }
} }
func (*fakeConnManager) GetTagInfo(p peer.ID) *ifconnmgr.TagInfo { return nil }
func (*fakeConnManager) TrimOpenConns(ctx context.Context) {}
func (*fakeConnManager) Notifee() inet.Notifiee { return nil }
func TestFindingMorePeers(t *testing.T) { func TestFindingMorePeers(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
completed := make(chan struct{}) completed := make(chan struct{})
connect := make(chan struct{})
peers := testutil.GeneratePeers(5) peers := testutil.GeneratePeers(5)
fcm := &fakeConnManager{} fpt := &fakePeerTagger{}
fpn := &fakePeerNetwork{peers, fcm, completed, connect} fppf := &fakePeerProviderFinder{peers, completed}
c := testutil.GenerateCids(1)[0] c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn) sessionPeerManager := New(ctx, id, fpt, fppf)
findCtx, findCancel := context.WithTimeout(ctx, 10*time.Millisecond) findCtx, findCancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer findCancel() defer findCancel()
...@@ -105,13 +84,6 @@ func TestFindingMorePeers(t *testing.T) { ...@@ -105,13 +84,6 @@ func TestFindingMorePeers(t *testing.T) {
case <-findCtx.Done(): case <-findCtx.Done():
t.Fatal("Did not finish finding providers") t.Fatal("Did not finish finding providers")
} }
for range peers {
select {
case <-connect:
case <-findCtx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
sessionPeers := sessionPeerManager.GetOptimizedPeers() sessionPeers := sessionPeerManager.GetOptimizedPeers()
...@@ -123,7 +95,7 @@ func TestFindingMorePeers(t *testing.T) { ...@@ -123,7 +95,7 @@ func TestFindingMorePeers(t *testing.T) {
t.Fatal("incorrect peer found through finding providers") t.Fatal("incorrect peer found through finding providers")
} }
} }
if len(fcm.taggedPeers) != len(peers) { if len(fpt.taggedPeers) != len(peers) {
t.Fatal("Peers were not tagged!") t.Fatal("Peers were not tagged!")
} }
} }
...@@ -133,12 +105,12 @@ func TestRecordingReceivedBlocks(t *testing.T) { ...@@ -133,12 +105,12 @@ func TestRecordingReceivedBlocks(t *testing.T) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
p := testutil.GeneratePeers(1)[0] p := testutil.GeneratePeers(1)[0]
fcm := &fakeConnManager{} fpt := &fakePeerTagger{}
fpn := &fakePeerNetwork{nil, fcm, nil, nil} fppf := &fakePeerProviderFinder{}
c := testutil.GenerateCids(1)[0] c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn) sessionPeerManager := New(ctx, id, fpt, fppf)
sessionPeerManager.RecordPeerResponse(p, c) sessionPeerManager.RecordPeerResponse(p, c)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
sessionPeers := sessionPeerManager.GetOptimizedPeers() sessionPeers := sessionPeerManager.GetOptimizedPeers()
...@@ -148,7 +120,7 @@ func TestRecordingReceivedBlocks(t *testing.T) { ...@@ -148,7 +120,7 @@ func TestRecordingReceivedBlocks(t *testing.T) {
if sessionPeers[0] != p { if sessionPeers[0] != p {
t.Fatal("incorrect peer added on receive") t.Fatal("incorrect peer added on receive")
} }
if len(fcm.taggedPeers) != 1 { if len(fpt.taggedPeers) != 1 {
t.Fatal("Peers was not tagged!") t.Fatal("Peers was not tagged!")
} }
} }
...@@ -159,12 +131,11 @@ func TestOrderingPeers(t *testing.T) { ...@@ -159,12 +131,11 @@ func TestOrderingPeers(t *testing.T) {
defer cancel() defer cancel()
peers := testutil.GeneratePeers(100) peers := testutil.GeneratePeers(100)
completed := make(chan struct{}) completed := make(chan struct{})
connect := make(chan struct{}) fpt := &fakePeerTagger{}
fcm := &fakeConnManager{} fppf := &fakePeerProviderFinder{peers, completed}
fpn := &fakePeerNetwork{peers, fcm, completed, connect}
c := testutil.GenerateCids(1) c := testutil.GenerateCids(1)
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn) sessionPeerManager := New(ctx, id, fpt, fppf)
// add all peers to session // add all peers to session
sessionPeerManager.FindMorePeers(ctx, c[0]) sessionPeerManager.FindMorePeers(ctx, c[0])
...@@ -173,13 +144,6 @@ func TestOrderingPeers(t *testing.T) { ...@@ -173,13 +144,6 @@ func TestOrderingPeers(t *testing.T) {
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("Did not finish finding providers") t.Fatal("Did not finish finding providers")
} }
for range peers {
select {
case <-connect:
case <-ctx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
// record broadcast // record broadcast
...@@ -237,13 +201,12 @@ func TestUntaggingPeers(t *testing.T) { ...@@ -237,13 +201,12 @@ func TestUntaggingPeers(t *testing.T) {
defer cancel() defer cancel()
peers := testutil.GeneratePeers(5) peers := testutil.GeneratePeers(5)
completed := make(chan struct{}) completed := make(chan struct{})
connect := make(chan struct{}) fpt := &fakePeerTagger{}
fcm := &fakeConnManager{} fppf := &fakePeerProviderFinder{peers, completed}
fpn := &fakePeerNetwork{peers, fcm, completed, connect}
c := testutil.GenerateCids(1)[0] c := testutil.GenerateCids(1)[0]
id := testutil.GenerateSessionID() id := testutil.GenerateSessionID()
sessionPeerManager := New(ctx, id, fpn) sessionPeerManager := New(ctx, id, fpt, fppf)
sessionPeerManager.FindMorePeers(ctx, c) sessionPeerManager.FindMorePeers(ctx, c)
select { select {
...@@ -251,22 +214,15 @@ func TestUntaggingPeers(t *testing.T) { ...@@ -251,22 +214,15 @@ func TestUntaggingPeers(t *testing.T) {
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("Did not finish finding providers") t.Fatal("Did not finish finding providers")
} }
for range peers {
select {
case <-connect:
case <-ctx.Done():
t.Fatal("Did not connect to peer")
}
}
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
if len(fcm.taggedPeers) != len(peers) { if len(fpt.taggedPeers) != len(peers) {
t.Fatal("Peers were not tagged!") t.Fatal("Peers were not tagged!")
} }
<-ctx.Done() <-ctx.Done()
fcm.wait.Wait() fpt.wait.Wait()
if len(fcm.taggedPeers) != 0 { if len(fpt.taggedPeers) != 0 {
t.Fatal("Peers were not untagged!") t.Fatal("Peers were not untagged!")
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment