Unverified Commit 97741ed0 authored by Steven Allen's avatar Steven Allen Committed by GitHub

Merge pull request #122 from dirkmc/fix/select-best-rec-on-put

Disallow overwriting new records with older records on DHT PUT
parents 2bab49ef 78082710
package dht
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"sort"
......@@ -20,6 +22,7 @@ import (
peer "github.com/libp2p/go-libp2p-peer"
pstore "github.com/libp2p/go-libp2p-peerstore"
record "github.com/libp2p/go-libp2p-record"
routing "github.com/libp2p/go-libp2p-routing"
bhost "github.com/libp2p/go-libp2p/p2p/host/basic"
ci "github.com/libp2p/go-testutil/ci"
travisci "github.com/libp2p/go-testutil/ci/travis"
......@@ -188,6 +191,79 @@ func TestValueGetSet(t *testing.T) {
}
}
func TestValueSetInvalid(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dhtA := setupDHT(ctx, t, false)
dhtB := setupDHT(ctx, t, false)
defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()
vf := func(r *record.ValidationRecord) error {
if bytes.Compare(r.Value, []byte("expired")) == 0 {
return errors.New("expired")
}
return nil
}
nulsel := func(k string, bs [][]byte) (int, error) {
index := -1
for i, b := range bs {
if bytes.Compare(b, []byte("newer")) == 0 {
index = i
} else if bytes.Compare(b, []byte("valid")) == 0 {
if index == -1 {
index = i
}
}
}
if index == -1 {
return -1, errors.New("no rec found")
}
return index, nil
}
dhtA.Validator["v"] = vf
dhtB.Validator["v"] = vf
dhtA.Selector["v"] = nulsel
dhtB.Selector["v"] = nulsel
connect(t, ctx, dhtA, dhtB)
testSetGet := func(val string, exp string, experr error) {
ctxT, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
err := dhtA.PutValue(ctxT, "/v/hello", []byte(val))
if err != nil {
t.Fatal(err)
}
ctxT, cancel = context.WithTimeout(ctx, time.Second*2)
defer cancel()
valb, err := dhtB.GetValue(ctxT, "/v/hello")
if err != experr {
t.Fatalf("Set/Get %v: Expected %v error but got %v", val, experr, err)
}
if err == nil {
if string(valb) != exp {
t.Fatalf("Expected '%v' got '%s'", exp, string(valb))
}
}
}
// Expired records should not be returned
testSetGet("expired", "", routing.ErrNotFound)
// Valid record should be returned
testSetGet("valid", "valid", nil)
// Newer record should supersede previous record
testSetGet("newer", "newer", nil)
// Attempt to set older record again should be ignored
testSetGet("valid", "newer", nil)
}
func TestInvalidMessageSenderTracking(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
......
......@@ -172,11 +172,33 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
}
cleanRecord(rec)
// Make sure the record is valid (not expired, valid signature etc)
if err = dht.Validator.VerifyRecord(rec); err != nil {
log.Warningf("Bad dht record in PUT from: %s. %s", p.Pretty(), err)
return nil, err
}
// Make sure the new record is "better" than the record we have locally.
// This prevents a record with for example a lower sequence number from
// overwriting a record with a higher sequence number.
existing, err := dht.getRecordFromDatastore(dskey)
if err != nil {
return nil, err
}
if existing != nil {
recs := [][]byte{rec.GetValue(), existing.GetValue()}
i, err := dht.Selector.BestRecord(pmes.GetKey(), recs)
if err != nil {
log.Warningf("Bad dht record in PUT from %s: %s", p.Pretty(), err)
return nil, err
}
if i != 0 {
log.Infof("DHT record in PUT from %s is older than existing record. Ignoring", p.Pretty())
return nil, errors.New("old record")
}
}
// record the time we receive every record
rec.TimeReceived = proto.String(u.FormatRFC3339(time.Now()))
......@@ -190,6 +212,42 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
return pmes, err
}
func (dht *IpfsDHT) getRecordFromDatastore(dskey ds.Key) (*recpb.Record, error) {
reci, err := dht.datastore.Get(dskey)
if err == ds.ErrNotFound {
return nil, nil
}
if err != nil {
log.Warningf("Got error retrieving record with key %s from datastore: %s", dskey, err)
return nil, err
}
byt, ok := reci.([]byte)
if !ok {
// Bad data in datastore, log it but don't return an error, we'll just overwrite it
log.Warningf("Value stored in datastore with key %s is not []byte", dskey)
return nil, nil
}
rec := new(recpb.Record)
err = proto.Unmarshal(byt, rec)
if err != nil {
// Bad data in datastore, log it but don't return an error, we'll just overwrite it
log.Warningf("Bad record data stored in datastore with key %s: could not unmarshal record", dskey)
return nil, nil
}
err = dht.Validator.VerifyRecord(rec)
if err != nil {
// Invalid record in datastore, probably expired but don't return an error,
// we'll just overwrite it
log.Debugf("Local record verify failed: %s (discarded)", err)
return nil, nil
}
return rec, nil
}
func (dht *IpfsDHT) handlePing(_ context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
log.Debugf("%s Responding to ping from %s!\n", dht.self, p)
return pmes, nil
......
......@@ -8,7 +8,9 @@ import (
"sync"
"time"
proto "github.com/gogo/protobuf/proto"
cid "github.com/ipfs/go-cid"
u "github.com/ipfs/go-ipfs-util"
logging "github.com/ipfs/go-log"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
kb "github.com/libp2p/go-libp2p-kbucket"
......@@ -45,6 +47,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte) (err
log.Debugf("PutValue %s", key)
rec := record.MakePutRecord(key, value)
rec.TimeReceived = proto.String(u.FormatRFC3339(time.Now()))
err = dht.putLocal(key, rec)
if err != nil {
return err
......@@ -101,6 +104,9 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string) (_ []byte, err err
recs = append(recs, v.Val)
}
}
if len(recs) == 0 {
return nil, routing.ErrNotFound
}
i, err := dht.Selector.BestRecord(key, recs)
if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment