Commit 78386f0e authored by hannahhoward's avatar hannahhoward

feat(wantlist): differentiate types

Seperate want list into differentiated types - session tracking and regular

fix #13
parent 472a8ab9
......@@ -97,8 +97,8 @@ func New(parent context.Context, network bsnet.BitSwapNetwork,
return nil
})
peerQueueFactory := func(p peer.ID) bspm.PeerQueue {
return bsmq.New(p, network)
peerQueueFactory := func(ctx context.Context, p peer.ID) bspm.PeerQueue {
return bsmq.New(ctx, p, network)
}
wm := bswm.New(ctx)
......
......@@ -2,7 +2,6 @@ package messagequeue
import (
"context"
"sync"
"time"
bsmsg "github.com/ipfs/go-bitswap/message"
......@@ -23,68 +22,72 @@ type MessageNetwork interface {
NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error)
}
type request interface {
handle(mq *MessageQueue)
}
// MessageQueue implements queue of want messages to send to peers.
type MessageQueue struct {
p peer.ID
outlk sync.Mutex
out bsmsg.BitSwapMessage
ctx context.Context
p peer.ID
network MessageNetwork
wl *wantlist.ThreadSafe
sender bsnet.MessageSender
newRequests chan request
outgoingMessages chan bsmsg.BitSwapMessage
done chan struct{}
// do not touch out of run loop
wl *wantlist.SessionTrackedWantlist
nextMessage bsmsg.BitSwapMessage
sender bsnet.MessageSender
}
type messageRequest struct {
entries []*bsmsg.Entry
ses uint64
}
work chan struct{}
done chan struct{}
type wantlistRequest struct {
wl *wantlist.SessionTrackedWantlist
}
// New creats a new MessageQueue.
func New(p peer.ID, network MessageNetwork) *MessageQueue {
func New(ctx context.Context, p peer.ID, network MessageNetwork) *MessageQueue {
return &MessageQueue{
done: make(chan struct{}),
work: make(chan struct{}, 1),
wl: wantlist.NewThreadSafe(),
network: network,
p: p,
ctx: ctx,
wl: wantlist.NewSessionTrackedWantlist(),
network: network,
p: p,
newRequests: make(chan request, 16),
outgoingMessages: make(chan bsmsg.BitSwapMessage),
done: make(chan struct{}),
}
}
// AddMessage adds new entries to an outgoing message for a given session.
func (mq *MessageQueue) AddMessage(entries []*bsmsg.Entry, ses uint64) {
if !mq.addEntries(entries, ses) {
return
}
select {
case mq.work <- struct{}{}:
default:
case mq.newRequests <- &messageRequest{entries, ses}:
case <-mq.ctx.Done():
}
}
// AddWantlist adds a complete session tracked want list to a message queue
func (mq *MessageQueue) AddWantlist(initialEntries []*wantlist.Entry) {
if len(initialEntries) > 0 {
if mq.out == nil {
mq.out = bsmsg.New(false)
}
func (mq *MessageQueue) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) {
wl := wantlist.NewSessionTrackedWantlist()
initialWants.CopyWants(wl)
for _, e := range initialEntries {
for k := range e.SesTrk {
mq.wl.AddEntry(e, k)
}
mq.out.AddEntry(e.Cid, e.Priority)
}
select {
case mq.work <- struct{}{}:
default:
}
select {
case mq.newRequests <- &wantlistRequest{wl}:
case <-mq.ctx.Done():
}
}
// Startup starts the processing of messages, and creates an initial message
// based on the given initial wantlist.
func (mq *MessageQueue) Startup(ctx context.Context) {
go mq.runQueue(ctx)
func (mq *MessageQueue) Startup() {
go mq.runQueue()
go mq.sendMessages()
}
// Shutdown stops the processing of messages for a message queue.
......@@ -92,17 +95,26 @@ func (mq *MessageQueue) Shutdown() {
close(mq.done)
}
func (mq *MessageQueue) runQueue(ctx context.Context) {
func (mq *MessageQueue) runQueue() {
outgoingMessages := func() chan bsmsg.BitSwapMessage {
if mq.nextMessage == nil {
return nil
}
return mq.outgoingMessages
}
for {
select {
case <-mq.work: // there is work to be done
mq.doWork(ctx)
case newRequest := <-mq.newRequests:
newRequest.handle(mq)
case outgoingMessages() <- mq.nextMessage:
mq.nextMessage = nil
case <-mq.done:
if mq.sender != nil {
mq.sender.Close()
}
return
case <-ctx.Done():
case <-mq.ctx.Done():
if mq.sender != nil {
mq.sender.Reset()
}
......@@ -111,63 +123,77 @@ func (mq *MessageQueue) runQueue(ctx context.Context) {
}
}
func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) bool {
var work bool
mq.outlk.Lock()
defer mq.outlk.Unlock()
// if we have no message held allocate a new one
if mq.out == nil {
mq.out = bsmsg.New(false)
func (mr *messageRequest) handle(mq *MessageQueue) {
mq.addEntries(mr.entries, mr.ses)
}
func (wr *wantlistRequest) handle(mq *MessageQueue) {
initialWants := wr.wl
initialWants.CopyWants(mq.wl)
if initialWants.Len() > 0 {
if mq.nextMessage == nil {
mq.nextMessage = bsmsg.New(false)
}
for _, e := range initialWants.Entries() {
mq.nextMessage.AddEntry(e.Cid, e.Priority)
}
}
}
// TODO: add a msg.Combine(...) method
// otherwise, combine the one we are holding with the
// one passed in
func (mq *MessageQueue) addEntries(entries []*bsmsg.Entry, ses uint64) {
for _, e := range entries {
if e.Cancel {
if mq.wl.Remove(e.Cid, ses) {
work = true
mq.out.Cancel(e.Cid)
if mq.nextMessage == nil {
mq.nextMessage = bsmsg.New(false)
}
mq.nextMessage.Cancel(e.Cid)
}
} else {
if mq.wl.Add(e.Cid, e.Priority, ses) {
work = true
mq.out.AddEntry(e.Cid, e.Priority)
if mq.nextMessage == nil {
mq.nextMessage = bsmsg.New(false)
}
mq.nextMessage.AddEntry(e.Cid, e.Priority)
}
}
}
return work
}
func (mq *MessageQueue) doWork(ctx context.Context) {
wlm := mq.extractOutgoingMessage()
if wlm == nil || wlm.Empty() {
return
func (mq *MessageQueue) sendMessages() {
for {
select {
case nextMessage := <-mq.outgoingMessages:
mq.sendMessage(nextMessage)
case <-mq.done:
return
case <-mq.ctx.Done():
return
}
}
}
func (mq *MessageQueue) sendMessage(message bsmsg.BitSwapMessage) {
// NB: only open a stream if we actually have data to send
err := mq.initializeSender(ctx)
err := mq.initializeSender()
if err != nil {
log.Infof("cant open message sender to peer %s: %s", mq.p, err)
// TODO: cant connect, what now?
return
}
// send wantlist updates
for i := 0; i < maxRetries; i++ { // try to send this message until we fail.
if mq.attemptSendAndRecovery(ctx, wlm) {
if mq.attemptSendAndRecovery(message) {
return
}
}
}
func (mq *MessageQueue) initializeSender(ctx context.Context) error {
func (mq *MessageQueue) initializeSender() error {
if mq.sender != nil {
return nil
}
nsender, err := openSender(ctx, mq.network, mq.p)
nsender, err := openSender(mq.ctx, mq.network, mq.p)
if err != nil {
return err
}
......@@ -175,8 +201,8 @@ func (mq *MessageQueue) initializeSender(ctx context.Context) error {
return nil
}
func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.BitSwapMessage) bool {
err := mq.sender.SendMsg(ctx, wlm)
func (mq *MessageQueue) attemptSendAndRecovery(message bsmsg.BitSwapMessage) bool {
err := mq.sender.SendMsg(mq.ctx, message)
if err == nil {
return true
}
......@@ -188,14 +214,14 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi
select {
case <-mq.done:
return true
case <-ctx.Done():
case <-mq.ctx.Done():
return true
case <-time.After(time.Millisecond * 100):
// wait 100ms in case disconnect notifications are still propogating
log.Warning("SendMsg errored but neither 'done' nor context.Done() were set")
}
err = mq.initializeSender(ctx)
err = mq.initializeSender()
if err != nil {
log.Infof("couldnt open sender again after SendMsg(%s) failed: %s", mq.p, err)
// TODO(why): what do we do now?
......@@ -215,15 +241,6 @@ func (mq *MessageQueue) attemptSendAndRecovery(ctx context.Context, wlm bsmsg.Bi
return false
}
func (mq *MessageQueue) extractOutgoingMessage() bsmsg.BitSwapMessage {
// grab outgoing message
mq.outlk.Lock()
wlm := mq.out
mq.out = nil
mq.outlk.Unlock()
return wlm
}
func openSender(ctx context.Context, network MessageNetwork, p peer.ID) (bsnet.MessageSender, error) {
// allow ten minutes for connections this includes looking them up in the
// dht dialing them, and handshaking
......
......@@ -27,7 +27,6 @@ func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet
return fmn.messageSender, nil
}
return nil, fmn.messageSenderError
}
type fakeMessageSender struct {
......@@ -77,12 +76,12 @@ func TestStartupAndShutdown(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet)
messageQueue := New(ctx, peerID, fakenet)
ses := testutil.GenerateSessionID()
wl := testutil.GenerateWantlist(10, ses)
messageQueue.Startup(ctx)
messageQueue.AddWantlist(wl.Entries())
messageQueue.Startup()
messageQueue.AddWantlist(wl)
messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond)
if len(messages) != 1 {
t.Fatal("wrong number of messages were sent for initial wants")
......@@ -119,11 +118,11 @@ func TestSendingMessagesDeduped(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet)
messageQueue := New(ctx, peerID, fakenet)
ses1 := testutil.GenerateSessionID()
ses2 := testutil.GenerateSessionID()
entries := testutil.GenerateMessageEntries(10, false)
messageQueue.Startup(ctx)
messageQueue.Startup()
messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(entries, ses2)
......@@ -142,13 +141,13 @@ func TestSendingMessagesPartialDupe(t *testing.T) {
fakeSender := &fakeMessageSender{nil, fullClosedChan, resetChan, messagesSent}
fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(peerID, fakenet)
messageQueue := New(ctx, peerID, fakenet)
ses1 := testutil.GenerateSessionID()
ses2 := testutil.GenerateSessionID()
entries := testutil.GenerateMessageEntries(10, false)
moreEntries := testutil.GenerateMessageEntries(5, false)
secondEntries := append(entries[5:], moreEntries...)
messageQueue.Startup(ctx)
messageQueue.Startup()
messageQueue.AddMessage(entries, ses1)
messageQueue.AddMessage(secondEntries, ses2)
......
......@@ -20,13 +20,13 @@ var (
// PeerQueue provides a queer of messages to be sent for a single peer.
type PeerQueue interface {
AddMessage(entries []*bsmsg.Entry, ses uint64)
Startup(ctx context.Context)
AddWantlist(initialEntries []*wantlist.Entry)
Startup()
AddWantlist(initialWants *wantlist.SessionTrackedWantlist)
Shutdown()
}
// PeerQueueFactory provides a function that will create a PeerQueue.
type PeerQueueFactory func(p peer.ID) PeerQueue
type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue
type peerMessage interface {
handle(pm *PeerManager)
......@@ -69,13 +69,13 @@ func (pm *PeerManager) ConnectedPeers() []peer.ID {
// Connected is called to add a new peer to the pool, and send it an initial set
// of wants.
func (pm *PeerManager) Connected(p peer.ID, initialEntries []*wantlist.Entry) {
func (pm *PeerManager) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) {
pm.peerQueuesLk.Lock()
pq := pm.getOrCreate(p)
if pq.refcnt == 0 {
pq.pq.AddWantlist(initialEntries)
pq.pq.AddWantlist(initialWants)
}
pq.refcnt++
......@@ -128,8 +128,8 @@ func (pm *PeerManager) SendMessage(entries []*bsmsg.Entry, targets []peer.ID, fr
func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance {
pqi, ok := pm.peerQueues[p]
if !ok {
pq := pm.createPeerQueue(p)
pq.Startup(pm.ctx)
pq := pm.createPeerQueue(pm.ctx, p)
pq.Startup()
pqi = &peerQueueInstance{0, pq}
pm.peerQueues[p] = pqi
}
......
......@@ -24,15 +24,15 @@ type fakePeer struct {
messagesSent chan messageSent
}
func (fp *fakePeer) Startup(ctx context.Context) {}
func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) Startup() {}
func (fp *fakePeer) Shutdown() {}
func (fp *fakePeer) AddMessage(entries []*bsmsg.Entry, ses uint64) {
fp.messagesSent <- messageSent{fp.p, entries, ses}
}
func (fp *fakePeer) AddWantlist(initialEntries []*wantlist.Entry) {}
func (fp *fakePeer) AddWantlist(initialWants *wantlist.SessionTrackedWantlist) {}
func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory {
return func(p peer.ID) PeerQueue {
return func(ctx context.Context, p peer.ID) PeerQueue {
return &fakePeer{
p: p,
messagesSent: messagesSent,
......
......@@ -39,8 +39,8 @@ func GenerateCids(n int) []cid.Cid {
}
// GenerateWantlist makes a populated wantlist.
func GenerateWantlist(n int, ses uint64) *wantlist.ThreadSafe {
wl := wantlist.NewThreadSafe()
func GenerateWantlist(n int, ses uint64) *wantlist.SessionTrackedWantlist {
wl := wantlist.NewSessionTrackedWantlist()
for i := 0; i < n; i++ {
prioritySeq++
entry := wantlist.NewRefEntry(blockGenerator.Next().Cid(), prioritySeq)
......
// package wantlist implements an object for bitswap that contains the keys
// Package wantlist implements an object for bitswap that contains the keys
// that a given peer wants.
package wantlist
import (
"sort"
"sync"
cid "github.com/ipfs/go-cid"
)
type ThreadSafe struct {
lk sync.RWMutex
set map[cid.Cid]*Entry
type SessionTrackedWantlist struct {
set map[cid.Cid]*sessionTrackedEntry
}
// not threadsafe
type Wantlist struct {
set map[cid.Cid]*Entry
}
......@@ -23,17 +20,20 @@ type Entry struct {
Cid cid.Cid
Priority int
SesTrk map[uint64]struct{}
// Trash in a book-keeping field
Trash bool
}
type sessionTrackedEntry struct {
*Entry
sesTrk map[uint64]struct{}
}
// NewRefEntry creates a new reference tracked wantlist entry.
func NewRefEntry(c cid.Cid, p int) *Entry {
return &Entry{
Cid: c,
Priority: p,
SesTrk: make(map[uint64]struct{}),
}
}
......@@ -43,9 +43,9 @@ func (es entrySlice) Len() int { return len(es) }
func (es entrySlice) Swap(i, j int) { es[i], es[j] = es[j], es[i] }
func (es entrySlice) Less(i, j int) bool { return es[i].Priority > es[j].Priority }
func NewThreadSafe() *ThreadSafe {
return &ThreadSafe{
set: make(map[cid.Cid]*Entry),
func NewSessionTrackedWantlist() *SessionTrackedWantlist {
return &SessionTrackedWantlist{
set: make(map[cid.Cid]*sessionTrackedEntry),
}
}
......@@ -63,33 +63,31 @@ func New() *Wantlist {
// TODO: think through priority changes here
// Add returns true if the cid did not exist in the wantlist before this call
// (even if it was under a different session).
func (w *ThreadSafe) Add(c cid.Cid, priority int, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
func (w *SessionTrackedWantlist) Add(c cid.Cid, priority int, ses uint64) bool {
if e, ok := w.set[c]; ok {
e.SesTrk[ses] = struct{}{}
e.sesTrk[ses] = struct{}{}
return false
}
w.set[c] = &Entry{
Cid: c,
Priority: priority,
SesTrk: map[uint64]struct{}{ses: struct{}{}},
w.set[c] = &sessionTrackedEntry{
Entry: &Entry{Cid: c, Priority: priority},
sesTrk: map[uint64]struct{}{ses: struct{}{}},
}
return true
}
// AddEntry adds given Entry to the wantlist. For more information see Add method.
func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
func (w *SessionTrackedWantlist) AddEntry(e *Entry, ses uint64) bool {
if ex, ok := w.set[e.Cid]; ok {
ex.SesTrk[ses] = struct{}{}
ex.sesTrk[ses] = struct{}{}
return false
}
w.set[e.Cid] = e
e.SesTrk[ses] = struct{}{}
w.set[e.Cid] = &sessionTrackedEntry{
Entry: e,
sesTrk: map[uint64]struct{}{ses: struct{}{}},
}
return true
}
......@@ -97,16 +95,14 @@ func (w *ThreadSafe) AddEntry(e *Entry, ses uint64) bool {
// 'true' is returned if this call to Remove removed the final session ID
// tracking the cid. (meaning true will be returned iff this call caused the
// value of 'Contains(c)' to change from true to false)
func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool {
w.lk.Lock()
defer w.lk.Unlock()
func (w *SessionTrackedWantlist) Remove(c cid.Cid, ses uint64) bool {
e, ok := w.set[c]
if !ok {
return false
}
delete(e.SesTrk, ses)
if len(e.SesTrk) == 0 {
delete(e.sesTrk, ses)
if len(e.sesTrk) == 0 {
delete(w.set, c)
return true
}
......@@ -115,35 +111,40 @@ func (w *ThreadSafe) Remove(c cid.Cid, ses uint64) bool {
// Contains returns true if the given cid is in the wantlist tracked by one or
// more sessions.
func (w *ThreadSafe) Contains(k cid.Cid) (*Entry, bool) {
w.lk.RLock()
defer w.lk.RUnlock()
func (w *SessionTrackedWantlist) Contains(k cid.Cid) (*Entry, bool) {
e, ok := w.set[k]
return e, ok
if !ok {
return nil, false
}
return e.Entry, true
}
func (w *ThreadSafe) Entries() []*Entry {
w.lk.RLock()
defer w.lk.RUnlock()
func (w *SessionTrackedWantlist) Entries() []*Entry {
es := make([]*Entry, 0, len(w.set))
for _, e := range w.set {
es = append(es, e)
es = append(es, e.Entry)
}
return es
}
func (w *ThreadSafe) SortedEntries() []*Entry {
func (w *SessionTrackedWantlist) SortedEntries() []*Entry {
es := w.Entries()
sort.Sort(entrySlice(es))
return es
}
func (w *ThreadSafe) Len() int {
w.lk.RLock()
defer w.lk.RUnlock()
func (w *SessionTrackedWantlist) Len() int {
return len(w.set)
}
func (w *SessionTrackedWantlist) CopyWants(to *SessionTrackedWantlist) {
for _, e := range w.set {
for k := range e.sesTrk {
to.AddEntry(e.Entry, k)
}
}
}
func (w *Wantlist) Len() int {
return len(w.set)
}
......
......@@ -82,8 +82,8 @@ func TestBasicWantlist(t *testing.T) {
}
}
func TestSesRefWantlist(t *testing.T) {
wl := NewThreadSafe()
func TestSessionTrackedWantlist(t *testing.T) {
wl := NewSessionTrackedWantlist()
if !wl.Add(testcids[0], 5, 1) {
t.Fatal("should have added")
......
......@@ -24,7 +24,7 @@ const (
// managed by the WantManager.
type PeerHandler interface {
Disconnected(p peer.ID)
Connected(p peer.ID, initialEntries []*wantlist.Entry)
Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist)
SendMessage(entries []*bsmsg.Entry, targets []peer.ID, from uint64)
}
......@@ -42,8 +42,8 @@ type WantManager struct {
wantMessages chan wantMessage
// synchronized by Run loop, only touch inside there
wl *wantlist.ThreadSafe
bcwl *wantlist.ThreadSafe
wl *wantlist.SessionTrackedWantlist
bcwl *wantlist.SessionTrackedWantlist
ctx context.Context
cancel func()
......@@ -59,8 +59,8 @@ func New(ctx context.Context) *WantManager {
"Number of items in wantlist.").Gauge()
return &WantManager{
wantMessages: make(chan wantMessage, 10),
wl: wantlist.NewThreadSafe(),
bcwl: wantlist.NewThreadSafe(),
wl: wantlist.NewSessionTrackedWantlist(),
bcwl: wantlist.NewSessionTrackedWantlist(),
ctx: ctx,
cancel: cancel,
wantlistGauge: wantlistGauge,
......@@ -274,7 +274,7 @@ type connectedMessage struct {
}
func (cm *connectedMessage) handle(wm *WantManager) {
wm.peerHandler.Connected(cm.p, wm.bcwl.Entries())
wm.peerHandler.Connected(cm.p, wm.bcwl)
}
type disconnectedMessage struct {
......
......@@ -25,8 +25,8 @@ func (fph *fakePeerHandler) SendMessage(entries []*bsmsg.Entry, targets []peer.I
fph.lk.Unlock()
}
func (fph *fakePeerHandler) Connected(p peer.ID, initialEntries []*wantlist.Entry) {}
func (fph *fakePeerHandler) Disconnected(p peer.ID) {}
func (fph *fakePeerHandler) Connected(p peer.ID, initialWants *wantlist.SessionTrackedWantlist) {}
func (fph *fakePeerHandler) Disconnected(p peer.ID) {}
func (fph *fakePeerHandler) getLastWantSet() wantSet {
fph.lk.Lock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment