Commit fa2f7eea authored by Marten Seemann's avatar Marten Seemann

return the context cancelation error

parent 653fbe64
...@@ -43,30 +43,23 @@ var _ cs.Transport = &Transport{} ...@@ -43,30 +43,23 @@ var _ cs.Transport = &Transport{}
// SecureInbound runs the TLS handshake as a server. // SecureInbound runs the TLS handshake as a server.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) { func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) {
serv := tls.Server(insecure, t.identity.Config) serv := tls.Server(insecure, t.identity.Config)
return t.handshake(ctx, insecure, serv)
// There's no way to pass a context to tls.Conn.Handshake().
// See https://github.com/golang/go/issues/18482.
// Close the connection instead.
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
case <-ctx.Done():
insecure.Close()
}
}()
if err := serv.Handshake(); err != nil {
return nil, err
}
return t.setupConn(serv)
} }
// SecureOutbound runs the TLS handshake as a client. // SecureOutbound runs the TLS handshake as a client.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) { func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) {
cl := tls.Client(insecure, t.identity.ConfigForPeer(p)) cl := tls.Client(insecure, t.identity.ConfigForPeer(p))
return t.handshake(ctx, insecure, cl)
}
func (t *Transport) handshake(
ctx context.Context,
// in Go 1.10, we need to close the underlying net.Conn
// in Go 1.11 this was fixed, and tls.Conn.Close() works as well
insecure net.Conn,
tlsConn *tls.Conn,
) (cs.Conn, error) {
errChan := make(chan error, 2)
// There's no way to pass a context to tls.Conn.Handshake(). // There's no way to pass a context to tls.Conn.Handshake().
// See https://github.com/golang/go/issues/18482. // See https://github.com/golang/go/issues/18482.
// Close the connection instead. // Close the connection instead.
...@@ -76,14 +69,23 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee ...@@ -76,14 +69,23 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee
select { select {
case <-done: case <-done:
case <-ctx.Done(): case <-ctx.Done():
errChan <- ctx.Err()
insecure.Close() insecure.Close()
} }
}() }()
if err := cl.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
return nil, err // if the context was canceled, return the context error
errChan <- err
return nil, <-errChan
}
conn, err := t.setupConn(tlsConn)
if err != nil {
// if the context was canceled, return the context error
errChan <- err
return nil, <-errChan
} }
return t.setupConn(cl) return conn, nil
} }
func (t *Transport) setupConn(tlsConn *tls.Conn) (cs.Conn, error) { func (t *Transport) setupConn(tlsConn *tls.Conn) (cs.Conn, error) {
......
...@@ -112,8 +112,7 @@ var _ = Describe("Transport", func() { ...@@ -112,8 +112,7 @@ var _ = Describe("Transport", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID)
Expect(err).To(HaveOccurred()) Expect(err).To(MatchError(context.Canceled))
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
}) })
It("fails when the context of the incoming connection is canceled", func() { It("fails when the context of the incoming connection is canceled", func() {
...@@ -129,8 +128,7 @@ var _ = Describe("Transport", func() { ...@@ -129,8 +128,7 @@ var _ = Describe("Transport", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err := serverTransport.SecureInbound(ctx, serverInsecureConn) _, err := serverTransport.SecureInbound(ctx, serverInsecureConn)
Expect(err).To(HaveOccurred()) Expect(err).To(MatchError(context.Canceled))
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
}() }()
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment