package mux import ( "errors" "sync" msg "github.com/jbenet/go-ipfs/net/message" u "github.com/jbenet/go-ipfs/util" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" ) // Protocol objects produce + consume raw data. They are added to the Muxer // with a ProtocolID, which is added to outgoing payloads. Muxer properly // encapsulates and decapsulates when interfacing with its Protocols. The // Protocols do not encounter their ProtocolID. type Protocol interface { GetPipe() *msg.Pipe } // ProtocolMap maps ProtocolIDs to Protocols. type ProtocolMap map[ProtocolID]Protocol // Muxer is a simple multiplexor that reads + writes to Incoming and Outgoing // channels. It multiplexes various protocols, wrapping and unwrapping data // with a ProtocolID. type Muxer struct { // Protocols are the multiplexed services. Protocols ProtocolMap // cancel is the function to stop the Muxer cancel context.CancelFunc ctx context.Context wg sync.WaitGroup *msg.Pipe } // NewMuxer constructs a muxer given a protocol map. func NewMuxer(mp ProtocolMap) *Muxer { return &Muxer{ Protocols: mp, Pipe: msg.NewPipe(10), } } // GetPipe implements the Protocol interface func (m *Muxer) GetPipe() *msg.Pipe { return m.Pipe } // Start kicks off the Muxer goroutines. func (m *Muxer) Start(ctx context.Context) error { if m == nil { panic("nix muxer") } if m.cancel != nil { return errors.New("Muxer already started.") } // make a cancellable context. m.ctx, m.cancel = context.WithCancel(ctx) m.wg = sync.WaitGroup{} m.wg.Add(1) go m.handleIncomingMessages() for pid, proto := range m.Protocols { m.wg.Add(1) go m.handleOutgoingMessages(pid, proto) } return nil } // Stop stops muxer activity. func (m *Muxer) Stop() { if m.cancel == nil { panic("muxer stopped twice.") } // issue cancel, and wipe func. m.cancel() m.cancel = context.CancelFunc(nil) // wait for everything to wind down. m.wg.Wait() } // AddProtocol adds a Protocol with given ProtocolID to the Muxer. func (m *Muxer) AddProtocol(p Protocol, pid ProtocolID) error { if _, found := m.Protocols[pid]; found { return errors.New("Another protocol already using this ProtocolID") } m.Protocols[pid] = p return nil } // handleIncoming consumes the messages on the m.Incoming channel and // routes them appropriately (to the protocols). func (m *Muxer) handleIncomingMessages() { defer m.wg.Done() for { if m == nil { panic("nil muxer") } select { case msg, more := <-m.Incoming: if !more { return } go m.handleIncomingMessage(msg) case <-m.ctx.Done(): return } } } // handleIncomingMessage routes message to the appropriate protocol. func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { data, pid, err := unwrapData(m1.Data()) if err != nil { u.PErr("muxer de-serializing error: %v\n", err) return } m2 := msg.New(m1.Peer(), data) proto, found := m.Protocols[pid] if !found { u.PErr("muxer unknown protocol %v\n", pid) return } select { case proto.GetPipe().Incoming <- m2: case <-m.ctx.Done(): u.PErr("%v\n", m.ctx.Err()) return } } // handleOutgoingMessages consumes the messages on the proto.Outgoing channel, // wraps them and sends them out. func (m *Muxer) handleOutgoingMessages(pid ProtocolID, proto Protocol) { defer m.wg.Done() for { select { case msg, more := <-proto.GetPipe().Outgoing: if !more { return } go m.handleOutgoingMessage(pid, msg) case <-m.ctx.Done(): return } } } // handleOutgoingMessage wraps out a message and sends it out the func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) { data, err := wrapData(m1.Data(), pid) if err != nil { u.PErr("muxer serializing error: %v\n", err) return } m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: case <-m.ctx.Done(): return } } func wrapData(data []byte, pid ProtocolID) ([]byte, error) { // Marshal pbm := new(PBProtocolMessage) pbm.ProtocolID = &pid pbm.Data = data b, err := proto.Marshal(pbm) if err != nil { return nil, err } return b, nil } func unwrapData(data []byte) ([]byte, ProtocolID, error) { // Unmarshal pbm := new(PBProtocolMessage) err := proto.Unmarshal(data, pbm) if err != nil { return nil, 0, err } return pbm.GetData(), pbm.GetProtocolID(), nil }