Commit 63d6ee6d authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

multiconn in swarm

parent 58fdcad9
...@@ -99,7 +99,7 @@ func GenerateEKeyPair(curveName string) ([]byte, GenSharedKey, error) { ...@@ -99,7 +99,7 @@ func GenerateEKeyPair(curveName string) ([]byte, GenSharedKey, error) {
} }
pubKey := elliptic.Marshal(curve, x, y) pubKey := elliptic.Marshal(curve, x, y)
log.Debug("GenerateEKeyPair %d", len(pubKey)) // log.Debug("GenerateEKeyPair %d", len(pubKey))
done := func(theirPub []byte) ([]byte, error) { done := func(theirPub []byte) ([]byte, error) {
// Verify and unpack node's public key. // Verify and unpack node's public key.
......
...@@ -38,7 +38,7 @@ type MultiConn struct { ...@@ -38,7 +38,7 @@ type MultiConn struct {
} }
// NewMultiConn constructs a new connection // NewMultiConn constructs a new connection
func NewMultiConn(ctx context.Context, local, remote *peer.Peer, conns []Conn) (Conn, error) { func NewMultiConn(ctx context.Context, local, remote *peer.Peer, conns []Conn) (*MultiConn, error) {
c := &MultiConn{ c := &MultiConn{
local: local, local: local,
...@@ -53,13 +53,10 @@ func NewMultiConn(ctx context.Context, local, remote *peer.Peer, conns []Conn) ( ...@@ -53,13 +53,10 @@ func NewMultiConn(ctx context.Context, local, remote *peer.Peer, conns []Conn) (
// must happen before Adds / fanOut // must happen before Adds / fanOut
c.ContextCloser = NewContextCloser(ctx, c.close) c.ContextCloser = NewContextCloser(ctx, c.close)
log.Info("adding %d...", len(conns))
if conns != nil && len(conns) > 0 { if conns != nil && len(conns) > 0 {
c.Add(conns...) c.Add(conns...)
} }
go c.fanOut() go c.fanOut()
log.Info("newMultiConn: %v to %v", local, remote)
return c, nil return c, nil
} }
...@@ -72,6 +69,9 @@ func (c *MultiConn) Add(conns ...Conn) { ...@@ -72,6 +69,9 @@ func (c *MultiConn) Add(conns ...Conn) {
log.Info("MultiConn: adding %s", c2) log.Info("MultiConn: adding %s", c2)
if c.LocalPeer() != c2.LocalPeer() || c.RemotePeer() != c2.RemotePeer() { if c.LocalPeer() != c2.LocalPeer() || c.RemotePeer() != c2.RemotePeer() {
log.Error("%s", c2) log.Error("%s", c2)
c.Unlock() // ok to unlock (to log). panicing.
log.Error("%s", c)
c.Lock() // gotta relock to avoid lock panic from deferring.
panic("connection addresses mismatch") panic("connection addresses mismatch")
} }
...@@ -102,12 +102,12 @@ func (c *MultiConn) Remove(conns ...Conn) { ...@@ -102,12 +102,12 @@ func (c *MultiConn) Remove(conns ...Conn) {
} }
// close all in parallel, but wait for all to be done closing. // close all in parallel, but wait for all to be done closing.
CloseConns(conns) CloseConns(conns...)
} }
// CloseConns closes multiple connections in parallel, and waits for all // CloseConns closes multiple connections in parallel, and waits for all
// to finish closing. // to finish closing.
func CloseConns(conns []Conn) { func CloseConns(conns ...Conn) {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, child := range conns { for _, child := range conns {
...@@ -204,7 +204,7 @@ func (c *MultiConn) close() error { ...@@ -204,7 +204,7 @@ func (c *MultiConn) close() error {
c.RUnlock() c.RUnlock()
// close underlying connections // close underlying connections
CloseConns(conns) CloseConns(conns...)
return nil return nil
} }
......
...@@ -150,7 +150,7 @@ func setupMultiConns(t *testing.T, ctx context.Context) (a, b *MultiConn) { ...@@ -150,7 +150,7 @@ func setupMultiConns(t *testing.T, ctx context.Context) (a, b *MultiConn) {
p2l.Close() p2l.Close()
log.Info("did you make multiconns?") log.Info("did you make multiconns?")
return c1.(*MultiConn), c2.(*MultiConn) return c1, c2
} }
func TestMulticonnSend(t *testing.T) { func TestMulticonnSend(t *testing.T) {
......
...@@ -36,7 +36,7 @@ func (s *Swarm) listen() error { ...@@ -36,7 +36,7 @@ func (s *Swarm) listen() error {
// Listen for new connections on the given multiaddr // Listen for new connections on the given multiaddr
func (s *Swarm) connListen(maddr ma.Multiaddr) error { func (s *Swarm) connListen(maddr ma.Multiaddr) error {
list, err := conn.Listen(s.ctx, maddr, s.local, s.peers) list, err := conn.Listen(s.Context(), maddr, s.local, s.peers)
if err != nil { if err != nil {
return err return err
} }
...@@ -50,13 +50,19 @@ func (s *Swarm) connListen(maddr ma.Multiaddr) error { ...@@ -50,13 +50,19 @@ func (s *Swarm) connListen(maddr ma.Multiaddr) error {
s.listeners = append(s.listeners, list) s.listeners = append(s.listeners, list)
// Accept and handle new connections on this listener until it errors // Accept and handle new connections on this listener until it errors
// this listener is a child.
s.Children().Add(1)
go func() { go func() {
defer s.Children().Done()
for { for {
select { select {
case <-s.ctx.Done(): case <-s.Closing():
return return
case conn := <-list.Accept(): case conn := <-list.Accept():
// handler also a child.
s.Children().Add(1)
go s.handleIncomingConn(conn) go s.handleIncomingConn(conn)
} }
} }
...@@ -67,6 +73,8 @@ func (s *Swarm) connListen(maddr ma.Multiaddr) error { ...@@ -67,6 +73,8 @@ func (s *Swarm) connListen(maddr ma.Multiaddr) error {
// Handle getting ID from this peer, handshake, and adding it into the map // Handle getting ID from this peer, handshake, and adding it into the map
func (s *Swarm) handleIncomingConn(nconn conn.Conn) { func (s *Swarm) handleIncomingConn(nconn conn.Conn) {
// this handler is a child. added by caller.
defer s.Children().Done()
// Setup the new connection // Setup the new connection
_, err := s.connSetup(nconn) _, err := s.connSetup(nconn)
...@@ -77,7 +85,7 @@ func (s *Swarm) handleIncomingConn(nconn conn.Conn) { ...@@ -77,7 +85,7 @@ func (s *Swarm) handleIncomingConn(nconn conn.Conn) {
} }
// connSetup adds the passed in connection to its peerMap and starts // connSetup adds the passed in connection to its peerMap and starts
// the fanIn routine for that connection // the fanInSingle routine for that connection
func (s *Swarm) connSetup(c conn.Conn) (conn.Conn, error) { func (s *Swarm) connSetup(c conn.Conn) (conn.Conn, error) {
if c == nil { if c == nil {
return nil, errors.New("Tried to start nil connection.") return nil, errors.New("Tried to start nil connection.")
...@@ -93,28 +101,44 @@ func (s *Swarm) connSetup(c conn.Conn) (conn.Conn, error) { ...@@ -93,28 +101,44 @@ func (s *Swarm) connSetup(c conn.Conn) (conn.Conn, error) {
// add to conns // add to conns
s.connsLock.Lock() s.connsLock.Lock()
if c2, ok := s.conns[c.RemotePeer().Key()]; ok {
log.Debug("Conn already open!") mc, ok := s.conns[c.RemotePeer().Key()]
if !ok {
// multiconn doesn't exist, make a new one.
conns := []conn.Conn{c}
mc, err := conn.NewMultiConn(s.Context(), s.local, c.RemotePeer(), conns)
if err != nil {
log.Error("error creating multiconn: %s", err)
c.Close()
return nil, err
}
s.conns[c.RemotePeer().Key()] = mc
s.connsLock.Unlock() s.connsLock.Unlock()
c.Close() log.Debug("added new multiconn: %s", mc)
return c2, nil // not error anymore, use existing conn. } else {
// return ErrAlreadyOpen s.connsLock.Unlock() // unlock before adding new conn
mc.Add(c)
log.Debug("multiconn found: %s", mc)
} }
s.conns[c.RemotePeer().Key()] = c
log.Debug("Added conn to map!") log.Debug("multiconn added new conn %s", c)
s.connsLock.Unlock()
// kick off reader goroutine // kick off reader goroutine
go s.fanIn(c) go s.fanInSingle(c)
return c, nil return c, nil
} }
// Handles the unwrapping + sending of messages to the right connection. // Handles the unwrapping + sending of messages to the right connection.
func (s *Swarm) fanOut() { func (s *Swarm) fanOut() {
s.Children().Add(1)
defer s.Children().Done()
for { for {
select { select {
case <-s.ctx.Done(): case <-s.Closing():
return // told to close. return // told to close.
case msg, ok := <-s.Outgoing: case msg, ok := <-s.Outgoing:
...@@ -127,9 +151,9 @@ func (s *Swarm) fanOut() { ...@@ -127,9 +151,9 @@ func (s *Swarm) fanOut() {
s.connsLock.RUnlock() s.connsLock.RUnlock()
if !found { if !found {
e := fmt.Errorf("Sent msg to peer without open conn: %v", e := fmt.Errorf("Sent msg to peer without open conn: %v", msg.Peer())
msg.Peer)
s.errChan <- e s.errChan <- e
log.Error("%s", e)
continue continue
} }
...@@ -143,30 +167,37 @@ func (s *Swarm) fanOut() { ...@@ -143,30 +167,37 @@ func (s *Swarm) fanOut() {
// Handles the receiving + wrapping of messages, per conn. // Handles the receiving + wrapping of messages, per conn.
// Consider using reflect.Select with one goroutine instead of n. // Consider using reflect.Select with one goroutine instead of n.
func (s *Swarm) fanIn(c conn.Conn) { func (s *Swarm) fanInSingle(c conn.Conn) {
s.Children().Add(1)
c.Children().Add(1) // child of Conn as well.
// cleanup all data associated with this child Connection.
defer func() {
// remove it from the map.
s.connsLock.Lock()
delete(s.conns, c.RemotePeer().Key())
s.connsLock.Unlock()
s.Children().Done()
c.Children().Done() // child of Conn as well.
}()
for { for {
select { select {
case <-s.ctx.Done(): case <-s.Closing(): // Swarm closing
// close Conn. return
c.Close()
goto out case <-c.Closing(): // Conn closing
return
case data, ok := <-c.In(): case data, ok := <-c.In():
if !ok { if !ok {
e := fmt.Errorf("Error retrieving from conn: %v", c.RemotePeer()) return // channel closed.
s.errChan <- e
goto out
} }
// log.Debug("[peer: %s] Received message [from = %s]", s.local, c.Peer) // log.Debug("[peer: %s] Received message [from = %s]", s.local, c.Peer)
s.Incoming <- msg.New(c.RemotePeer(), data) s.Incoming <- msg.New(c.RemotePeer(), data)
} }
} }
out:
s.connsLock.Lock()
delete(s.conns, c.RemotePeer().Key())
s.connsLock.Unlock()
} }
// Commenting out because it's platform specific // Commenting out because it's platform specific
......
...@@ -32,7 +32,6 @@ func TestSimultOpen(t *testing.T) { ...@@ -32,7 +32,6 @@ func TestSimultOpen(t *testing.T) {
if _, err := s.Dial(cp); err != nil { if _, err := s.Dial(cp); err != nil {
t.Fatal("error swarm dialing to peer", err) t.Fatal("error swarm dialing to peer", err)
} }
log.Info("done?!?")
wg.Done() wg.Done()
} }
......
...@@ -56,48 +56,42 @@ type Swarm struct { ...@@ -56,48 +56,42 @@ type Swarm struct {
errChan chan error errChan chan error
// conns are the open connections the swarm is handling. // conns are the open connections the swarm is handling.
conns conn.Map // these are MultiConns, which multiplex multiple separate underlying Conns.
conns conn.MultiConnMap
connsLock sync.RWMutex connsLock sync.RWMutex
// listeners for each network address // listeners for each network address
listeners []conn.Listener listeners []conn.Listener
// cancel is an internal function used to stop the Swarm's processing. // ContextCloser
cancel context.CancelFunc conn.ContextCloser
ctx context.Context
} }
// NewSwarm constructs a Swarm, with a Chan. // NewSwarm constructs a Swarm, with a Chan.
func NewSwarm(ctx context.Context, local *peer.Peer, ps peer.Peerstore) (*Swarm, error) { func NewSwarm(ctx context.Context, local *peer.Peer, ps peer.Peerstore) (*Swarm, error) {
s := &Swarm{ s := &Swarm{
Pipe: msg.NewPipe(10), Pipe: msg.NewPipe(10),
conns: conn.Map{}, conns: conn.MultiConnMap{},
local: local, local: local,
peers: ps, peers: ps,
errChan: make(chan error, 100), errChan: make(chan error, 100),
} }
s.ctx, s.cancel = context.WithCancel(ctx) // ContextCloser for proper child management.
s.ContextCloser = conn.NewContextCloser(ctx, s.close)
go s.fanOut() go s.fanOut()
return s, s.listen() return s, s.listen()
} }
// Close stops a swarm. // close stops a swarm. It's the underlying function called by ContextCloser
func (s *Swarm) Close() error { func (s *Swarm) close() error {
if s.cancel == nil {
return errors.New("Swarm already closed.")
}
// issue cancel for the context
s.cancel()
// set cancel to nil to prevent calling Close again, and signal to Listeners
s.cancel = nil
// close listeners // close listeners
for _, list := range s.listeners { for _, list := range s.listeners {
list.Close() list.Close()
} }
// close connections
conn.CloseConns(s.Connections()...)
return nil return nil
} }
...@@ -132,7 +126,7 @@ func (s *Swarm) Dial(peer *peer.Peer) (conn.Conn, error) { ...@@ -132,7 +126,7 @@ func (s *Swarm) Dial(peer *peer.Peer) (conn.Conn, error) {
Peerstore: s.peers, Peerstore: s.peers,
} }
c, err = d.Dial(s.ctx, "tcp", peer) c, err = d.Dial(s.Context(), "tcp", peer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -158,6 +152,19 @@ func (s *Swarm) GetConnection(pid peer.ID) conn.Conn { ...@@ -158,6 +152,19 @@ func (s *Swarm) GetConnection(pid peer.ID) conn.Conn {
return c return c
} }
// Connections returns a slice of all connections.
func (s *Swarm) Connections() []conn.Conn {
s.connsLock.RLock()
conns := make([]conn.Conn, 0, len(s.conns))
for _, c := range s.conns {
conns = append(conns, c)
}
s.connsLock.RUnlock()
return conns
}
// CloseConnection removes a given peer from swarm + closes the connection // CloseConnection removes a given peer from swarm + closes the connection
func (s *Swarm) CloseConnection(p *peer.Peer) error { func (s *Swarm) CloseConnection(p *peer.Peer) error {
c := s.GetConnection(p.ID) c := s.GetConnection(p.ID)
......
...@@ -85,7 +85,11 @@ func SubtestSwarm(t *testing.T, addrs []string, MsgNum int) { ...@@ -85,7 +85,11 @@ func SubtestSwarm(t *testing.T, addrs []string, MsgNum int) {
var wg sync.WaitGroup var wg sync.WaitGroup
connect := func(s *Swarm, dst *peer.Peer) { connect := func(s *Swarm, dst *peer.Peer) {
// copy for other peer // copy for other peer
cp := &peer.Peer{ID: dst.ID}
cp, err := s.peers.Get(dst.ID)
if err != nil {
cp = &peer.Peer{ID: dst.ID}
}
cp.AddAddress(dst.Addresses[0]) cp.AddAddress(dst.Addresses[0])
log.Info("SWARM TEST: %s dialing %s", s.local, dst) log.Info("SWARM TEST: %s dialing %s", s.local, dst)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment