Commit da42c385 authored by dignifiedquire's avatar dignifiedquire Committed by Steven Allen

fix: do not allocate when comparing keys

parent 3390f7d1
......@@ -363,7 +363,9 @@ func KeyEqual(k1, k2 Key) bool {
return true
}
b1, err1 := k1.Bytes()
b2, err2 := k2.Bytes()
return subtle.ConstantTimeCompare(b1, b2) == 1 && err1 == err2
if k1.Type() != k2.Type() {
return false
}
return k1.Equals(k2)
}
......@@ -101,18 +101,19 @@ func testKeyEncoding(t *testing.T, sk PrivKey) {
}
func testKeyEquals(t *testing.T, k Key) {
kb, err := k.Bytes()
if err != nil {
t.Fatal(err)
}
// kb, err := k.Raw()
// if err != nil {
// t.Fatal(err)
// }
if !KeyEqual(k, k) {
t.Fatal("Key not equal to itself.")
}
if !KeyEqual(k, testkey(kb)) {
t.Fatal("Key not equal to key with same bytes.")
}
// bad test, relies on deep internals..
// if !KeyEqual(k, testkey(kb)) {
// t.Fatal("Key not equal to key with same bytes.")
// }
sk, pk, err := test.RandTestKeyPair(RSA, 512)
if err != nil {
......@@ -143,7 +144,20 @@ func (pk testkey) Raw() ([]byte, error) {
}
func (pk testkey) Equals(k Key) bool {
return KeyEqual(pk, k)
if pk.Type() != k.Type() {
return false
}
a, err := pk.Raw()
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
}
func TestUnknownCurveErrors(t *testing.T) {
......
......@@ -3,6 +3,7 @@
package crypto
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
......@@ -63,7 +64,21 @@ func (pk *RsaPublicKey) Raw() ([]byte, error) {
// Equals checks whether this key is equal to another
func (pk *RsaPublicKey) Equals(k Key) bool {
return KeyEqual(pk, k)
// make sure this is an rsa public key
other, ok := (k).(*RsaPublicKey)
if !ok {
a, err := pk.Raw()
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
}
return pk.k.N.Cmp(other.k.N) == 0 && pk.k.E == other.k.E
}
// Sign returns a signature of the input data
......@@ -93,7 +108,35 @@ func (sk *RsaPrivateKey) Raw() ([]byte, error) {
// Equals checks whether this key is equal to another
func (sk *RsaPrivateKey) Equals(k Key) bool {
return KeyEqual(sk, k)
// make sure this is an rsa public key
other, ok := (k).(*RsaPrivateKey)
if !ok {
a, err := sk.Raw()
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
}
a := sk.sk
b := other.sk
if a.PublicKey.N.Cmp(b.PublicKey.N) != 0 {
return false
}
if a.PublicKey.E != b.PublicKey.E {
return false
}
if a.D.Cmp(b.D) != 0 {
return false
}
return true
}
// UnmarshalRsaPrivateKey returns a private key from the input x509 bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment