sessionpeermanager.go 2.79 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
package sessionpeermanager

import (
	"context"
	"fmt"

	cid "github.com/ipfs/go-cid"
	ifconnmgr "github.com/libp2p/go-libp2p-interface-connmgr"
	peer "github.com/libp2p/go-libp2p-peer"
)

type PeerNetwork interface {
	ConnectionManager() ifconnmgr.ConnManager
	FindProvidersAsync(context.Context, cid.Cid, int) <-chan peer.ID
}

type SessionPeerManager struct {
	ctx     context.Context
	network PeerNetwork
	tag     string

	newPeers chan peer.ID
	peerReqs chan chan []peer.ID

	// do not touch outside of run loop
	activePeers    map[peer.ID]struct{}
	activePeersArr []peer.ID
}

func New(ctx context.Context, id uint64, network PeerNetwork) *SessionPeerManager {
	spm := &SessionPeerManager{
		ctx:         ctx,
		network:     network,
		newPeers:    make(chan peer.ID, 16),
		peerReqs:    make(chan chan []peer.ID),
		activePeers: make(map[peer.ID]struct{}),
	}

	spm.tag = fmt.Sprint("bs-ses-", id)

	go spm.run(ctx)
	return spm
}

func (spm *SessionPeerManager) RecordPeerResponse(p peer.ID, k cid.Cid) {
	// at the moment, we're just adding peers here
	// in the future, we'll actually use this to record metrics
	select {
	case spm.newPeers <- p:
	case <-spm.ctx.Done():
	}
}

func (spm *SessionPeerManager) RecordPeerRequests(p []peer.ID, ks []cid.Cid) {
	// at the moment, we're not doing anything here
	// soon we'll use this to track latency by peer
}

func (spm *SessionPeerManager) GetOptimizedPeers() []peer.ID {
	// right now this just returns all peers, but soon we might return peers
	// ordered by optimization, or only a subset
	resp := make(chan []peer.ID)
	select {
	case spm.peerReqs <- resp:
	case <-spm.ctx.Done():
		return nil
	}

	select {
	case peers := <-resp:
		return peers
	case <-spm.ctx.Done():
		return nil
	}
}

func (spm *SessionPeerManager) FindMorePeers(ctx context.Context, c cid.Cid) {
	go func(k cid.Cid) {
		// TODO: have a task queue setup for this to:
		// - rate limit
		// - manage timeouts
		// - ensure two 'findprovs' calls for the same block don't run concurrently
		// - share peers between sessions based on interest set
		for p := range spm.network.FindProvidersAsync(ctx, k, 10) {
			spm.newPeers <- p
		}
	}(c)
}

func (spm *SessionPeerManager) run(ctx context.Context) {
	for {
		select {
		case p := <-spm.newPeers:
			spm.addActivePeer(p)
		case resp := <-spm.peerReqs:
			resp <- spm.activePeersArr
		case <-ctx.Done():
			spm.handleShutdown()
			return
		}
	}
}
func (spm *SessionPeerManager) addActivePeer(p peer.ID) {
	if _, ok := spm.activePeers[p]; !ok {
		spm.activePeers[p] = struct{}{}
		spm.activePeersArr = append(spm.activePeersArr, p)

		cmgr := spm.network.ConnectionManager()
		cmgr.TagPeer(p, spm.tag, 10)
	}
}

func (spm *SessionPeerManager) handleShutdown() {
	cmgr := spm.network.ConnectionManager()
	for _, p := range spm.activePeersArr {
		cmgr.UntagPeer(p, spm.tag)
	}
}