Commit 2e8087cb authored by Steven Allen's avatar Steven Allen Committed by Marten Seemann

make peer verification use a channel

parent 1aaea78d
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"net"
"time" "time"
ic "github.com/libp2p/go-libp2p-crypto" ic "github.com/libp2p/go-libp2p-crypto"
...@@ -20,66 +19,55 @@ const certValidityPeriod = 180 * 24 * time.Hour ...@@ -20,66 +19,55 @@ const certValidityPeriod = 180 * 24 * time.Hour
// Identity is used to secure connections // Identity is used to secure connections
type Identity struct { type Identity struct {
*tls.Config config tls.Config
} }
// NewIdentity creates a new identity // NewIdentity creates a new identity
func NewIdentity( func NewIdentity(privKey ic.PrivKey) (*Identity, error) {
privKey ic.PrivKey,
verifiedPeerCallback func(net.Conn, ic.PubKey),
) (*Identity, error) {
key, cert, err := keyToCertificate(privKey) key, cert, err := keyToCertificate(privKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conf := &tls.Config{ return &Identity{
MinVersion: tls.VersionTLS13, config: tls.Config{
InsecureSkipVerify: true, // This is not insecure here. We will verify the cert chain ourselves. MinVersion: tls.VersionTLS13,
ClientAuth: tls.RequireAnyClientCert, InsecureSkipVerify: true, // This is not insecure here. We will verify the cert chain ourselves.
Certificates: []tls.Certificate{{ ClientAuth: tls.RequireAnyClientCert,
Certificate: [][]byte{cert.Raw}, Certificates: []tls.Certificate{{
PrivateKey: key, Certificate: [][]byte{cert.Raw},
}}, PrivateKey: key,
} }},
// When receiving the ClientHello, create a new tls.Config. VerifyPeerCertificate: func(_ [][]byte, _ [][]*x509.Certificate) error {
// This new config has a VerifyPeerCertificate set, which calls the verifiedPeerCallback panic("tls config not specialized for peer")
// when we derived the remote's public key from its certificate chain. },
conf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { },
c := conf.Clone() }, nil
c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { }
chain := make([]*x509.Certificate, len(rawCerts))
for i := 0; i < len(rawCerts); i++ { // ConfigForAny is a short-hand for ConfigForPeer("").
cert, err := x509.ParseCertificate(rawCerts[i]) func (i *Identity) ConfigForAny() (*tls.Config, <-chan ic.PubKey) {
if err != nil { return i.ConfigForPeer("")
return err
}
chain[i] = cert
}
pubKey, err := getRemotePubKey(chain)
if err != nil {
return err
}
verifiedPeerCallback(ch.Conn, pubKey)
return nil
}
return c, nil
}
return &Identity{conf}, nil
} }
// ConfigForPeer creates a new tls.Config that verifies the peers certificate chain. // ConfigForPeer creates a new single-use tls.Config that verifies the peer's
// It should be used to create a new tls.Config before dialing. // certificate chain and returns the peer's public key via the channel. If the
// peer ID is empty, the returned config will accept any peer.
//
// It should be used to create a new tls.Config before securing either an
// incoming or outgoing connection.
func (i *Identity) ConfigForPeer( func (i *Identity) ConfigForPeer(
remote peer.ID, remote peer.ID,
verifiedPeerCallback func(ic.PubKey), ) (*tls.Config, <-chan ic.PubKey) {
) *tls.Config { keyCh := make(chan ic.PubKey, 1)
// We need to check the peer ID in the VerifyPeerCertificate callback. // We need to check the peer ID in the VerifyPeerCertificate callback.
// The tls.Config it is also used for listening, and we might also have concurrent dials. // The tls.Config it is also used for listening, and we might also have concurrent dials.
// Clone it so we can check for the specific peer ID we're dialing here. // Clone it so we can check for the specific peer ID we're dialing here.
conf := i.Config.Clone() conf := i.config.Clone()
// We're using InsecureSkipVerify, so the verifiedChains parameter will always be empty. // We're using InsecureSkipVerify, so the verifiedChains parameter will always be empty.
// We need to parse the certificates ourselves from the raw certs. // We need to parse the certificates ourselves from the raw certs.
conf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { conf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
defer close(keyCh)
chain := make([]*x509.Certificate, len(rawCerts)) chain := make([]*x509.Certificate, len(rawCerts))
for i := 0; i < len(rawCerts); i++ { for i := 0; i < len(rawCerts); i++ {
cert, err := x509.ParseCertificate(rawCerts[i]) cert, err := x509.ParseCertificate(rawCerts[i])
...@@ -88,17 +76,18 @@ func (i *Identity) ConfigForPeer( ...@@ -88,17 +76,18 @@ func (i *Identity) ConfigForPeer(
} }
chain[i] = cert chain[i] = cert
} }
pubKey, err := getRemotePubKey(chain) pubKey, err := getRemotePubKey(chain)
if err != nil { if err != nil {
return err return err
} }
if !remote.MatchesPublicKey(pubKey) { if remote != "" && !remote.MatchesPublicKey(pubKey) {
return errors.New("peer IDs don't match") return errors.New("peer IDs don't match")
} }
verifiedPeerCallback(pubKey) keyCh <- pubKey
return nil return nil
} }
return conf return conf, keyCh
} }
// getRemotePubKey derives the remote's public key from the certificate chain. // getRemotePubKey derives the remote's public key from the certificate chain.
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"errors" "errors"
"net" "net"
"os" "os"
"sync"
cs "github.com/libp2p/go-conn-security" cs "github.com/libp2p/go-conn-security"
ci "github.com/libp2p/go-libp2p-crypto" ci "github.com/libp2p/go-libp2p-crypto"
...@@ -29,9 +28,6 @@ type Transport struct { ...@@ -29,9 +28,6 @@ type Transport struct {
localPeer peer.ID localPeer peer.ID
privKey ci.PrivKey privKey ci.PrivKey
activeMutex sync.Mutex
active map[net.Conn]ic.PubKey
} }
// New creates a TLS encrypted transport // New creates a TLS encrypted transport
...@@ -43,14 +39,9 @@ func New(key ci.PrivKey) (*Transport, error) { ...@@ -43,14 +39,9 @@ func New(key ci.PrivKey) (*Transport, error) {
t := &Transport{ t := &Transport{
localPeer: id, localPeer: id,
privKey: key, privKey: key,
active: make(map[net.Conn]ic.PubKey),
} }
identity, err := NewIdentity(key, func(conn net.Conn, pubKey ic.PubKey) { identity, err := NewIdentity(key)
t.activeMutex.Lock()
t.active[conn] = pubKey
t.activeMutex.Unlock()
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -62,15 +53,8 @@ var _ cs.Transport = &Transport{} ...@@ -62,15 +53,8 @@ var _ cs.Transport = &Transport{}
// SecureInbound runs the TLS handshake as a server. // SecureInbound runs the TLS handshake as a server.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) { func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) {
defer func() { config, keyCh := t.identity.ConfigForAny()
t.activeMutex.Lock() return t.handshake(ctx, tls.Server(insecure, config), keyCh)
// only contains this connection if we successfully derived the client's key
delete(t.active, insecure)
t.activeMutex.Unlock()
}()
serv := tls.Server(insecure, t.identity.Config)
return t.handshake(ctx, insecure, serv)
} }
// SecureOutbound runs the TLS handshake as a client. // SecureOutbound runs the TLS handshake as a client.
...@@ -81,19 +65,14 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Co ...@@ -81,19 +65,14 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Co
// If the handshake fails, the server will close the connection. The client will // If the handshake fails, the server will close the connection. The client will
// notice this after 1 RTT when calling Read. // notice this after 1 RTT when calling Read.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) { func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) {
verifiedCallback := func(pubKey ic.PubKey) { config, keyCh := t.identity.ConfigForPeer(p)
t.activeMutex.Lock() return t.handshake(ctx, tls.Client(insecure, config), keyCh)
t.active[insecure] = pubKey
t.activeMutex.Unlock()
}
cl := tls.Client(insecure, t.identity.ConfigForPeer(p, verifiedCallback))
return t.handshake(ctx, insecure, cl)
} }
func (t *Transport) handshake( func (t *Transport) handshake(
ctx context.Context, ctx context.Context,
insecure net.Conn,
tlsConn *tls.Conn, tlsConn *tls.Conn,
keyCh <-chan ci.PubKey,
) (cs.Conn, error) { ) (cs.Conn, error) {
// There's no way to pass a context to tls.Conn.Handshake(). // There's no way to pass a context to tls.Conn.Handshake().
// See https://github.com/golang/go/issues/18482. // See https://github.com/golang/go/issues/18482.
...@@ -120,7 +99,15 @@ func (t *Transport) handshake( ...@@ -120,7 +99,15 @@ func (t *Transport) handshake(
} }
return nil, err return nil, err
} }
conn, err := t.setupConn(insecure, tlsConn)
// Should be ready by this point, don't block.
var remotePubKey ic.PubKey
select {
case remotePubKey = <-keyCh:
default:
}
conn, err := t.setupConn(tlsConn, remotePubKey)
if err != nil { if err != nil {
// if the context was canceled, return the context error // if the context was canceled, return the context error
if ctxErr := ctx.Err(); ctxErr != nil { if ctxErr := ctx.Err(); ctxErr != nil {
...@@ -131,11 +118,7 @@ func (t *Transport) handshake( ...@@ -131,11 +118,7 @@ func (t *Transport) handshake(
return conn, nil return conn, nil
} }
func (t *Transport) setupConn(insecure net.Conn, tlsConn *tls.Conn) (cs.Conn, error) { func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ic.PubKey) (cs.Conn, error) {
t.activeMutex.Lock()
remotePubKey := t.active[insecure]
t.activeMutex.Unlock()
if remotePubKey == nil { if remotePubKey == nil {
return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set") return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set")
} }
......
...@@ -180,15 +180,15 @@ var _ = Describe("Transport", func() { ...@@ -180,15 +180,15 @@ var _ = Describe("Transport", func() {
Context("invalid certificates", func() { Context("invalid certificates", func() {
invalidateCertChain := func(identity *Identity) { invalidateCertChain := func(identity *Identity) {
switch identity.Config.Certificates[0].PrivateKey.(type) { switch identity.config.Certificates[0].PrivateKey.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
identity.Config.Certificates[0].PrivateKey = key identity.config.Certificates[0].PrivateKey = key
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
identity.Config.Certificates[0].PrivateKey = key identity.config.Certificates[0].PrivateKey = key
default: default:
Fail("unexpected private key type") Fail("unexpected private key type")
} }
...@@ -206,7 +206,7 @@ var _ = Describe("Transport", func() { ...@@ -206,7 +206,7 @@ var _ = Describe("Transport", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key2) cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key2)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
identity.Config.Certificates = []tls.Certificate{{ identity.config.Certificates = []tls.Certificate{{
Certificate: [][]byte{cert2DER, cert1DER}, Certificate: [][]byte{cert2DER, cert1DER},
PrivateKey: key2, PrivateKey: key2,
}} }}
...@@ -222,7 +222,7 @@ var _ = Describe("Transport", func() { ...@@ -222,7 +222,7 @@ var _ = Describe("Transport", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
identity.Config.Certificates = []tls.Certificate{{ identity.config.Certificates = []tls.Certificate{{
Certificate: [][]byte{cert}, Certificate: [][]byte{cert},
PrivateKey: key, PrivateKey: key,
}} }}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment