From c1219303a092161960094dbb666946f1f28d14d0 Mon Sep 17 00:00:00 2001
From: Juan Batiz-Benet <juan@benet.ai>
Date: Fri, 26 Sep 2014 03:57:34 -0700
Subject: [PATCH] fixed muxer errors

---
 net/mux/mux.go      | 44 ++++++++++++++++++++++++++++++--------------
 net/mux/mux_test.go | 12 ++++++------
 2 files changed, 36 insertions(+), 20 deletions(-)

diff --git a/net/mux/mux.go b/net/mux/mux.go
index 57cfe3343..e02a926d3 100644
--- a/net/mux/mux.go
+++ b/net/mux/mux.go
@@ -2,6 +2,7 @@ package mux
 
 import (
 	"errors"
+	"sync"
 
 	msg "github.com/jbenet/go-ipfs/net/message"
 	u "github.com/jbenet/go-ipfs/util"
@@ -30,6 +31,8 @@ type Muxer struct {
 
 	// cancel is the function to stop the Muxer
 	cancel context.CancelFunc
+	ctx    context.Context
+	wg     sync.WaitGroup
 
 	*msg.Pipe
 }
@@ -58,11 +61,14 @@ func (m *Muxer) Start(ctx context.Context) error {
 	}
 
 	// make a cancellable context.
-	ctx, m.cancel = context.WithCancel(ctx)
+	m.ctx, m.cancel = context.WithCancel(ctx)
+	m.wg = sync.WaitGroup{}
 
-	go m.handleIncomingMessages(ctx)
+	m.wg.Add(1)
+	go m.handleIncomingMessages()
 	for pid, proto := range m.Protocols {
-		go m.handleOutgoingMessages(ctx, pid, proto)
+		m.wg.Add(1)
+		go m.handleOutgoingMessages(pid, proto)
 	}
 
 	return nil
@@ -70,8 +76,15 @@ func (m *Muxer) Start(ctx context.Context) error {
 
 // Stop stops muxer activity.
 func (m *Muxer) Stop() {
+	if m.cancel == nil {
+		panic("muxer stopped twice.")
+	}
+	// issue cancel, and wipe func.
 	m.cancel()
 	m.cancel = context.CancelFunc(nil)
+
+	// wait for everything to wind down.
+	m.wg.Wait()
 }
 
 // AddProtocol adds a Protocol with given ProtocolID to the Muxer.
@@ -86,7 +99,8 @@ func (m *Muxer) AddProtocol(p Protocol, pid ProtocolID) error {
 
 // handleIncoming consumes the messages on the m.Incoming channel and
 // routes them appropriately (to the protocols).
-func (m *Muxer) handleIncomingMessages(ctx context.Context) {
+func (m *Muxer) handleIncomingMessages() {
+	defer m.wg.Done()
 
 	for {
 		if m == nil {
@@ -98,16 +112,16 @@ func (m *Muxer) handleIncomingMessages(ctx context.Context) {
 			if !more {
 				return
 			}
-			go m.handleIncomingMessage(ctx, msg)
+			go m.handleIncomingMessage(msg)
 
-		case <-ctx.Done():
+		case <-m.ctx.Done():
 			return
 		}
 	}
 }
 
 // handleIncomingMessage routes message to the appropriate protocol.
-func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) {
+func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
 
 	data, pid, err := unwrapData(m1.Data())
 	if err != nil {
@@ -124,31 +138,33 @@ func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) {
 
 	select {
 	case proto.GetPipe().Incoming <- m2:
-	case <-ctx.Done():
-		u.PErr("%v\n", ctx.Err())
+	case <-m.ctx.Done():
+		u.PErr("%v\n", m.ctx.Err())
 		return
 	}
 }
 
 // handleOutgoingMessages consumes the messages on the proto.Outgoing channel,
 // wraps them and sends them out.
-func (m *Muxer) handleOutgoingMessages(ctx context.Context, pid ProtocolID, proto Protocol) {
+func (m *Muxer) handleOutgoingMessages(pid ProtocolID, proto Protocol) {
+	defer m.wg.Done()
+
 	for {
 		select {
 		case msg, more := <-proto.GetPipe().Outgoing:
 			if !more {
 				return
 			}
-			go m.handleOutgoingMessage(ctx, pid, msg)
+			go m.handleOutgoingMessage(pid, msg)
 
-		case <-ctx.Done():
+		case <-m.ctx.Done():
 			return
 		}
 	}
 }
 
 // handleOutgoingMessage wraps out a message and sends it out the
-func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 msg.NetMessage) {
+func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) {
 	data, err := wrapData(m1.Data(), pid)
 	if err != nil {
 		u.PErr("muxer serializing error: %v\n", err)
@@ -158,7 +174,7 @@ func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 ms
 	m2 := msg.New(m1.Peer(), data)
 	select {
 	case m.GetPipe().Outgoing <- m2:
-	case <-ctx.Done():
+	case <-m.ctx.Done():
 		return
 	}
 }
diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go
index 6aeeda28c..17606bf93 100644
--- a/net/mux/mux_test.go
+++ b/net/mux/mux_test.go
@@ -229,13 +229,13 @@ func TestStopping(t *testing.T) {
 	mux1.Start(context.Background())
 
 	// test outgoing p1
-	for _, s := range []string{"foo", "bar", "baz"} {
+	for _, s := range []string{"foo1", "bar1", "baz1"} {
 		p1.Outgoing <- msg.New(peer1, []byte(s))
 		testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s))
 	}
 
 	// test incoming p1
-	for _, s := range []string{"foo", "bar", "baz"} {
+	for _, s := range []string{"foo2", "bar2", "baz2"} {
 		d, err := wrapData([]byte(s), pid1)
 		if err != nil {
 			t.Error(err)
@@ -250,17 +250,17 @@ func TestStopping(t *testing.T) {
 	}
 
 	// test outgoing p1
-	for _, s := range []string{"foo", "bar", "baz"} {
+	for _, s := range []string{"foo3", "bar3", "baz3"} {
 		p1.Outgoing <- msg.New(peer1, []byte(s))
 		select {
-		case <-mux1.Outgoing:
-			t.Error("should not have received anything.")
+		case m := <-mux1.Outgoing:
+			t.Errorf("should not have received anything. Got: %v", string(m.Data()))
 		case <-time.After(time.Millisecond):
 		}
 	}
 
 	// test incoming p1
-	for _, s := range []string{"foo", "bar", "baz"} {
+	for _, s := range []string{"foo4", "bar4", "baz4"} {
 		d, err := wrapData([]byte(s), pid1)
 		if err != nil {
 			t.Error(err)
-- 
GitLab