Unverified Commit 5ddf5deb authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #19 from libp2p/fix/close-on-err

improve correctness of closing connections on failure
parents 773b63c9 8467c1e7
...@@ -104,6 +104,7 @@ var _ = Describe("Listener", func() { ...@@ -104,6 +104,7 @@ var _ = Describe("Listener", func() {
It("accepts a single connection", func() { It("accepts a single connection", func() {
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
defer ln.Close()
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
sconn, err := ln.Accept() sconn, err := ln.Accept()
...@@ -113,6 +114,7 @@ var _ = Describe("Listener", func() { ...@@ -113,6 +114,7 @@ var _ = Describe("Listener", func() {
It("accepts multiple connections", func() { It("accepts multiple connections", func() {
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
defer ln.Close()
const num = 10 const num = 10
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
...@@ -127,11 +129,15 @@ var _ = Describe("Listener", func() { ...@@ -127,11 +129,15 @@ var _ = Describe("Listener", func() {
const timeout = 200 * time.Millisecond const timeout = 200 * time.Millisecond
tpt.AcceptTimeout = timeout tpt.AcceptTimeout = timeout
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
defer ln.Close()
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred()) if !Expect(err).ToNot(HaveOccurred()) {
return
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer conn.Close()
str, err := conn.OpenStream() str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// start a Read. It will block until the connection is closed // start a Read. It will block until the connection is closed
...@@ -151,10 +157,16 @@ var _ = Describe("Listener", func() { ...@@ -151,10 +157,16 @@ var _ = Describe("Listener", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, _ = ln.Accept() conn, err := ln.Accept()
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
close(done) close(done)
}() }()
_, _ = dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
// make the goroutine return // make the goroutine return
ln.Close() ln.Close()
...@@ -178,6 +190,7 @@ var _ = Describe("Listener", func() { ...@@ -178,6 +190,7 @@ var _ = Describe("Listener", func() {
if err != nil { if err != nil {
return return
} }
conn.Close()
accepted <- conn accepted <- conn
} }
}() }()
...@@ -187,8 +200,14 @@ var _ = Describe("Listener", func() { ...@@ -187,8 +200,14 @@ var _ = Describe("Listener", func() {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred()) if Expect(err).ToNot(HaveOccurred()) {
stream, err := conn.AcceptStream() // wait for conn to be accepted.
if !Expect(err).To(HaveOccurred()) {
stream.Close()
}
conn.Close()
}
wg.Done() wg.Done()
}() }()
} }
...@@ -201,29 +220,40 @@ var _ = Describe("Listener", func() { ...@@ -201,29 +220,40 @@ var _ = Describe("Listener", func() {
It("stops setting up when the more than AcceptQueueLength connections are waiting to get accepted", func() { It("stops setting up when the more than AcceptQueueLength connections are waiting to get accepted", func() {
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
defer ln.Close()
// setup AcceptQueueLength connections, but don't accept any of them // setup AcceptQueueLength connections, but don't accept any of them
dialed := make(chan struct{}, 10*st.AcceptQueueLength) // used as a thread-safe counter dialed := make(chan tpt.Conn, 10*st.AcceptQueueLength) // used as a thread-safe counter
for i := 0; i < st.AcceptQueueLength; i++ { for i := 0; i < st.AcceptQueueLength; i++ {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{} dialed <- conn
}() }()
} }
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength)) Eventually(dialed).Should(HaveLen(st.AcceptQueueLength))
// dial a new connection. This connection should not complete setup, since the queue is full // dial a new connection. This connection should not complete setup, since the queue is full
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
dialed <- struct{}{} dialed <- conn
}() }()
Consistently(dialed).Should(HaveLen(st.AcceptQueueLength)) Consistently(dialed).Should(HaveLen(st.AcceptQueueLength))
// accept a single connection. Now the new connection should be set up, and fill the queue again // accept a single connection. Now the new connection should be set up, and fill the queue again
_, err := ln.Accept() conn, err := ln.Accept()
Expect(err).ToNot(HaveOccurred()) if Expect(err).ToNot(HaveOccurred()) {
conn.Close()
}
Eventually(dialed).Should(HaveLen(st.AcceptQueueLength + 1)) Eventually(dialed).Should(HaveLen(st.AcceptQueueLength + 1))
// Cleanup
for i := 0; i < st.AcceptQueueLength+1; i++ {
if c := <-dialed; c != nil {
c.Close()
}
}
}) })
}) })
...@@ -233,9 +263,12 @@ var _ = Describe("Listener", func() { ...@@ -233,9 +263,12 @@ var _ = Describe("Listener", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := ln.Accept() conn, err := ln.Accept()
Expect(err).To(HaveOccurred()) if Expect(err).To(HaveOccurred()) {
Expect(err.Error()).To(ContainSubstring("use of closed network connection")) Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
} else {
conn.Close()
}
close(done) close(done)
}() }()
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
...@@ -246,15 +279,20 @@ var _ = Describe("Listener", func() { ...@@ -246,15 +279,20 @@ var _ = Describe("Listener", func() {
It("doesn't accept new connections when it is closed", func() { It("doesn't accept new connections when it is closed", func() {
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
Expect(ln.Close()).To(Succeed()) Expect(ln.Close()).To(Succeed())
_, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1))
Expect(err).To(HaveOccurred()) if !Expect(err).To(HaveOccurred()) {
conn.Close()
}
}) })
It("closes incoming connections that have not yet been accepted", func() { It("closes incoming connections that have not yet been accepted", func() {
ln := createListener(defaultUpgrader) ln := createListener(defaultUpgrader)
conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2))
if !Expect(err).ToNot(HaveOccurred()) {
ln.Close()
return
}
Expect(conn.IsClosed()).To(BeFalse()) Expect(conn.IsClosed()).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed()) Expect(ln.Close()).To(Succeed())
Eventually(conn.IsClosed).Should(BeTrue()) Eventually(conn.IsClosed).Should(BeTrue())
}) })
......
...@@ -89,7 +89,7 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma ...@@ -89,7 +89,7 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma
} }
smconn, err := u.setupMuxer(ctx, sconn, p) smconn, err := u.setupMuxer(ctx, sconn, p)
if err != nil { if err != nil {
conn.Close() sconn.Close()
return nil, fmt.Errorf("failed to negotiate security stream multiplexer: %s", err) return nil, fmt.Errorf("failed to negotiate security stream multiplexer: %s", err)
} }
return &transportConn{ return &transportConn{
...@@ -122,6 +122,10 @@ func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, p peer.ID) (sm ...@@ -122,6 +122,10 @@ func (u *Upgrader) setupMuxer(ctx context.Context, conn net.Conn, p peer.ID) (sm
case <-done: case <-done:
return smconn, err return smconn, err
case <-ctx.Done(): case <-ctx.Done():
// interrupt this process
conn.Close()
// wait to finish
<-done
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment