Commit 11bbc4e8 authored by Marten Seemann's avatar Marten Seemann

implement the new handshake

parent 8b16d92c
......@@ -29,7 +29,7 @@ before_install:
script:
# some tests are randomized. Run them a few times.
- for i in `seq 1 3`; do
- for i in `seq 1 10`; do
ginkgo -r -v --cover --randomizeAllSpecs --randomizeSuites --trace --progress;
done
......
package libp2ptls
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"math/big"
"time"
crypto "github.com/libp2p/go-libp2p-crypto"
ic "github.com/libp2p/go-libp2p-crypto"
pb "github.com/libp2p/go-libp2p-crypto/pb"
peer "github.com/libp2p/go-libp2p-peer"
)
const certValidityPeriod = 180 * 24 * time.Hour
const certValidityPeriod = 100 * 365 * 24 * time.Hour // ~100 years
var extensionID = getPrefixedExtensionID([]int{1, 1})
type signedKey struct {
PubKey []byte
Signature []byte
}
// Identity is used to secure connections
type Identity struct {
......@@ -24,7 +34,7 @@ type Identity struct {
// NewIdentity creates a new identity
func NewIdentity(privKey ic.PrivKey) (*Identity, error) {
key, cert, err := keyToCertificate(privKey)
cert, err := keyToCertificate(privKey)
if err != nil {
return nil, err
}
......@@ -33,10 +43,7 @@ func NewIdentity(privKey ic.PrivKey) (*Identity, error) {
MinVersion: tls.VersionTLS13,
InsecureSkipVerify: true, // This is not insecure here. We will verify the cert chain ourselves.
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: key,
}},
Certificates: []tls.Certificate{*cert},
VerifyPeerCertificate: func(_ [][]byte, _ [][]*x509.Certificate) error {
panic("tls config not specialized for peer")
},
......@@ -95,70 +102,95 @@ func getRemotePubKey(chain []*x509.Certificate) (ic.PubKey, error) {
if len(chain) != 1 {
return nil, errors.New("expected one certificates in the chain")
}
cert := chain[0]
pool := x509.NewCertPool()
pool.AddCert(chain[0])
if _, err := chain[0].Verify(x509.VerifyOptions{Roots: pool}); err != nil {
pool.AddCert(cert)
if _, err := cert.Verify(x509.VerifyOptions{Roots: pool}); err != nil {
// If we return an x509 error here, it will be sent on the wire.
// Wrap the error to avoid that.
return nil, fmt.Errorf("certificate verification failed: %s", err)
}
remotePubKey, err := x509.MarshalPKIXPublicKey(chain[0].PublicKey)
var found bool
var keyExt pkix.Extension
// find the libp2p key extension, skipping all unknown extensions
for _, ext := range cert.Extensions {
if extensionIDEqual(ext.Id, extensionID) {
keyExt = ext
found = true
break
}
}
if !found {
return nil, errors.New("expected certificate to contain the key extension")
}
var sk signedKey
if _, err := asn1.Unmarshal(keyExt.Value, &sk); err != nil {
return nil, fmt.Errorf("unmarshalling signed certificate failed: %s", err)
}
pubKey, err := crypto.UnmarshalPublicKey(sk.PubKey)
if err != nil {
return nil, fmt.Errorf("unmarshalling public key failed: %s", err)
}
certKeyPub, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return nil, err
}
switch chain[0].PublicKeyAlgorithm {
case x509.RSA:
return ic.UnmarshalRsaPublicKey(remotePubKey)
case x509.ECDSA:
return ic.UnmarshalECDSAPublicKey(remotePubKey)
default:
return nil, fmt.Errorf("unexpected public key algorithm: %d", chain[0].PublicKeyAlgorithm)
valid, err := pubKey.Verify(certKeyPub, sk.Signature)
if err != nil {
return nil, fmt.Errorf("signature verification failed: %s", err)
}
if !valid {
return nil, errors.New("signature invalid")
}
return pubKey, nil
}
func keyToCertificate(sk ic.PrivKey) (crypto.PrivateKey, *x509.Certificate, error) {
sn, err := rand.Int(rand.Reader, big.NewInt(1<<62))
func keyToCertificate(sk ic.PrivKey) (*tls.Certificate, error) {
certKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, err
}
tmpl := &x509.Certificate{
SerialNumber: sn,
NotBefore: time.Now().Add(-24 * time.Hour),
NotAfter: time.Now().Add(certValidityPeriod),
return nil, err
}
var privateKey crypto.PrivateKey
var publicKey crypto.PublicKey
raw, err := sk.Raw()
keyBytes, err := crypto.MarshalPublicKey(sk.GetPublic())
if err != nil {
return nil, nil, err
return nil, err
}
switch sk.Type() {
case pb.KeyType_RSA:
k, err := x509.ParsePKCS1PrivateKey(raw)
if err != nil {
return nil, nil, err
}
publicKey = &k.PublicKey
privateKey = k
case pb.KeyType_ECDSA:
k, err := x509.ParseECPrivateKey(raw)
if err != nil {
return nil, nil, err
}
publicKey = &k.PublicKey
privateKey = k
// TODO: add support for Ed25519
default:
return nil, nil, errors.New("unsupported key type for TLS")
certKeyPub, err := x509.MarshalPKIXPublicKey(certKey.Public())
if err != nil {
return nil, err
}
signature, err := sk.Sign(certKeyPub)
if err != nil {
return nil, err
}
value, err := asn1.Marshal(signedKey{
PubKey: keyBytes,
Signature: signature,
})
if err != nil {
return nil, err
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, publicKey, privateKey)
sn, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
return nil, nil, err
return nil, err
}
tmpl := &x509.Certificate{
SerialNumber: sn,
NotBefore: time.Time{},
NotAfter: time.Now().Add(certValidityPeriod),
// after calling CreateCertificate, these will end up in Certificate.Extensions
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: value},
},
}
cert, err := x509.ParseCertificate(certDER)
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, certKey.Public(), certKey)
if err != nil {
return nil, nil, err
return nil, err
}
return privateKey, cert, nil
return &tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: certKey,
}, nil
}
package libp2ptls
// TODO: get an assigment for a valid OID
var extensionPrefix = []int{1, 3, 6, 1, 4, 1, 123456789}
// getPrefixedExtensionID returns an Object Identifier
// that can be used in x509 Certificates.
func getPrefixedExtensionID(suffix []int) []int {
return append(extensionPrefix, suffix...)
}
// extensionIDEqual compares two extension IDs.
func extensionIDEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
package libp2ptls
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Extensions", func() {
It("generates a prefixed extension ID", func() {
Expect(getPrefixedExtensionID([]int{13, 37})).To(Equal([]int{1, 3, 6, 1, 4, 1, 123456789, 13, 37}))
})
It("compares extension IDs", func() {
Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3, 4})).To(BeTrue())
Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3})).To(BeFalse())
Expect(extensionIDEqual([]int{1, 2, 3}, []int{1, 2, 3, 4})).To(BeFalse())
Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{4, 3, 2, 1})).To(BeFalse())
})
})
......@@ -2,12 +2,15 @@ package libp2ptls
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
mrand "math/rand"
......@@ -15,6 +18,7 @@ import (
"time"
"github.com/onsi/gomega/gbytes"
"github.com/onsi/gomega/types"
cs "github.com/libp2p/go-conn-security"
ci "github.com/libp2p/go-libp2p-crypto"
......@@ -26,7 +30,7 @@ import (
type transform struct {
name string
apply func(*Identity)
remoteErr string // the error that the side validating the chain gets
remoteErr types.GomegaMatcher // the error that the side validating the chain gets
}
var _ = Describe("Transport", func() {
......@@ -37,17 +41,22 @@ var _ = Describe("Transport", func() {
createPeer := func() (peer.ID, ci.PrivKey) {
var priv ci.PrivKey
if mrand.Int()%2 == 0 {
var err error
switch mrand.Int() % 4 {
case 0:
fmt.Fprintf(GinkgoWriter, " using an ECDSA key: ")
var err error
priv, _, err = ci.GenerateECDSAKeyPair(rand.Reader)
Expect(err).ToNot(HaveOccurred())
} else {
case 1:
fmt.Fprintf(GinkgoWriter, " using an RSA key: ")
var err error
priv, _, err = ci.GenerateRSAKeyPair(1024, rand.Reader)
Expect(err).ToNot(HaveOccurred())
case 2:
fmt.Fprintf(GinkgoWriter, " using an Ed25519 key: ")
priv, _, err = ci.GenerateEd25519Key(rand.Reader)
case 3:
fmt.Fprintf(GinkgoWriter, " using an Ed25519 key: ")
priv, _, err = ci.GenerateSecp256k1Key(rand.Reader)
}
Expect(err).ToNot(HaveOccurred())
id, err := peer.IDFromPrivateKey(priv)
Expect(err).ToNot(HaveOccurred())
fmt.Fprintln(GinkgoWriter, id.Pretty())
......@@ -212,37 +221,149 @@ var _ = Describe("Transport", func() {
}}
}
getCertWithKey := func(key crypto.Signer, tmpl *x509.Certificate) tls.Certificate {
cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: key,
}
}
getCert := func(tmpl *x509.Certificate) tls.Certificate {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expect(err).ToNot(HaveOccurred())
return getCertWithKey(key, tmpl)
}
expiredCert := func(identity *Identity) {
tmpl := &x509.Certificate{
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(-time.Minute),
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
})
identity.config.Certificates = []tls.Certificate{cert}
}
noKeyExtension := func(identity *Identity) {
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
})
identity.config.Certificates = []tls.Certificate{cert}
}
unparseableKeyExtension := func(identity *Identity) {
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: []byte("foobar")},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
unparseableKey := func(identity *Identity) {
data, err := asn1.Marshal(signedKey{PubKey: []byte("foobar")})
Expect(err).ToNot(HaveOccurred())
cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: data},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
tooShortSignature := func(identity *Identity) {
key, _, err := ci.GenerateSecp256k1Key(rand.Reader)
Expect(err).ToNot(HaveOccurred())
identity.config.Certificates = []tls.Certificate{{
Certificate: [][]byte{cert},
PrivateKey: key,
}}
keyBytes, err := key.GetPublic().Bytes()
Expect(err).ToNot(HaveOccurred())
data, err := asn1.Marshal(signedKey{
PubKey: keyBytes,
Signature: []byte("foobar"),
})
Expect(err).ToNot(HaveOccurred())
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: data},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
invalidSignature := func(identity *Identity) {
key, _, err := ci.GenerateSecp256k1Key(rand.Reader)
Expect(err).ToNot(HaveOccurred())
keyBytes, err := key.GetPublic().Bytes()
Expect(err).ToNot(HaveOccurred())
signature, err := key.Sign([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
data, err := asn1.Marshal(signedKey{
PubKey: keyBytes,
Signature: signature,
})
Expect(err).ToNot(HaveOccurred())
cert := getCert(&x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{
{Id: extensionID, Value: data},
},
})
identity.config.Certificates = []tls.Certificate{cert}
}
transforms := []transform{
{
name: "private key used in the TLS handshake doesn't match the public key in the cert",
apply: invalidateCertChain,
remoteErr: "tls: invalid certificate signature",
remoteErr: Equal("tls: invalid certificate signature"),
},
{
name: "certificate chain contains 2 certs",
apply: twoCerts,
remoteErr: "expected one certificates in the chain",
remoteErr: Equal("expected one certificates in the chain"),
},
{
name: "cert is expired",
apply: expiredCert,
remoteErr: "certificate verification failed: x509: certificate has expired or is not yet valid",
remoteErr: Equal("certificate verification failed: x509: certificate has expired or is not yet valid"),
},
{
name: "cert doesn't have the key extension",
apply: noKeyExtension,
remoteErr: Equal("expected certificate to contain the key extension"),
},
{
name: "key extension not parseable",
apply: unparseableKeyExtension,
remoteErr: ContainSubstring("asn1"),
},
{
name: "key protobuf not parseable",
apply: unparseableKey,
remoteErr: ContainSubstring("unmarshalling public key failed: proto:"),
},
{
name: "signature is malformed",
apply: tooShortSignature,
remoteErr: ContainSubstring("signature verification failed:"),
},
{
name: "signature is invalid",
apply: invalidSignature,
remoteErr: Equal("signature invalid"),
},
}
......@@ -262,7 +383,8 @@ var _ = Describe("Transport", func() {
go func() {
defer GinkgoRecover()
_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
Expect(err).To(MatchError(t.remoteErr))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(t.remoteErr)
close(done)
}()
......@@ -297,7 +419,8 @@ var _ = Describe("Transport", func() {
}()
_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
Expect(err).To(MatchError(t.remoteErr))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(t.remoteErr)
Eventually(done).Should(BeClosed())
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment