Commit bdccb20e authored by Dirk McCormick's avatar Dirk McCormick

fix: simplify message queue shutdown

parent c26bd59d
...@@ -56,6 +56,7 @@ type MessageNetwork interface { ...@@ -56,6 +56,7 @@ type MessageNetwork interface {
// MessageQueue implements queue of want messages to send to peers. // MessageQueue implements queue of want messages to send to peers.
type MessageQueue struct { type MessageQueue struct {
ctx context.Context ctx context.Context
shutdown func()
p peer.ID p peer.ID
network MessageNetwork network MessageNetwork
dhTimeoutMgr DontHaveTimeoutManager dhTimeoutMgr DontHaveTimeoutManager
...@@ -63,7 +64,6 @@ type MessageQueue struct { ...@@ -63,7 +64,6 @@ type MessageQueue struct {
sendErrorBackoff time.Duration sendErrorBackoff time.Duration
outgoingWork chan time.Time outgoingWork chan time.Time
done chan struct{}
// Take lock whenever any of these variables are modified // Take lock whenever any of these variables are modified
wllock sync.Mutex wllock sync.Mutex
...@@ -170,8 +170,10 @@ func New(ctx context.Context, p peer.ID, network MessageNetwork, onDontHaveTimeo ...@@ -170,8 +170,10 @@ func New(ctx context.Context, p peer.ID, network MessageNetwork, onDontHaveTimeo
func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork, func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork,
maxMsgSize int, sendErrorBackoff time.Duration, dhTimeoutMgr DontHaveTimeoutManager) *MessageQueue { maxMsgSize int, sendErrorBackoff time.Duration, dhTimeoutMgr DontHaveTimeoutManager) *MessageQueue {
ctx, cancel := context.WithCancel(ctx)
mq := &MessageQueue{ mq := &MessageQueue{
ctx: ctx, ctx: ctx,
shutdown: cancel,
p: p, p: p,
network: network, network: network,
dhTimeoutMgr: dhTimeoutMgr, dhTimeoutMgr: dhTimeoutMgr,
...@@ -180,7 +182,6 @@ func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork, ...@@ -180,7 +182,6 @@ func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork,
peerWants: newRecallWantList(), peerWants: newRecallWantList(),
cancels: cid.NewSet(), cancels: cid.NewSet(),
outgoingWork: make(chan time.Time, 1), outgoingWork: make(chan time.Time, 1),
done: make(chan struct{}),
rebroadcastInterval: defaultRebroadcastInterval, rebroadcastInterval: defaultRebroadcastInterval,
sendErrorBackoff: sendErrorBackoff, sendErrorBackoff: sendErrorBackoff,
priority: maxPriority, priority: maxPriority,
...@@ -301,12 +302,17 @@ func (mq *MessageQueue) Startup() { ...@@ -301,12 +302,17 @@ func (mq *MessageQueue) Startup() {
// Shutdown stops the processing of messages for a message queue. // Shutdown stops the processing of messages for a message queue.
func (mq *MessageQueue) Shutdown() { func (mq *MessageQueue) Shutdown() {
close(mq.done) mq.shutdown()
} }
func (mq *MessageQueue) onShutdown() { func (mq *MessageQueue) onShutdown() {
// Shut down the DONT_HAVE timeout manager // Shut down the DONT_HAVE timeout manager
mq.dhTimeoutMgr.Shutdown() mq.dhTimeoutMgr.Shutdown()
// Reset the streamMessageSender
if mq.sender != nil {
_ = mq.sender.Reset()
}
} }
func (mq *MessageQueue) runQueue() { func (mq *MessageQueue) runQueue() {
...@@ -352,17 +358,7 @@ func (mq *MessageQueue) runQueue() { ...@@ -352,17 +358,7 @@ func (mq *MessageQueue) runQueue() {
// in sendMessageDebounce. Send immediately. // in sendMessageDebounce. Send immediately.
workScheduled = time.Time{} workScheduled = time.Time{}
mq.sendIfReady() mq.sendIfReady()
case <-mq.done:
if mq.sender != nil {
mq.sender.Close()
}
return
case <-mq.ctx.Done(): case <-mq.ctx.Done():
if mq.sender != nil {
// TODO: should I call sender.Close() here also to stop
// and in progress connection?
_ = mq.sender.Reset()
}
return return
} }
} }
......
...@@ -82,17 +82,15 @@ func (fp *fakeDontHaveTimeoutMgr) pendingCount() int { ...@@ -82,17 +82,15 @@ func (fp *fakeDontHaveTimeoutMgr) pendingCount() int {
type fakeMessageSender struct { type fakeMessageSender struct {
lk sync.Mutex lk sync.Mutex
fullClosed chan<- struct{}
reset chan<- struct{} reset chan<- struct{}
messagesSent chan<- []bsmsg.Entry messagesSent chan<- []bsmsg.Entry
supportsHave bool supportsHave bool
} }
func newFakeMessageSender(fullClosed chan<- struct{}, reset chan<- struct{}, func newFakeMessageSender(reset chan<- struct{},
messagesSent chan<- []bsmsg.Entry, supportsHave bool) *fakeMessageSender { messagesSent chan<- []bsmsg.Entry, supportsHave bool) *fakeMessageSender {
return &fakeMessageSender{ return &fakeMessageSender{
fullClosed: fullClosed,
reset: reset, reset: reset,
messagesSent: messagesSent, messagesSent: messagesSent,
supportsHave: supportsHave, supportsHave: supportsHave,
...@@ -106,7 +104,7 @@ func (fms *fakeMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMess ...@@ -106,7 +104,7 @@ func (fms *fakeMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMess
fms.messagesSent <- msg.Wantlist() fms.messagesSent <- msg.Wantlist()
return nil return nil
} }
func (fms *fakeMessageSender) Close() error { fms.fullClosed <- struct{}{}; return nil } func (fms *fakeMessageSender) Close() error { return nil }
func (fms *fakeMessageSender) Reset() error { fms.reset <- struct{}{}; return nil } func (fms *fakeMessageSender) Reset() error { fms.reset <- struct{}{}; return nil }
func (fms *fakeMessageSender) SupportsHave() bool { return fms.supportsHave } func (fms *fakeMessageSender) SupportsHave() bool { return fms.supportsHave }
...@@ -141,8 +139,7 @@ func TestStartupAndShutdown(t *testing.T) { ...@@ -141,8 +139,7 @@ func TestStartupAndShutdown(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -170,11 +167,9 @@ func TestStartupAndShutdown(t *testing.T) { ...@@ -170,11 +167,9 @@ func TestStartupAndShutdown(t *testing.T) {
timeoutctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) timeoutctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel() defer cancel()
select { select {
case <-fullClosedChan:
case <-resetChan: case <-resetChan:
t.Fatal("message sender should have been closed but was reset")
case <-timeoutctx.Done(): case <-timeoutctx.Done():
t.Fatal("message sender should have been closed but wasn't") t.Fatal("message sender should have been reset but wasn't")
} }
} }
...@@ -182,8 +177,7 @@ func TestSendingMessagesDeduped(t *testing.T) { ...@@ -182,8 +177,7 @@ func TestSendingMessagesDeduped(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -204,8 +198,7 @@ func TestSendingMessagesPartialDupe(t *testing.T) { ...@@ -204,8 +198,7 @@ func TestSendingMessagesPartialDupe(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -226,8 +219,7 @@ func TestSendingMessagesPriority(t *testing.T) { ...@@ -226,8 +219,7 @@ func TestSendingMessagesPriority(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -294,8 +286,7 @@ func TestCancelOverridesPendingWants(t *testing.T) { ...@@ -294,8 +286,7 @@ func TestCancelOverridesPendingWants(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -345,8 +336,7 @@ func TestWantOverridesPendingCancels(t *testing.T) { ...@@ -345,8 +336,7 @@ func TestWantOverridesPendingCancels(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -392,8 +382,7 @@ func TestWantlistRebroadcast(t *testing.T) { ...@@ -392,8 +382,7 @@ func TestWantlistRebroadcast(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb)
...@@ -488,8 +477,7 @@ func TestSendingLargeMessages(t *testing.T) { ...@@ -488,8 +477,7 @@ func TestSendingLargeMessages(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
dhtm := &fakeDontHaveTimeoutMgr{} dhtm := &fakeDontHaveTimeoutMgr{}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
...@@ -518,8 +506,7 @@ func TestSendToPeerThatDoesntSupportHave(t *testing.T) { ...@@ -518,8 +506,7 @@ func TestSendToPeerThatDoesntSupportHave(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, false)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, false)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
...@@ -573,8 +560,7 @@ func TestSendToPeerThatDoesntSupportHaveMonitorsTimeouts(t *testing.T) { ...@@ -573,8 +560,7 @@ func TestSendToPeerThatDoesntSupportHaveMonitorsTimeouts(t *testing.T) {
ctx := context.Background() ctx := context.Background()
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, false)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, false)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
...@@ -624,8 +610,7 @@ func BenchmarkMessageQueue(b *testing.B) { ...@@ -624,8 +610,7 @@ func BenchmarkMessageQueue(b *testing.B) {
createQueue := func() *MessageQueue { createQueue := func() *MessageQueue {
messagesSent := make(chan []bsmsg.Entry) messagesSent := make(chan []bsmsg.Entry)
resetChan := make(chan struct{}, 1) resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1) fakeSender := newFakeMessageSender(resetChan, messagesSent, true)
fakeSender := newFakeMessageSender(fullClosedChan, resetChan, messagesSent, true)
fakenet := &fakeMessageNetwork{nil, nil, fakeSender} fakenet := &fakeMessageNetwork{nil, nil, fakeSender}
dhtm := &fakeDontHaveTimeoutMgr{} dhtm := &fakeDontHaveTimeoutMgr{}
peerID := testutil.GeneratePeers(1)[0] peerID := testutil.GeneratePeers(1)[0]
......
...@@ -94,7 +94,6 @@ type streamMessageSender struct { ...@@ -94,7 +94,6 @@ type streamMessageSender struct {
stream network.Stream stream network.Stream
bsnet *impl bsnet *impl
opts *MessageSenderOpts opts *MessageSenderOpts
done chan struct{}
} }
// Open a stream to the remote peer // Open a stream to the remote peer
...@@ -107,13 +106,6 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro ...@@ -107,13 +106,6 @@ func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, erro
return nil, err return nil, err
} }
// Check if the sender has been closed
select {
case <-s.done:
return nil, nil
default:
}
stream, err := s.bsnet.newStreamToPeer(ctx, s.to) stream, err := s.bsnet.newStreamToPeer(ctx, s.to)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -135,7 +127,6 @@ func (s *streamMessageSender) Reset() error { ...@@ -135,7 +127,6 @@ func (s *streamMessageSender) Reset() error {
// Close the stream // Close the stream
func (s *streamMessageSender) Close() error { func (s *streamMessageSender) Close() error {
close(s.done)
return helpers.FullClose(s.stream) return helpers.FullClose(s.stream)
} }
...@@ -172,8 +163,6 @@ func (s *streamMessageSender) multiAttempt(ctx context.Context, fn func(context. ...@@ -172,8 +163,6 @@ func (s *streamMessageSender) multiAttempt(ctx context.Context, fn func(context.
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-s.done:
return nil
default: default:
} }
...@@ -195,8 +184,6 @@ func (s *streamMessageSender) multiAttempt(ctx context.Context, fn func(context. ...@@ -195,8 +184,6 @@ func (s *streamMessageSender) multiAttempt(ctx context.Context, fn func(context.
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case <-s.done:
return nil
case <-time.After(s.opts.SendErrorBackoff): case <-time.After(s.opts.SendErrorBackoff):
// wait a short time in case disconnect notifications are still propagating // wait a short time in case disconnect notifications are still propagating
log.Infof("send message to %s failed but context was not Done: %s", s.to, err) log.Infof("send message to %s failed but context was not Done: %s", s.to, err)
...@@ -286,7 +273,6 @@ func (bsnet *impl) NewMessageSender(ctx context.Context, p peer.ID, opts *Messag ...@@ -286,7 +273,6 @@ func (bsnet *impl) NewMessageSender(ctx context.Context, p peer.ID, opts *Messag
to: p, to: p,
bsnet: bsnet, bsnet: bsnet,
opts: opts, opts: opts,
done: make(chan struct{}),
} }
err := sender.multiAttempt(ctx, func(fnctx context.Context) error { err := sender.multiAttempt(ctx, func(fnctx context.Context) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment