transport_test.go 9.71 KB
Newer Older
Marten Seemann's avatar
Marten Seemann committed
1 2 3 4
package libp2ptls

import (
	"context"
Marten Seemann's avatar
Marten Seemann committed
5 6
	"crypto/ecdsa"
	"crypto/elliptic"
Marten Seemann's avatar
Marten Seemann committed
7 8
	"crypto/rand"
	"crypto/rsa"
9 10
	"crypto/tls"
	"crypto/x509"
Marten Seemann's avatar
Marten Seemann committed
11
	"fmt"
12
	"math/big"
Marten Seemann's avatar
Marten Seemann committed
13
	mrand "math/rand"
Marten Seemann's avatar
Marten Seemann committed
14
	"net"
15 16 17
	"time"

	"github.com/onsi/gomega/gbytes"
Marten Seemann's avatar
Marten Seemann committed
18 19 20 21 22 23 24 25

	cs "github.com/libp2p/go-conn-security"
	ic "github.com/libp2p/go-libp2p-crypto"
	peer "github.com/libp2p/go-libp2p-peer"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
)

26 27 28 29 30 31
type transform struct {
	name      string
	apply     func(*Identity)
	remoteErr string // the error that the side validating the chain gets
}

Marten Seemann's avatar
Marten Seemann committed
32 33 34 35 36 37 38
var _ = Describe("Transport", func() {
	var (
		serverKey, clientKey ic.PrivKey
		serverID, clientID   peer.ID
	)

	createPeer := func() (peer.ID, ic.PrivKey) {
Marten Seemann's avatar
Marten Seemann committed
39 40
		var priv ic.PrivKey
		if mrand.Int()%2 == 0 {
Marten Seemann's avatar
Marten Seemann committed
41
			fmt.Fprintf(GinkgoWriter, " using an ECDSA key: ")
Marten Seemann's avatar
Marten Seemann committed
42 43 44 45
			var err error
			priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader)
			Expect(err).ToNot(HaveOccurred())
		} else {
Marten Seemann's avatar
Marten Seemann committed
46
			fmt.Fprintf(GinkgoWriter, " using an RSA key: ")
Marten Seemann's avatar
Marten Seemann committed
47 48 49 50
			var err error
			priv, _, err = ic.GenerateRSAKeyPair(1024, rand.Reader)
			Expect(err).ToNot(HaveOccurred())
		}
Marten Seemann's avatar
Marten Seemann committed
51 52
		id, err := peer.IDFromPrivateKey(priv)
		Expect(err).ToNot(HaveOccurred())
Marten Seemann's avatar
Marten Seemann committed
53
		fmt.Fprintln(GinkgoWriter, id.Pretty())
Marten Seemann's avatar
Marten Seemann committed
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
		return id, priv
	}

	connect := func() (net.Conn, net.Conn) {
		ln, err := net.Listen("tcp", "localhost:0")
		Expect(err).ToNot(HaveOccurred())
		defer ln.Close()
		serverConnChan := make(chan net.Conn)
		go func() {
			defer GinkgoRecover()
			conn, err := ln.Accept()
			Expect(err).ToNot(HaveOccurred())
			serverConnChan <- conn
		}()
		conn, err := net.Dial("tcp", ln.Addr().String())
		Expect(err).ToNot(HaveOccurred())
		return conn, <-serverConnChan
	}

	BeforeEach(func() {
Marten Seemann's avatar
Marten Seemann committed
74
		fmt.Fprintf(GinkgoWriter, "Initializing a server")
Marten Seemann's avatar
Marten Seemann committed
75
		serverID, serverKey = createPeer()
Marten Seemann's avatar
Marten Seemann committed
76
		fmt.Fprintf(GinkgoWriter, "Initializing a client")
Marten Seemann's avatar
Marten Seemann committed
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
		clientID, clientKey = createPeer()
	})

	It("handshakes", func() {
		clientTransport, err := New(clientKey)
		Expect(err).ToNot(HaveOccurred())
		serverTransport, err := New(serverKey)
		Expect(err).ToNot(HaveOccurred())

		clientInsecureConn, serverInsecureConn := connect()

		serverConnChan := make(chan cs.Conn)
		go func() {
			defer GinkgoRecover()
			serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
			Expect(err).ToNot(HaveOccurred())
			serverConnChan <- serverConn
		}()
		clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
		Expect(err).ToNot(HaveOccurred())
		var serverConn cs.Conn
		Eventually(serverConnChan).Should(Receive(&serverConn))
		defer clientConn.Close()
		defer serverConn.Close()
		Expect(clientConn.LocalPeer()).To(Equal(clientID))
		Expect(serverConn.LocalPeer()).To(Equal(serverID))
		Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey))
		Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey))
		Expect(clientConn.RemotePeer()).To(Equal(serverID))
		Expect(serverConn.RemotePeer()).To(Equal(clientID))
		Expect(clientConn.RemotePublicKey()).To(Equal(serverKey.GetPublic()))
		Expect(serverConn.RemotePublicKey()).To(Equal(clientKey.GetPublic()))
		// exchange some data
		_, err = serverConn.Write([]byte("foobar"))
		Expect(err).ToNot(HaveOccurred())
		b := make([]byte, 6)
		_, err = clientConn.Read(b)
		Expect(err).ToNot(HaveOccurred())
		Expect(string(b)).To(Equal("foobar"))
	})

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
	It("fails when the context of the outgoing connection is canceled", func() {
		clientTransport, err := New(clientKey)
		Expect(err).ToNot(HaveOccurred())
		serverTransport, err := New(serverKey)
		Expect(err).ToNot(HaveOccurred())

		clientInsecureConn, serverInsecureConn := connect()

		go func() {
			defer GinkgoRecover()
			_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
			Expect(err).To(HaveOccurred())
		}()
		ctx, cancel := context.WithCancel(context.Background())
		cancel()
		_, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID)
134
		Expect(err).To(MatchError(context.Canceled))
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
	})

	It("fails when the context of the incoming connection is canceled", func() {
		clientTransport, err := New(clientKey)
		Expect(err).ToNot(HaveOccurred())
		serverTransport, err := New(serverKey)
		Expect(err).ToNot(HaveOccurred())

		clientInsecureConn, serverInsecureConn := connect()

		go func() {
			defer GinkgoRecover()
			ctx, cancel := context.WithCancel(context.Background())
			cancel()
			_, err := serverTransport.SecureInbound(ctx, serverInsecureConn)
150
			Expect(err).To(MatchError(context.Canceled))
151 152 153 154 155
		}()
		_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
		Expect(err).To(HaveOccurred())
	})

Marten Seemann's avatar
Marten Seemann committed
156
	It("fails if the peer ID doesn't match", func() {
Marten Seemann's avatar
Marten Seemann committed
157
		fmt.Fprintf(GinkgoWriter, "Creating another peer")
Marten Seemann's avatar
Marten Seemann committed
158 159 160 161 162 163 164 165 166
		thirdPartyID, _ := createPeer()

		serverTransport, err := New(serverKey)
		Expect(err).ToNot(HaveOccurred())
		clientTransport, err := New(clientKey)
		Expect(err).ToNot(HaveOccurred())

		clientInsecureConn, serverInsecureConn := connect()

Marten Seemann's avatar
Marten Seemann committed
167
		done := make(chan struct{})
Marten Seemann's avatar
Marten Seemann committed
168 169 170 171 172
		go func() {
			defer GinkgoRecover()
			_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
			Expect(err).To(HaveOccurred())
			Expect(err.Error()).To(ContainSubstring("tls: bad certificate"))
Marten Seemann's avatar
Marten Seemann committed
173
			close(done)
Marten Seemann's avatar
Marten Seemann committed
174 175 176 177
		}()
		// dial, but expect the wrong peer ID
		_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID)
		Expect(err).To(MatchError("peer IDs don't match"))
Marten Seemann's avatar
Marten Seemann committed
178
		Eventually(done).Should(BeClosed())
Marten Seemann's avatar
Marten Seemann committed
179 180
	})

181 182
	Context("invalid certificates", func() {
		invalidateCertChain := func(identity *Identity) {
183
			switch identity.config.Certificates[0].PrivateKey.(type) {
184 185 186
			case *rsa.PrivateKey:
				key, err := rsa.GenerateKey(rand.Reader, 1024)
				Expect(err).ToNot(HaveOccurred())
187
				identity.config.Certificates[0].PrivateKey = key
188 189 190
			case *ecdsa.PrivateKey:
				key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
				Expect(err).ToNot(HaveOccurred())
191
				identity.config.Certificates[0].PrivateKey = key
192 193 194 195
			default:
				Fail("unexpected private key type")
			}
		}
Marten Seemann's avatar
Marten Seemann committed
196

197 198 199 200 201 202 203 204 205 206 207 208
		twoCerts := func(identity *Identity) {
			tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)}
			key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
			Expect(err).ToNot(HaveOccurred())
			key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
			Expect(err).ToNot(HaveOccurred())
			cert1DER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key1.Public(), key1)
			Expect(err).ToNot(HaveOccurred())
			cert1, err := x509.ParseCertificate(cert1DER)
			Expect(err).ToNot(HaveOccurred())
			cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key2)
			Expect(err).ToNot(HaveOccurred())
209
			identity.config.Certificates = []tls.Certificate{{
210 211 212 213
				Certificate: [][]byte{cert2DER, cert1DER},
				PrivateKey:  key2,
			}}
		}
Marten Seemann's avatar
Marten Seemann committed
214

215 216 217 218 219 220 221 222 223 224
		expiredCert := func(identity *Identity) {
			tmpl := &x509.Certificate{
				SerialNumber: big.NewInt(1),
				NotBefore:    time.Now().Add(-time.Hour),
				NotAfter:     time.Now().Add(-time.Minute),
			}
			key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
			Expect(err).ToNot(HaveOccurred())
			cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
			Expect(err).ToNot(HaveOccurred())
225
			identity.config.Certificates = []tls.Certificate{{
226 227 228 229
				Certificate: [][]byte{cert},
				PrivateKey:  key,
			}}
		}
Marten Seemann's avatar
Marten Seemann committed
230

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
		transforms := []transform{
			{
				name:      "private key used in the TLS handshake doesn't match the public key in the cert",
				apply:     invalidateCertChain,
				remoteErr: "tls: invalid certificate signature",
			},
			{
				name:      "certificate chain contains 2 certs",
				apply:     twoCerts,
				remoteErr: "expected one certificates in the chain",
			},
			{
				name:      "cert is expired",
				apply:     expiredCert,
				remoteErr: "certificate verification failed: x509: certificate has expired or is not yet valid",
			},
		}
Marten Seemann's avatar
Marten Seemann committed
248

249 250
		for i := range transforms {
			t := transforms[i]
Marten Seemann's avatar
Marten Seemann committed
251

252 253 254 255 256 257
			It(fmt.Sprintf("fails if the client presents an invalid cert: %s", t.name), func() {
				serverTransport, err := New(serverKey)
				Expect(err).ToNot(HaveOccurred())
				clientTransport, err := New(clientKey)
				Expect(err).ToNot(HaveOccurred())
				t.apply(clientTransport.identity)
Marten Seemann's avatar
Marten Seemann committed
258

259
				clientInsecureConn, serverInsecureConn := connect()
Marten Seemann's avatar
Marten Seemann committed
260

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
				done := make(chan struct{})
				go func() {
					defer GinkgoRecover()
					_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
					Expect(err).To(MatchError(t.remoteErr))
					close(done)
				}()

				conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
				Expect(err).ToNot(HaveOccurred())
				_, err = gbytes.TimeoutReader(conn, time.Second).Read([]byte{0})
				Expect(err).To(Or(
					// if the certificate's public key doesn't match the private key used for signing
					MatchError("remote error: tls: error decrypting message"),
					// all other errors
					MatchError("remote error: tls: bad certificate"),
				))
				Eventually(done).Should(BeClosed())
			})

			It(fmt.Sprintf("fails if the server presents an invalid cert: %s", t.name), func() {
				serverTransport, err := New(serverKey)
				Expect(err).ToNot(HaveOccurred())
				t.apply(serverTransport.identity)
				clientTransport, err := New(clientKey)
				Expect(err).ToNot(HaveOccurred())

				clientInsecureConn, serverInsecureConn := connect()

				done := make(chan struct{})
				go func() {
					defer GinkgoRecover()
					_, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn)
					Expect(err).To(HaveOccurred())
					Expect(err.Error()).To(ContainSubstring("remote error: tls:"))
					close(done)
				}()

				_, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID)
				Expect(err).To(MatchError(t.remoteErr))
				Eventually(done).Should(BeClosed())
			})
		}
Marten Seemann's avatar
Marten Seemann committed
304 305
	})
})