Commit 785ee8c8 authored by Steven Allen's avatar Steven Allen

fix: handle nil peer IDs

Feels like Java all over again.

fixes #87
parent 416a1c1a
...@@ -207,6 +207,9 @@ func (ab *dsAddrBook) Close() error { ...@@ -207,6 +207,9 @@ func (ab *dsAddrBook) Close() error {
// //
// If the cache argument is true, the record is inserted in the cache when loaded from the datastore. // If the cache argument is true, the record is inserted in the cache when loaded from the datastore.
func (ab *dsAddrBook) loadRecord(id peer.ID, cache bool, update bool) (pr *addrsRecord, err error) { func (ab *dsAddrBook) loadRecord(id peer.ID, cache bool, update bool) (pr *addrsRecord, err error) {
if err := id.Validate(); err != nil {
return nil, err
}
if e, ok := ab.cache.Get(id); ok { if e, ok := ab.cache.Get(id); ok {
pr = e.(*addrsRecord) pr = e.(*addrsRecord)
pr.Lock() pr.Lock()
...@@ -421,6 +424,11 @@ func (ab *dsAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multi ...@@ -421,6 +424,11 @@ func (ab *dsAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multi
// ClearAddrs will delete all known addresses for a peer ID. // ClearAddrs will delete all known addresses for a peer ID.
func (ab *dsAddrBook) ClearAddrs(p peer.ID) { func (ab *dsAddrBook) ClearAddrs(p peer.ID) {
if err := p.Validate(); err != nil {
// nothing to do
return
}
ab.cache.Remove(p) ab.cache.Remove(p)
key := addrBookBase.ChildString(b32.RawStdEncoding.EncodeToString([]byte(p))) key := addrBookBase.ChildString(b32.RawStdEncoding.EncodeToString([]byte(p)))
......
...@@ -41,6 +41,9 @@ func NewPeerMetadata(_ context.Context, store ds.Datastore, _ Options) (*dsPeerM ...@@ -41,6 +41,9 @@ func NewPeerMetadata(_ context.Context, store ds.Datastore, _ Options) (*dsPeerM
} }
func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) {
if err := p.Validate(); err != nil {
return nil, err
}
k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key) k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key)
value, err := pm.ds.Get(k) value, err := pm.ds.Get(k)
if err != nil { if err != nil {
...@@ -58,6 +61,9 @@ func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { ...@@ -58,6 +61,9 @@ func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) {
} }
func (pm *dsPeerMetadata) Put(p peer.ID, key string, val interface{}) error { func (pm *dsPeerMetadata) Put(p peer.ID, key string, val interface{}) error {
if err := p.Validate(); err != nil {
return err
}
k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key) k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key)
var buf pool.Buffer var buf pool.Buffer
if err := gob.NewEncoder(&buf).Encode(&val); err != nil { if err := gob.NewEncoder(&buf).Encode(&val); err != nil {
......
...@@ -39,6 +39,10 @@ func NewProtoBook(meta pstore.PeerMetadata) *dsProtoBook { ...@@ -39,6 +39,10 @@ func NewProtoBook(meta pstore.PeerMetadata) *dsProtoBook {
} }
func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -52,6 +56,10 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { ...@@ -52,6 +56,10 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
} }
func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -69,6 +77,10 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { ...@@ -69,6 +77,10 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
} }
func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
...@@ -87,6 +99,10 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { ...@@ -87,6 +99,10 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
} }
func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
...@@ -107,6 +123,10 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, ...@@ -107,6 +123,10 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string,
} }
func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
......
...@@ -196,6 +196,11 @@ func (mab *memoryAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, tt ...@@ -196,6 +196,11 @@ func (mab *memoryAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, tt
} }
func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration, signed bool) { func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration, signed bool) {
if err := p.Validate(); err != nil {
log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err)
return
}
// if ttl is zero, exit. nothing to do. // if ttl is zero, exit. nothing to do.
if ttl <= 0 { if ttl <= 0 {
return return
...@@ -244,12 +249,22 @@ func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du ...@@ -244,12 +249,22 @@ func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du
// SetAddr calls mgr.SetAddrs(p, addr, ttl) // SetAddr calls mgr.SetAddrs(p, addr, ttl)
func (mab *memoryAddrBook) SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) { func (mab *memoryAddrBook) SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) {
if err := p.Validate(); err != nil {
log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err)
return
}
mab.SetAddrs(p, []ma.Multiaddr{addr}, ttl) mab.SetAddrs(p, []ma.Multiaddr{addr}, ttl)
} }
// SetAddrs sets the ttl on addresses. This clears any TTL there previously. // SetAddrs sets the ttl on addresses. This clears any TTL there previously.
// This is used when we receive the best estimate of the validity of an address. // This is used when we receive the best estimate of the validity of an address.
func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) { func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) {
if err := p.Validate(); err != nil {
log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err)
return
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -287,6 +302,11 @@ func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du ...@@ -287,6 +302,11 @@ func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du
// UpdateAddrs updates the addresses associated with the given peer that have // UpdateAddrs updates the addresses associated with the given peer that have
// the given oldTTL to have the given newTTL. // the given oldTTL to have the given newTTL.
func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) { func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) {
if err := p.Validate(); err != nil {
log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err)
return
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -310,6 +330,11 @@ func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL t ...@@ -310,6 +330,11 @@ func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL t
// Addrs returns all known (and valid) addresses for a given peer // Addrs returns all known (and valid) addresses for a given peer
func (mab *memoryAddrBook) Addrs(p peer.ID) []ma.Multiaddr { func (mab *memoryAddrBook) Addrs(p peer.ID) []ma.Multiaddr {
if err := p.Validate(); err != nil {
// invalid peer ID = no addrs
return nil
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
...@@ -336,6 +361,11 @@ func validAddrs(amap map[string]*expiringAddr) []ma.Multiaddr { ...@@ -336,6 +361,11 @@ func validAddrs(amap map[string]*expiringAddr) []ma.Multiaddr {
// given peer id, if one exists. // given peer id, if one exists.
// Returns nil if no signed PeerRecord exists for the peer. // Returns nil if no signed PeerRecord exists for the peer.
func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope { func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope {
if err := p.Validate(); err != nil {
// invalid peer ID = no addrs
return nil
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
...@@ -356,6 +386,11 @@ func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope { ...@@ -356,6 +386,11 @@ func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope {
// ClearAddrs removes all previously stored addresses // ClearAddrs removes all previously stored addresses
func (mab *memoryAddrBook) ClearAddrs(p peer.ID) { func (mab *memoryAddrBook) ClearAddrs(p peer.ID) {
if err := p.Validate(); err != nil {
// nothing to clear
return
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -367,6 +402,13 @@ func (mab *memoryAddrBook) ClearAddrs(p peer.ID) { ...@@ -367,6 +402,13 @@ func (mab *memoryAddrBook) ClearAddrs(p peer.ID) {
// AddrStream returns a channel on which all new addresses discovered for a // AddrStream returns a channel on which all new addresses discovered for a
// given peer ID will be published. // given peer ID will be published.
func (mab *memoryAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multiaddr { func (mab *memoryAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multiaddr {
if err := p.Validate(); err != nil {
log.Warningf("tried to get addrs for invalid peer ID %s: %s", p, err)
ch := make(chan ma.Multiaddr)
close(ch)
return ch
}
s := mab.segments.get(p) s := mab.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
......
...@@ -35,6 +35,9 @@ func NewPeerMetadata() *memoryPeerMetadata { ...@@ -35,6 +35,9 @@ func NewPeerMetadata() *memoryPeerMetadata {
} }
func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error { func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error {
if err := p.Validate(); err != nil {
return err
}
ps.dslock.Lock() ps.dslock.Lock()
defer ps.dslock.Unlock() defer ps.dslock.Unlock()
if vals, ok := val.(string); ok && internKeys[key] { if vals, ok := val.(string); ok && internKeys[key] {
...@@ -49,6 +52,9 @@ func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error ...@@ -49,6 +52,9 @@ func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error
} }
func (ps *memoryPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { func (ps *memoryPeerMetadata) Get(p peer.ID, key string) (interface{}, error) {
if err := p.Validate(); err != nil {
return nil, err
}
ps.dslock.RLock() ps.dslock.RLock()
defer ps.dslock.RUnlock() defer ps.dslock.RUnlock()
i, ok := ps.ds[metakey{p, key}] i, ok := ps.ds[metakey{p, key}]
......
...@@ -67,6 +67,10 @@ func (pb *memoryProtoBook) internProtocol(proto string) string { ...@@ -67,6 +67,10 @@ func (pb *memoryProtoBook) internProtocol(proto string) string {
} }
func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -82,6 +86,10 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { ...@@ -82,6 +86,10 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error {
} }
func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -100,6 +108,10 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { ...@@ -100,6 +108,10 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error {
} }
func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
...@@ -113,6 +125,10 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { ...@@ -113,6 +125,10 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) {
} }
func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
...@@ -130,6 +146,10 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { ...@@ -130,6 +146,10 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
} }
func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p) s := pb.segments.get(p)
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
......
...@@ -279,6 +279,29 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { ...@@ -279,6 +279,29 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) {
if !reflect.DeepEqual(supported, protos[2:]) { if !reflect.DeepEqual(supported, protos[2:]) {
t.Fatal("expected only one protocol to remain") t.Fatal("expected only one protocol to remain")
} }
// test bad peer IDs
badp := peer.ID("")
err = ps.AddProtocols(badp, protos...)
if err == nil {
t.Fatal("expected error when using a bad peer ID")
}
_, err = ps.GetProtocols(badp)
if err == nil || err == pstore.ErrNotFound {
t.Fatal("expected error when using a bad peer ID")
}
_, err = ps.SupportsProtocols(badp, "q", "w", "a", "y", "b")
if err == nil || err == pstore.ErrNotFound {
t.Fatal("expected error when using a bad peer ID")
}
err = ps.RemoveProtocols(badp)
if err == nil || err == pstore.ErrNotFound {
t.Fatal("expected error when using a bad peer ID")
}
} }
} }
...@@ -309,6 +332,10 @@ func testBasicPeerstore(ps pstore.Peerstore) func(t *testing.T) { ...@@ -309,6 +332,10 @@ func testBasicPeerstore(ps pstore.Peerstore) func(t *testing.T) {
if !pinfo.Addrs[0].Equal(addrs[0]) { if !pinfo.Addrs[0].Equal(addrs[0]) {
t.Fatal("stored wrong address") t.Fatal("stored wrong address")
} }
// should fail silently...
ps.AddAddrs("", addrs, pstore.PermanentAddrTTL)
ps.Addrs("")
} }
} }
...@@ -355,6 +382,12 @@ func testMetadata(ps pstore.Peerstore) func(t *testing.T) { ...@@ -355,6 +382,12 @@ func testMetadata(ps pstore.Peerstore) func(t *testing.T) {
continue continue
} }
} }
if err := ps.Put("", "foobar", "thing"); err == nil {
t.Errorf("expected error for bad peer ID")
}
if _, err := ps.Get("", "foobar"); err == nil || err == pstore.ErrNotFound {
t.Errorf("expected error for bad peer ID")
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment