diff --git a/swarm.go b/swarm.go index 25f45901780e105ac21f646b9d36ce7a9c202fe7..7dd957f230e51ad98d162fe4056e3ac356d34cec 100644 --- a/swarm.go +++ b/swarm.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "strings" "sync" "sync/atomic" @@ -176,6 +177,27 @@ func (s *Swarm) teardown() error { // Wait for everything to finish. s.refs.Wait() + // Now close out any transports (if necessary). Do this after closing + // all connections/listeners. + s.transports.Lock() + transports := s.transports.m + s.transports.m = nil + s.transports.Unlock() + + var wg sync.WaitGroup + for _, t := range transports { + if closer, ok := t.(io.Closer); ok { + wg.Add(1) + go func(c io.Closer) { + defer wg.Done() + if err := closer.Close(); err != nil { + log.Errorf("error when closing down transport %T: %s", c, err) + } + }(closer) + } + } + wg.Wait() + return nil } diff --git a/swarm_listen.go b/swarm_listen.go index ab5c42a3455991d4622b1f41edd3085265d14d8e..5bd1015d9964a14082dc8bc824c7e78691a77d39 100644 --- a/swarm_listen.go +++ b/swarm_listen.go @@ -40,7 +40,17 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { tpt := s.TransportForListening(a) if tpt == nil { - return ErrNoTransport + // TransportForListening will return nil if either: + // 1. No transport has been registered. + // 2. We're closed (so we've nulled out the transport map. + // + // Distinguish between these two cases to avoid confusing users. + select { + case <-s.proc.Closing(): + return ErrSwarmClosed + default: + return ErrNoTransport + } } list, err := tpt.Listen(a) diff --git a/swarm_transport.go b/swarm_transport.go index 307bfe641f1a9abe7ebade7ef44cc5fb57bc3d01..21728ac3b5114b9a99c7f04a4ec1447131722503 100644 --- a/swarm_transport.go +++ b/swarm_transport.go @@ -20,7 +20,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { s.transports.RLock() defer s.transports.RUnlock() if len(s.transports.m) == 0 { - log.Error("you have no transports configured") + // make sure we're not just shutting down. + if s.transports.m != nil { + log.Error("you have no transports configured") + } return nil } @@ -48,7 +51,10 @@ func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport { s.transports.RLock() defer s.transports.RUnlock() if len(s.transports.m) == 0 { - log.Error("you have no transports configured") + // make sure we're not just shutting down. + if s.transports.m != nil { + log.Error("you have no transports configured") + } return nil } @@ -77,6 +83,9 @@ func (s *Swarm) AddTransport(t transport.Transport) error { s.transports.Lock() defer s.transports.Unlock() + if s.transports.m == nil { + return ErrSwarmClosed + } var registered []string for _, p := range protocols { if _, ok := s.transports.m[p]; ok { diff --git a/transport_test.go b/transport_test.go index f6090a6ef88524473d2525cbf5b19f75e0841b63..82225840b9b0947cf38e4ec3adbb9d920b828ba7 100644 --- a/transport_test.go +++ b/transport_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-core/peer" @@ -14,6 +15,7 @@ import ( type dummyTransport struct { protocols []int proxy bool + closed bool } func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { @@ -35,13 +37,44 @@ func (dt *dummyTransport) Proxy() bool { func (dt *dummyTransport) Protocols() []int { return dt.protocols } +func (dt *dummyTransport) Close() error { + dt.closed = true + return nil +} func TestUselessTransport(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - swarm := swarmt.GenSwarm(t, ctx) - err := swarm.AddTransport(new(dummyTransport)) + s := swarmt.GenSwarm(t, ctx) + err := s.AddTransport(new(dummyTransport)) if err == nil { t.Fatal("adding a transport that supports no protocols should have failed") } } + +func TestTransportClose(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := swarmt.GenSwarm(t, ctx) + tpt := &dummyTransport{protocols: []int{1}} + if err := s.AddTransport(tpt); err != nil { + t.Fatal(err) + } + _ = s.Close() + if !tpt.closed { + t.Fatal("expected transport to be closed") + } + +} + +func TestTransportAfterClose(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := swarmt.GenSwarm(t, ctx) + s.Close() + + tpt := &dummyTransport{protocols: []int{1}} + if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { + t.Fatal("expected swarm closed error, got: ", err) + } +}