Commit 9a4415d1 authored by Steven Allen's avatar Steven Allen

use a fallback basicEquals function everywhere

This also ensures we check that the types are equal, even if we're comparing
directly with `k1.Equals(k2)` instead of `KeyEquals(k1, k2)`.
parent 2df9672e
...@@ -119,7 +119,7 @@ func (ePriv *ECDSAPrivateKey) Raw() ([]byte, error) { ...@@ -119,7 +119,7 @@ func (ePriv *ECDSAPrivateKey) Raw() ([]byte, error) {
func (ePriv *ECDSAPrivateKey) Equals(o Key) bool { func (ePriv *ECDSAPrivateKey) Equals(o Key) bool {
oPriv, ok := o.(*ECDSAPrivateKey) oPriv, ok := o.(*ECDSAPrivateKey)
if !ok { if !ok {
return false return basicEquals(ePriv, o)
} }
return ePriv.priv.D.Cmp(oPriv.priv.D) == 0 return ePriv.priv.D.Cmp(oPriv.priv.D) == 0
...@@ -163,7 +163,7 @@ func (ePub ECDSAPublicKey) Raw() ([]byte, error) { ...@@ -163,7 +163,7 @@ func (ePub ECDSAPublicKey) Raw() ([]byte, error) {
func (ePub *ECDSAPublicKey) Equals(o Key) bool { func (ePub *ECDSAPublicKey) Equals(o Key) bool {
oPub, ok := o.(*ECDSAPublicKey) oPub, ok := o.(*ECDSAPublicKey)
if !ok { if !ok {
return false return basicEquals(ePub, o)
} }
return ePub.pub.X != nil && ePub.pub.Y != nil && oPub.pub.X != nil && oPub.pub.Y != nil && return ePub.pub.X != nil && ePub.pub.Y != nil && oPub.pub.X != nil && oPub.pub.Y != nil &&
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
package crypto package crypto
import ( import (
"bytes"
"crypto/elliptic" "crypto/elliptic"
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
...@@ -363,9 +364,21 @@ func KeyEqual(k1, k2 Key) bool { ...@@ -363,9 +364,21 @@ func KeyEqual(k1, k2 Key) bool {
return true return true
} }
return k1.Equals(k2)
}
func basicEquals(k1, k2 Key) bool {
if k1.Type() != k2.Type() { if k1.Type() != k2.Type() {
return false return false
} }
return k1.Equals(k2) a, err := k1.Raw()
if err != nil {
return false
}
b, err := k2.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
} }
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
package crypto package crypto
import ( import (
"bytes"
pb "github.com/libp2p/go-libp2p-core/crypto/pb" pb "github.com/libp2p/go-libp2p-core/crypto/pb"
openssl "github.com/libp2p/go-openssl" openssl "github.com/libp2p/go-openssl"
...@@ -65,16 +63,7 @@ func (pk *opensslPublicKey) Raw() ([]byte, error) { ...@@ -65,16 +63,7 @@ func (pk *opensslPublicKey) Raw() ([]byte, error) {
func (pk *opensslPublicKey) Equals(k Key) bool { func (pk *opensslPublicKey) Equals(k Key) bool {
k0, ok := k.(*RsaPublicKey) k0, ok := k.(*RsaPublicKey)
if !ok { if !ok {
a, err := pk.Raw() return basicEquals(pk, k)
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
} }
return pk.key.Equal(k0.opensslPublicKey.key) return pk.key.Equal(k0.opensslPublicKey.key)
...@@ -112,16 +101,7 @@ func (sk *opensslPrivateKey) Raw() ([]byte, error) { ...@@ -112,16 +101,7 @@ func (sk *opensslPrivateKey) Raw() ([]byte, error) {
func (sk *opensslPrivateKey) Equals(k Key) bool { func (sk *opensslPrivateKey) Equals(k Key) bool {
k0, ok := k.(*RsaPrivateKey) k0, ok := k.(*RsaPrivateKey)
if !ok { if !ok {
a, err := sk.Raw() return basicEquals(sk, k)
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
} }
return sk.key.Equal(k0.opensslPrivateKey.key) return sk.key.Equal(k0.opensslPrivateKey.key)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
package crypto package crypto
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
...@@ -67,15 +66,7 @@ func (pk *RsaPublicKey) Equals(k Key) bool { ...@@ -67,15 +66,7 @@ func (pk *RsaPublicKey) Equals(k Key) bool {
// make sure this is an rsa public key // make sure this is an rsa public key
other, ok := (k).(*RsaPublicKey) other, ok := (k).(*RsaPublicKey)
if !ok { if !ok {
a, err := pk.Raw() return basicEquals(pk, k)
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
} }
return pk.k.N.Cmp(other.k.N) == 0 && pk.k.E == other.k.E return pk.k.N.Cmp(other.k.N) == 0 && pk.k.E == other.k.E
...@@ -111,15 +102,7 @@ func (sk *RsaPrivateKey) Equals(k Key) bool { ...@@ -111,15 +102,7 @@ func (sk *RsaPrivateKey) Equals(k Key) bool {
// make sure this is an rsa public key // make sure this is an rsa public key
other, ok := (k).(*RsaPrivateKey) other, ok := (k).(*RsaPrivateKey)
if !ok { if !ok {
a, err := sk.Raw() return basicEquals(sk, k)
if err != nil {
return false
}
b, err := k.Raw()
if err != nil {
return false
}
return bytes.Equal(a, b)
} }
a := sk.sk a := sk.sk
......
...@@ -66,7 +66,7 @@ func (k *Secp256k1PrivateKey) Raw() ([]byte, error) { ...@@ -66,7 +66,7 @@ func (k *Secp256k1PrivateKey) Raw() ([]byte, error) {
func (k *Secp256k1PrivateKey) Equals(o Key) bool { func (k *Secp256k1PrivateKey) Equals(o Key) bool {
sk, ok := o.(*Secp256k1PrivateKey) sk, ok := o.(*Secp256k1PrivateKey)
if !ok { if !ok {
return false return basicEquals(k, o)
} }
return k.D.Cmp(sk.D) == 0 return k.D.Cmp(sk.D) == 0
...@@ -107,7 +107,7 @@ func (k *Secp256k1PublicKey) Raw() ([]byte, error) { ...@@ -107,7 +107,7 @@ func (k *Secp256k1PublicKey) Raw() ([]byte, error) {
func (k *Secp256k1PublicKey) Equals(o Key) bool { func (k *Secp256k1PublicKey) Equals(o Key) bool {
sk, ok := o.(*Secp256k1PublicKey) sk, ok := o.(*Secp256k1PublicKey)
if !ok { if !ok {
return false return basicEquals(k, o)
} }
return (*btcec.PublicKey)(k).IsEqual((*btcec.PublicKey)(sk)) return (*btcec.PublicKey)(k).IsEqual((*btcec.PublicKey)(sk))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment