Commit 0cdcc614 authored by Steven Allen's avatar Steven Allen

feat: close transports that implement io.Closer

This way, transports with shared resources (e.g., reused sockets) can clean them
up.

fixes https://github.com/libp2p/go-libp2p/issues/999
parent 3563ed1f
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
...@@ -176,6 +177,27 @@ func (s *Swarm) teardown() error { ...@@ -176,6 +177,27 @@ func (s *Swarm) teardown() error {
// Wait for everything to finish. // Wait for everything to finish.
s.refs.Wait() s.refs.Wait()
// Now close out any transports (if necessary). Do this after closing
// all connections/listeners.
s.transports.Lock()
transports := s.transports.m
s.transports.m = nil
s.transports.Unlock()
var wg sync.WaitGroup
for _, t := range transports {
if closer, ok := t.(io.Closer); ok {
wg.Add(1)
go func(c io.Closer) {
defer wg.Done()
if err := closer.Close(); err != nil {
log.Errorf("error when closing down transport %T: %s", c, err)
}
}(closer)
}
}
wg.Wait()
return nil return nil
} }
......
...@@ -40,7 +40,17 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { ...@@ -40,7 +40,17 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { func (s *Swarm) AddListenAddr(a ma.Multiaddr) error {
tpt := s.TransportForListening(a) tpt := s.TransportForListening(a)
if tpt == nil { if tpt == nil {
return ErrNoTransport // TransportForListening will return nil if either:
// 1. No transport has been registered.
// 2. We're closed (so we've nulled out the transport map.
//
// Distinguish between these two cases to avoid confusing users.
select {
case <-s.proc.Closing():
return ErrSwarmClosed
default:
return ErrNoTransport
}
} }
list, err := tpt.Listen(a) list, err := tpt.Listen(a)
......
...@@ -20,7 +20,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { ...@@ -20,7 +20,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport {
s.transports.RLock() s.transports.RLock()
defer s.transports.RUnlock() defer s.transports.RUnlock()
if len(s.transports.m) == 0 { if len(s.transports.m) == 0 {
log.Error("you have no transports configured") // make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil return nil
} }
...@@ -48,7 +51,10 @@ func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport { ...@@ -48,7 +51,10 @@ func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport {
s.transports.RLock() s.transports.RLock()
defer s.transports.RUnlock() defer s.transports.RUnlock()
if len(s.transports.m) == 0 { if len(s.transports.m) == 0 {
log.Error("you have no transports configured") // make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil return nil
} }
...@@ -77,6 +83,9 @@ func (s *Swarm) AddTransport(t transport.Transport) error { ...@@ -77,6 +83,9 @@ func (s *Swarm) AddTransport(t transport.Transport) error {
s.transports.Lock() s.transports.Lock()
defer s.transports.Unlock() defer s.transports.Unlock()
if s.transports.m == nil {
return ErrSwarmClosed
}
var registered []string var registered []string
for _, p := range protocols { for _, p := range protocols {
if _, ok := s.transports.m[p]; ok { if _, ok := s.transports.m[p]; ok {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"testing" "testing"
swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing" swarmt "github.com/libp2p/go-libp2p-swarm/testing"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
...@@ -14,6 +15,7 @@ import ( ...@@ -14,6 +15,7 @@ import (
type dummyTransport struct { type dummyTransport struct {
protocols []int protocols []int
proxy bool proxy bool
closed bool
} }
func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
...@@ -35,13 +37,44 @@ func (dt *dummyTransport) Proxy() bool { ...@@ -35,13 +37,44 @@ func (dt *dummyTransport) Proxy() bool {
func (dt *dummyTransport) Protocols() []int { func (dt *dummyTransport) Protocols() []int {
return dt.protocols return dt.protocols
} }
func (dt *dummyTransport) Close() error {
dt.closed = true
return nil
}
func TestUselessTransport(t *testing.T) { func TestUselessTransport(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
swarm := swarmt.GenSwarm(t, ctx) s := swarmt.GenSwarm(t, ctx)
err := swarm.AddTransport(new(dummyTransport)) err := s.AddTransport(new(dummyTransport))
if err == nil { if err == nil {
t.Fatal("adding a transport that supports no protocols should have failed") t.Fatal("adding a transport that supports no protocols should have failed")
} }
} }
func TestTransportClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != nil {
t.Fatal(err)
}
_ = s.Close()
if !tpt.closed {
t.Fatal("expected transport to be closed")
}
}
func TestTransportAfterClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
s.Close()
tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed {
t.Fatal("expected swarm closed error, got: ", err)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment