From c2a228f650c2a8d62ca71661d0e53d96c9cbed10 Mon Sep 17 00:00:00 2001
From: Juan Batiz-Benet <juan@benet.ai>
Date: Sun, 19 Oct 2014 02:26:10 -0700
Subject: [PATCH] use ContextCloser better (listener fix)

---
 net/conn/conn.go             | 29 ++++++++++++++++-------------
 net/conn/conn_test.go        |  3 +++
 net/conn/dial_test.go        |  1 +
 net/conn/handshake.go        |  2 +-
 net/conn/secure_conn_test.go |  3 +++
 5 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/net/conn/conn.go b/net/conn/conn.go
index 5d67c1e22..a7157a796 100644
--- a/net/conn/conn.go
+++ b/net/conn/conn.go
@@ -65,8 +65,16 @@ func newSingleConn(ctx context.Context, local, remote *peer.Peer,
 	log.Info("newSingleConn: %v to %v", local, remote)
 
 	// setup the various io goroutines
-	go conn.msgio.outgoing.WriteTo(maconn)
-	go conn.msgio.incoming.ReadFrom(maconn, MaxMessageSize)
+	go func() {
+		conn.Children().Add(1)
+		conn.msgio.outgoing.WriteTo(maconn)
+		conn.Children().Done()
+	}()
+	go func() {
+		conn.Children().Add(1)
+		conn.msgio.incoming.ReadFrom(maconn, MaxMessageSize)
+		conn.Children().Done()
+	}()
 
 	// version handshake
 	ctxT, _ := context.WithTimeout(ctx, HandshakeTimeout)
@@ -216,16 +224,9 @@ func (l *listener) close() error {
 	return l.Listener.Close()
 }
 
-func (l *listener) isClosed() bool {
-	select {
-	case <-l.Closed():
-		return true
-	default:
-		return false
-	}
-}
-
 func (l *listener) listen() {
+	l.Children().Add(1)
+	defer l.Children().Done()
 
 	// handle at most chansize concurrent handshakes
 	sem := make(chan struct{}, l.chansize)
@@ -254,9 +255,11 @@ func (l *listener) listen() {
 		maconn, err := l.Listener.Accept()
 		if err != nil {
 
-			// if cancel is nil we're closed.
-			if l.isClosed() {
+			// if closing, we should exit.
+			select {
+			case <-l.Closing():
 				return // done.
+			default:
 			}
 
 			log.Error("Failed to accept connection: %v", err)
diff --git a/net/conn/conn_test.go b/net/conn/conn_test.go
index f86192a8d..86da6875b 100644
--- a/net/conn/conn_test.go
+++ b/net/conn/conn_test.go
@@ -13,6 +13,7 @@ import (
 )
 
 func TestClose(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	ctx, cancel := context.WithCancel(context.Background())
 	c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/1234", "/ip4/127.0.0.1/tcp/2345")
@@ -45,6 +46,7 @@ func TestClose(t *testing.T) {
 }
 
 func TestCancel(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	ctx, cancel := context.WithCancel(context.Background())
 	c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/1234", "/ip4/127.0.0.1/tcp/2345")
@@ -78,6 +80,7 @@ func TestCancel(t *testing.T) {
 }
 
 func TestCloseLeak(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	var wg sync.WaitGroup
 
diff --git a/net/conn/dial_test.go b/net/conn/dial_test.go
index 09f99d799..f5942d1f9 100644
--- a/net/conn/dial_test.go
+++ b/net/conn/dial_test.go
@@ -93,6 +93,7 @@ func setupConn(t *testing.T, ctx context.Context, a1, a2 string) (a, b Conn) {
 }
 
 func TestDialer(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	p1, err := setupPeer("/ip4/127.0.0.1/tcp/1234")
 	if err != nil {
diff --git a/net/conn/handshake.go b/net/conn/handshake.go
index 093233522..633c8d5f7 100644
--- a/net/conn/handshake.go
+++ b/net/conn/handshake.go
@@ -31,7 +31,7 @@ func VersionHandshake(ctx context.Context, c Conn) error {
 	case <-ctx.Done():
 		return ctx.Err()
 
-	case <-c.Closed():
+	case <-c.Closing():
 		return errors.New("remote closed connection during version exchange")
 
 	case data, ok := <-c.In():
diff --git a/net/conn/secure_conn_test.go b/net/conn/secure_conn_test.go
index 4e6db1ea4..5a78870d0 100644
--- a/net/conn/secure_conn_test.go
+++ b/net/conn/secure_conn_test.go
@@ -29,6 +29,7 @@ func setupSecureConn(t *testing.T, c Conn) Conn {
 }
 
 func TestSecureClose(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	ctx, cancel := context.WithCancel(context.Background())
 	c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/1234", "/ip4/127.0.0.1/tcp/2345")
@@ -64,6 +65,7 @@ func TestSecureClose(t *testing.T) {
 }
 
 func TestSecureCancel(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	ctx, cancel := context.WithCancel(context.Background())
 	c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/1234", "/ip4/127.0.0.1/tcp/2345")
@@ -100,6 +102,7 @@ func TestSecureCancel(t *testing.T) {
 }
 
 func TestSecureCloseLeak(t *testing.T) {
+	// t.Skip("Skipping in favor of another test")
 
 	var wg sync.WaitGroup
 
-- 
GitLab