Commit 1aaea78d authored by Marten Seemann's avatar Marten Seemann

derive and save the server's pub key in tls.Config.VerifyPeerCertificate

parent 9ecd0944
...@@ -69,7 +69,10 @@ func NewIdentity( ...@@ -69,7 +69,10 @@ func NewIdentity(
// ConfigForPeer creates a new tls.Config that verifies the peers certificate chain. // ConfigForPeer creates a new tls.Config that verifies the peers certificate chain.
// It should be used to create a new tls.Config before dialing. // It should be used to create a new tls.Config before dialing.
func (i *Identity) ConfigForPeer(remote peer.ID) *tls.Config { func (i *Identity) ConfigForPeer(
remote peer.ID,
verifiedPeerCallback func(ic.PubKey),
) *tls.Config {
// We need to check the peer ID in the VerifyPeerCertificate callback. // We need to check the peer ID in the VerifyPeerCertificate callback.
// The tls.Config it is also used for listening, and we might also have concurrent dials. // The tls.Config it is also used for listening, and we might also have concurrent dials.
// Clone it so we can check for the specific peer ID we're dialing here. // Clone it so we can check for the specific peer ID we're dialing here.
...@@ -92,16 +95,12 @@ func (i *Identity) ConfigForPeer(remote peer.ID) *tls.Config { ...@@ -92,16 +95,12 @@ func (i *Identity) ConfigForPeer(remote peer.ID) *tls.Config {
if !remote.MatchesPublicKey(pubKey) { if !remote.MatchesPublicKey(pubKey) {
return errors.New("peer IDs don't match") return errors.New("peer IDs don't match")
} }
verifiedPeerCallback(pubKey)
return nil return nil
} }
return conf return conf
} }
// KeyFromChain takes a chain of x509.Certificates and returns the peer's public key.
func KeyFromChain(chain []*x509.Certificate) (ic.PubKey, error) {
return getRemotePubKey(chain)
}
// getRemotePubKey derives the remote's public key from the certificate chain. // getRemotePubKey derives the remote's public key from the certificate chain.
func getRemotePubKey(chain []*x509.Certificate) (ic.PubKey, error) { func getRemotePubKey(chain []*x509.Certificate) (ic.PubKey, error) {
if len(chain) != 1 { if len(chain) != 1 {
......
...@@ -3,6 +3,7 @@ package libp2ptls ...@@ -3,6 +3,7 @@ package libp2ptls
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"os" "os"
"sync" "sync"
...@@ -29,8 +30,8 @@ type Transport struct { ...@@ -29,8 +30,8 @@ type Transport struct {
localPeer peer.ID localPeer peer.ID
privKey ci.PrivKey privKey ci.PrivKey
incomingMutex sync.Mutex activeMutex sync.Mutex
incoming map[net.Conn]ic.PubKey active map[net.Conn]ic.PubKey
} }
// New creates a TLS encrypted transport // New creates a TLS encrypted transport
...@@ -42,9 +43,14 @@ func New(key ci.PrivKey) (*Transport, error) { ...@@ -42,9 +43,14 @@ func New(key ci.PrivKey) (*Transport, error) {
t := &Transport{ t := &Transport{
localPeer: id, localPeer: id,
privKey: key, privKey: key,
incoming: make(map[net.Conn]ic.PubKey), active: make(map[net.Conn]ic.PubKey),
} }
identity, err := NewIdentity(key, t.verifiedPeer)
identity, err := NewIdentity(key, func(conn net.Conn, pubKey ic.PubKey) {
t.activeMutex.Lock()
t.active[conn] = pubKey
t.activeMutex.Unlock()
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -57,10 +63,10 @@ var _ cs.Transport = &Transport{} ...@@ -57,10 +63,10 @@ var _ cs.Transport = &Transport{}
// SecureInbound runs the TLS handshake as a server. // SecureInbound runs the TLS handshake as a server.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) { func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Conn, error) {
defer func() { defer func() {
t.incomingMutex.Lock() t.activeMutex.Lock()
// only contains this connection if we successfully derived the client's key // only contains this connection if we successfully derived the client's key
delete(t.incoming, insecure) delete(t.active, insecure)
t.incomingMutex.Unlock() t.activeMutex.Unlock()
}() }()
serv := tls.Server(insecure, t.identity.Config) serv := tls.Server(insecure, t.identity.Config)
...@@ -75,7 +81,12 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Co ...@@ -75,7 +81,12 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (cs.Co
// If the handshake fails, the server will close the connection. The client will // If the handshake fails, the server will close the connection. The client will
// notice this after 1 RTT when calling Read. // notice this after 1 RTT when calling Read.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) { func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (cs.Conn, error) {
cl := tls.Client(insecure, t.identity.ConfigForPeer(p)) verifiedCallback := func(pubKey ic.PubKey) {
t.activeMutex.Lock()
t.active[insecure] = pubKey
t.activeMutex.Unlock()
}
cl := tls.Client(insecure, t.identity.ConfigForPeer(p, verifiedCallback))
return t.handshake(ctx, insecure, cl) return t.handshake(ctx, insecure, cl)
} }
...@@ -120,26 +131,15 @@ func (t *Transport) handshake( ...@@ -120,26 +131,15 @@ func (t *Transport) handshake(
return conn, nil return conn, nil
} }
func (t *Transport) verifiedPeer(conn net.Conn, pubKey ic.PubKey) {
t.incomingMutex.Lock()
t.incoming[conn] = pubKey
t.incomingMutex.Unlock()
}
func (t *Transport) setupConn(insecure net.Conn, tlsConn *tls.Conn) (cs.Conn, error) { func (t *Transport) setupConn(insecure net.Conn, tlsConn *tls.Conn) (cs.Conn, error) {
t.incomingMutex.Lock() t.activeMutex.Lock()
remotePubKey := t.incoming[insecure] remotePubKey := t.active[insecure]
t.incomingMutex.Unlock() t.activeMutex.Unlock()
// This case only occurs for the client.
// Servers already determined the client's key in the VerifyPeerCertificate callback.
if remotePubKey == nil { if remotePubKey == nil {
var err error return nil, errors.New("go-libp2p-tls BUG: expected remote pub key to be set")
remotePubKey, err = KeyFromChain(tlsConn.ConnectionState().PeerCertificates)
if err != nil {
return nil, err
}
} }
remotePeerID, err := peer.IDFromPublicKey(remotePubKey) remotePeerID, err := peer.IDFromPublicKey(remotePubKey)
if err != nil { if err != nil {
return nil, err return nil, err
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment