From c1219303a092161960094dbb666946f1f28d14d0 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet <juan@benet.ai> Date: Fri, 26 Sep 2014 03:57:34 -0700 Subject: [PATCH] fixed muxer errors --- net/mux/mux.go | 44 ++++++++++++++++++++++++++++++-------------- net/mux/mux_test.go | 12 ++++++------ 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/net/mux/mux.go b/net/mux/mux.go index 57cfe3343..e02a926d3 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -2,6 +2,7 @@ package mux import ( "errors" + "sync" msg "github.com/jbenet/go-ipfs/net/message" u "github.com/jbenet/go-ipfs/util" @@ -30,6 +31,8 @@ type Muxer struct { // cancel is the function to stop the Muxer cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup *msg.Pipe } @@ -58,11 +61,14 @@ func (m *Muxer) Start(ctx context.Context) error { } // make a cancellable context. - ctx, m.cancel = context.WithCancel(ctx) + m.ctx, m.cancel = context.WithCancel(ctx) + m.wg = sync.WaitGroup{} - go m.handleIncomingMessages(ctx) + m.wg.Add(1) + go m.handleIncomingMessages() for pid, proto := range m.Protocols { - go m.handleOutgoingMessages(ctx, pid, proto) + m.wg.Add(1) + go m.handleOutgoingMessages(pid, proto) } return nil @@ -70,8 +76,15 @@ func (m *Muxer) Start(ctx context.Context) error { // Stop stops muxer activity. func (m *Muxer) Stop() { + if m.cancel == nil { + panic("muxer stopped twice.") + } + // issue cancel, and wipe func. m.cancel() m.cancel = context.CancelFunc(nil) + + // wait for everything to wind down. + m.wg.Wait() } // AddProtocol adds a Protocol with given ProtocolID to the Muxer. @@ -86,7 +99,8 @@ func (m *Muxer) AddProtocol(p Protocol, pid ProtocolID) error { // handleIncoming consumes the messages on the m.Incoming channel and // routes them appropriately (to the protocols). -func (m *Muxer) handleIncomingMessages(ctx context.Context) { +func (m *Muxer) handleIncomingMessages() { + defer m.wg.Done() for { if m == nil { @@ -98,16 +112,16 @@ func (m *Muxer) handleIncomingMessages(ctx context.Context) { if !more { return } - go m.handleIncomingMessage(ctx, msg) + go m.handleIncomingMessage(msg) - case <-ctx.Done(): + case <-m.ctx.Done(): return } } } // handleIncomingMessage routes message to the appropriate protocol. -func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) { +func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { data, pid, err := unwrapData(m1.Data()) if err != nil { @@ -124,31 +138,33 @@ func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) { select { case proto.GetPipe().Incoming <- m2: - case <-ctx.Done(): - u.PErr("%v\n", ctx.Err()) + case <-m.ctx.Done(): + u.PErr("%v\n", m.ctx.Err()) return } } // handleOutgoingMessages consumes the messages on the proto.Outgoing channel, // wraps them and sends them out. -func (m *Muxer) handleOutgoingMessages(ctx context.Context, pid ProtocolID, proto Protocol) { +func (m *Muxer) handleOutgoingMessages(pid ProtocolID, proto Protocol) { + defer m.wg.Done() + for { select { case msg, more := <-proto.GetPipe().Outgoing: if !more { return } - go m.handleOutgoingMessage(ctx, pid, msg) + go m.handleOutgoingMessage(pid, msg) - case <-ctx.Done(): + case <-m.ctx.Done(): return } } } // handleOutgoingMessage wraps out a message and sends it out the -func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 msg.NetMessage) { +func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) { data, err := wrapData(m1.Data(), pid) if err != nil { u.PErr("muxer serializing error: %v\n", err) @@ -158,7 +174,7 @@ func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 ms m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: - case <-ctx.Done(): + case <-m.ctx.Done(): return } } diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go index 6aeeda28c..17606bf93 100644 --- a/net/mux/mux_test.go +++ b/net/mux/mux_test.go @@ -229,13 +229,13 @@ func TestStopping(t *testing.T) { mux1.Start(context.Background()) // test outgoing p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo1", "bar1", "baz1"} { p1.Outgoing <- msg.New(peer1, []byte(s)) testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) } // test incoming p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo2", "bar2", "baz2"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) @@ -250,17 +250,17 @@ func TestStopping(t *testing.T) { } // test outgoing p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo3", "bar3", "baz3"} { p1.Outgoing <- msg.New(peer1, []byte(s)) select { - case <-mux1.Outgoing: - t.Error("should not have received anything.") + case m := <-mux1.Outgoing: + t.Errorf("should not have received anything. Got: %v", string(m.Data())) case <-time.After(time.Millisecond): } } // test incoming p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo4", "bar4", "baz4"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) -- GitLab