Commit c1219303 authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

fixed muxer errors

parent 2507680d
......@@ -2,6 +2,7 @@ package mux
import (
"errors"
"sync"
msg "github.com/jbenet/go-ipfs/net/message"
u "github.com/jbenet/go-ipfs/util"
......@@ -30,6 +31,8 @@ type Muxer struct {
// cancel is the function to stop the Muxer
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
*msg.Pipe
}
......@@ -58,11 +61,14 @@ func (m *Muxer) Start(ctx context.Context) error {
}
// make a cancellable context.
ctx, m.cancel = context.WithCancel(ctx)
m.ctx, m.cancel = context.WithCancel(ctx)
m.wg = sync.WaitGroup{}
go m.handleIncomingMessages(ctx)
m.wg.Add(1)
go m.handleIncomingMessages()
for pid, proto := range m.Protocols {
go m.handleOutgoingMessages(ctx, pid, proto)
m.wg.Add(1)
go m.handleOutgoingMessages(pid, proto)
}
return nil
......@@ -70,8 +76,15 @@ func (m *Muxer) Start(ctx context.Context) error {
// Stop stops muxer activity.
func (m *Muxer) Stop() {
if m.cancel == nil {
panic("muxer stopped twice.")
}
// issue cancel, and wipe func.
m.cancel()
m.cancel = context.CancelFunc(nil)
// wait for everything to wind down.
m.wg.Wait()
}
// AddProtocol adds a Protocol with given ProtocolID to the Muxer.
......@@ -86,7 +99,8 @@ func (m *Muxer) AddProtocol(p Protocol, pid ProtocolID) error {
// handleIncoming consumes the messages on the m.Incoming channel and
// routes them appropriately (to the protocols).
func (m *Muxer) handleIncomingMessages(ctx context.Context) {
func (m *Muxer) handleIncomingMessages() {
defer m.wg.Done()
for {
if m == nil {
......@@ -98,16 +112,16 @@ func (m *Muxer) handleIncomingMessages(ctx context.Context) {
if !more {
return
}
go m.handleIncomingMessage(ctx, msg)
go m.handleIncomingMessage(msg)
case <-ctx.Done():
case <-m.ctx.Done():
return
}
}
}
// handleIncomingMessage routes message to the appropriate protocol.
func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) {
func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
data, pid, err := unwrapData(m1.Data())
if err != nil {
......@@ -124,31 +138,33 @@ func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) {
select {
case proto.GetPipe().Incoming <- m2:
case <-ctx.Done():
u.PErr("%v\n", ctx.Err())
case <-m.ctx.Done():
u.PErr("%v\n", m.ctx.Err())
return
}
}
// handleOutgoingMessages consumes the messages on the proto.Outgoing channel,
// wraps them and sends them out.
func (m *Muxer) handleOutgoingMessages(ctx context.Context, pid ProtocolID, proto Protocol) {
func (m *Muxer) handleOutgoingMessages(pid ProtocolID, proto Protocol) {
defer m.wg.Done()
for {
select {
case msg, more := <-proto.GetPipe().Outgoing:
if !more {
return
}
go m.handleOutgoingMessage(ctx, pid, msg)
go m.handleOutgoingMessage(pid, msg)
case <-ctx.Done():
case <-m.ctx.Done():
return
}
}
}
// handleOutgoingMessage wraps out a message and sends it out the
func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 msg.NetMessage) {
func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) {
data, err := wrapData(m1.Data(), pid)
if err != nil {
u.PErr("muxer serializing error: %v\n", err)
......@@ -158,7 +174,7 @@ func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 ms
m2 := msg.New(m1.Peer(), data)
select {
case m.GetPipe().Outgoing <- m2:
case <-ctx.Done():
case <-m.ctx.Done():
return
}
}
......
......@@ -229,13 +229,13 @@ func TestStopping(t *testing.T) {
mux1.Start(context.Background())
// test outgoing p1
for _, s := range []string{"foo", "bar", "baz"} {
for _, s := range []string{"foo1", "bar1", "baz1"} {
p1.Outgoing <- msg.New(peer1, []byte(s))
testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s))
}
// test incoming p1
for _, s := range []string{"foo", "bar", "baz"} {
for _, s := range []string{"foo2", "bar2", "baz2"} {
d, err := wrapData([]byte(s), pid1)
if err != nil {
t.Error(err)
......@@ -250,17 +250,17 @@ func TestStopping(t *testing.T) {
}
// test outgoing p1
for _, s := range []string{"foo", "bar", "baz"} {
for _, s := range []string{"foo3", "bar3", "baz3"} {
p1.Outgoing <- msg.New(peer1, []byte(s))
select {
case <-mux1.Outgoing:
t.Error("should not have received anything.")
case m := <-mux1.Outgoing:
t.Errorf("should not have received anything. Got: %v", string(m.Data()))
case <-time.After(time.Millisecond):
}
}
// test incoming p1
for _, s := range []string{"foo", "bar", "baz"} {
for _, s := range []string{"foo4", "bar4", "baz4"} {
d, err := wrapData([]byte(s), pid1)
if err != nil {
t.Error(err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment