Commit d94deae4 authored by Erin Swenson-Healey's avatar Erin Swenson-Healey

make DHT protocols pluggable

parent c0d3351b
......@@ -34,9 +34,6 @@ import (
var log = logging.Logger("dht")
var ProtocolDHT protocol.ID = "/ipfs/kad/1.0.0"
var ProtocolDHTOld protocol.ID = "/ipfs/dht"
// NumBootstrapQueries defines the number of random dht queries to do to
// collect members of the routing table.
const NumBootstrapQueries = 5
......@@ -64,6 +61,8 @@ type IpfsDHT struct {
smlk sync.Mutex
plk sync.Mutex
protocols []protocol.ID // DHT protocols
}
// New creates a new DHT with the specified host and options.
......@@ -72,7 +71,7 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
if err := cfg.Apply(append([]opts.Option{opts.Defaults}, options...)...); err != nil {
return nil, err
}
dht := makeDHT(ctx, h, cfg.Datastore)
dht := makeDHT(ctx, h, cfg.Datastore, cfg.Protocols)
// register for network notifs.
dht.host.Network().Notify((*netNotifiee)(dht))
......@@ -87,8 +86,9 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
dht.Validator = cfg.Validator
if !cfg.Client {
h.SetStreamHandler(ProtocolDHT, dht.handleNewStream)
h.SetStreamHandler(ProtocolDHTOld, dht.handleNewStream)
for _, p := range cfg.Protocols {
h.SetStreamHandler(p, dht.handleNewStream)
}
}
return dht, nil
}
......@@ -116,7 +116,7 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT
return dht
}
func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []protocol.ID) *IpfsDHT {
rt := kb.NewRoutingTable(KValue, kb.ConvertPeerID(h.ID()), time.Minute, h.Peerstore())
cmgr := h.ConnManager()
......@@ -137,6 +137,7 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
providers: providers.NewProviderManager(ctx, h.ID(), dstore),
birth: time.Now(),
routingTable: rt,
protocols: protocols,
}
}
......@@ -389,6 +390,15 @@ func (dht *IpfsDHT) Close() error {
return dht.proc.Close()
}
func (dht *IpfsDHT) protocolStrs() []string {
pstrs := make([]string, len(dht.protocols))
for _, proto := range dht.protocols {
pstrs = append(pstrs, string(proto))
}
return pstrs
}
func mkDsKey(s string) ds.Key {
return ds.NewKey(base32.RawStdEncoding.EncodeToString([]byte(s)))
}
......@@ -190,7 +190,7 @@ func (ms *messageSender) prep() error {
return nil
}
nstr, err := ms.dht.host.NewStream(ms.dht.ctx, ms.p, ProtocolDHT, ProtocolDHTOld)
nstr, err := ms.dht.host.NewStream(ms.dht.ctx, ms.p, ms.dht.protocols...)
if err != nil {
return err
}
......
......@@ -13,6 +13,7 @@ import (
opts "github.com/libp2p/go-libp2p-kad-dht/opts"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p-protocol"
cid "github.com/ipfs/go-cid"
u "github.com/ipfs/go-ipfs-util"
......@@ -1075,3 +1076,39 @@ func TestFindClosestPeers(t *testing.T) {
t.Fatalf("got wrong number of peers (got %d, expected %d)", len(out), KValue)
}
}
func TestGetSetPluggedProtocol(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
os := []opts.Option{
opts.Protocols([]protocol.ID{"/esh/dht"}),
opts.Client(false),
opts.NamespacedValidator("v", blankValidator{}),
}
dhtA, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), os...)
if err != nil {
t.Fatal(err)
}
dhtB, err := New(ctx, bhost.New(netutil.GenSwarmNetwork(t, ctx)), os...)
if err != nil {
t.Fatal(err)
}
connect(t, ctx, dhtA, dhtB)
if err := dhtA.PutValue(ctx, "/v/cat", []byte("meow")); err != nil {
t.Fatal(err)
}
value, err := dhtB.GetValue(ctx, "/v/cat")
if err != nil {
t.Fatal(err)
}
if string(value) != "meow" {
t.Fatalf("Expected 'meow' got '%s'", string(value))
}
}
......@@ -36,7 +36,7 @@ func TestGetFailures(t *testing.T) {
d.Update(ctx, hosts[1].ID())
// Reply with failures to every message
hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) {
s.Close()
})
......@@ -58,7 +58,7 @@ func TestGetFailures(t *testing.T) {
t.Log("Timeout test passed.")
// Reply with failures to every message
hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) {
defer s.Close()
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
......@@ -110,7 +110,7 @@ func TestGetFailures(t *testing.T) {
Record: rec,
}
s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), ProtocolDHT)
s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), d.protocols[0])
if err != nil {
t.Fatal(err)
}
......@@ -160,7 +160,7 @@ func TestNotFound(t *testing.T) {
// Reply with random peers to every message
for _, host := range hosts {
host := host // shadow loop var
host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
host.SetStreamHandler(d.protocols[0], func(s inet.Stream) {
defer s.Close()
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
......@@ -239,7 +239,7 @@ func TestLessThanKResponses(t *testing.T) {
// Reply with random peers to every message
for _, host := range hosts {
host := host // shadow loop var
host.SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
host.SetStreamHandler(d.protocols[0], func(s inet.Stream) {
defer s.Close()
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
......@@ -305,7 +305,7 @@ func TestMultipleQueries(t *testing.T) {
// It would be nice to be able to just get a value and succeed but then
// we'd need to deal with selectors and validators...
hosts[1].SetStreamHandler(ProtocolDHT, func(s inet.Stream) {
hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) {
defer s.Close()
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
......
......@@ -9,8 +9,6 @@ import (
// netNotifiee defines methods to be used with the IpfsDHT
type netNotifiee IpfsDHT
var dhtProtocols = []string{string(ProtocolDHT), string(ProtocolDHTOld)}
func (nn *netNotifiee) DHT() *IpfsDHT {
return (*IpfsDHT)(nn)
}
......@@ -24,7 +22,7 @@ func (nn *netNotifiee) Connected(n inet.Network, v inet.Conn) {
}
p := v.RemotePeer()
protos, err := dht.peerstore.SupportsProtocols(p, dhtProtocols...)
protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...)
if err == nil && len(protos) != 0 {
// We lock here for consistency with the lock in testConnection.
// This probably isn't necessary because (dis)connect
......@@ -57,7 +55,7 @@ func (nn *netNotifiee) testConnection(v inet.Conn) {
}
defer s.Close()
selected, err := mstream.SelectOneOf(dhtProtocols, s)
selected, err := mstream.SelectOneOf(dht.protocolStrs(), s)
if err != nil {
// Doesn't support the protocol
return
......
......@@ -5,14 +5,20 @@ import (
ds "github.com/ipfs/go-datastore"
dssync "github.com/ipfs/go-datastore/sync"
"github.com/libp2p/go-libp2p-protocol"
record "github.com/libp2p/go-libp2p-record"
)
var ProtocolDHT protocol.ID = "/ipfs/kad/1.0.0"
var ProtocolDHTOld protocol.ID = "/ipfs/dht"
var DefaultProtocols = []protocol.ID{ProtocolDHT, ProtocolDHTOld}
// Options is a structure containing all the options that can be used when constructing a DHT.
type Options struct {
Datastore ds.Batching
Validator record.Validator
Client bool
Protocols []protocol.ID
}
// Apply applies the given options to this Option
......@@ -35,6 +41,7 @@ var Defaults = func(o *Options) error {
"pk": record.PublicKeyValidator{},
}
o.Datastore = dssync.MutexWrap(ds.NewMapDatastore())
o.Protocols = DefaultProtocols
return nil
}
......@@ -85,3 +92,13 @@ func NamespacedValidator(ns string, v record.Validator) Option {
return nil
}
}
// Protocols sets the protocols for the DHT
//
// Defaults to dht.DefaultProtocols
func Protocols(protocols []protocol.ID) Option {
return func(o *Options) error {
o.Protocols = protocols
return nil
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment