diff --git a/swarm_conn.go b/swarm_conn.go index bd531f273fa270791a52d6c4a21da3d5a21c61c2..946936cfc6ff4bb469f7713afed943c81d3f42e5 100644 --- a/swarm_conn.go +++ b/swarm_conn.go @@ -87,12 +87,10 @@ func (c *Conn) doClose() { }() } -func (c *Conn) removeStream(s *Stream) bool { +func (c *Conn) removeStream(s *Stream) { c.streams.Lock() - _, has := c.streams.m[s] delete(c.streams.m, s) c.streams.Unlock() - return has } // listens for new streams. diff --git a/swarm_stream.go b/swarm_stream.go index 74ac62cb05bb4979243297f7648f04b73c946acb..c09b712a527148fd378ce26653c718097061a00e 100644 --- a/swarm_stream.go +++ b/swarm_stream.go @@ -22,6 +22,8 @@ type Stream struct { stream mux.MuxedStream conn *Conn + closeOnce sync.Once + notifyLk sync.Mutex protocol atomic.Value @@ -76,7 +78,7 @@ func (s *Stream) Write(p []byte) (int, error) { // resources. func (s *Stream) Close() error { err := s.stream.Close() - s.remove() + s.closeOnce.Do(s.remove) return err } @@ -84,7 +86,7 @@ func (s *Stream) Close() error { // associated resources. func (s *Stream) Reset() error { err := s.stream.Reset() - s.remove() + s.closeOnce.Do(s.remove) return err } @@ -102,9 +104,7 @@ func (s *Stream) CloseRead() error { } func (s *Stream) remove() { - if !s.conn.removeStream(s) { - return - } + s.conn.removeStream(s) // We *must* do this in a goroutine. This can be called during a // an open notification and will block until that notification is done. diff --git a/swarm_test.go b/swarm_test.go index 81e1412f1105dab8f74e7f43f6bcb19f58bee873..bd2c844fca1084b2e4266c27493c7fa532e086ad 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -440,3 +440,20 @@ func TestNoDial(t *testing.T) { t.Fatal("should have failed with ErrNoConn") } } + +func TestCloseWithOpenStreams(t *testing.T) { + ctx := context.Background() + swarms := makeSwarms(ctx, t, 2) + connectSwarms(t, ctx, swarms) + + s, err := swarms[0].NewStream(ctx, swarms[1].LocalPeer()) + if err != nil { + t.Fatal(err) + } + defer s.Close() + // close swarm before stream. + err = swarms[0].Close() + if err != nil { + t.Fatal(err) + } +}