cleaner parallelism

parent 02d310fd
......@@ -4,7 +4,6 @@ package dual
import (
"context"
"fmt"
"sync"
"github.com/ipfs/go-cid"
......@@ -15,6 +14,8 @@ import (
"github.com/libp2p/go-libp2p-core/routing"
dht "github.com/libp2p/go-libp2p-kad-dht"
helper "github.com/libp2p/go-libp2p-routing-helpers"
"github.com/hashicorp/go-multierror"
)
// DHT implements the routing interface to provide two concrete DHT implementationts for use
......@@ -24,8 +25,8 @@ type DHT struct {
LAN *dht.IpfsDHT
}
// DefaultLanExtension is used to differentiate local protocol requests from those on the WAN DHT.
const DefaultLanExtension protocol.ID = "/lan"
// LanExtension is used to differentiate local protocol requests from those on the WAN DHT.
const LanExtension protocol.ID = "/lan"
// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
// guarantee, but we can use them to aid refactoring.
......@@ -55,7 +56,7 @@ func New(ctx context.Context, h host.Host, options ...dht.Option) (*DHT, error)
// Unless overridden by user supplied options, the LAN DHT should default
// to 'AutoServer' mode.
lanOpts := append(options,
dht.ProtocolExtension(DefaultLanExtension),
dht.ProtocolExtension(LanExtension),
dht.QueryFilter(dht.PrivateQueryFilter),
dht.RoutingTableFilter(dht.PrivateRoutingTableFilter),
)
......@@ -73,7 +74,7 @@ func New(ctx context.Context, h host.Host, options ...dht.Option) (*DHT, error)
// Close closes the DHT context.
func (dht *DHT) Close() error {
return mergeErrors(dht.WAN.Close(), dht.LAN.Close())
return multierror.Append(dht.WAN.Close(), dht.LAN.Close()).ErrorOrNil()
}
func (dht *DHT) activeWAN() bool {
......@@ -99,21 +100,18 @@ func (dht *DHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int)
defer close(outCh)
found := make(map[peer.ID]struct{}, count)
nch := 2
var pi peer.AddrInfo
for nch > 0 && count > 0 {
for count > 0 && (wanCh != nil || lanCh != nil) {
var ok bool
select {
case pi, ok = <-wanCh:
if !ok {
wanCh = nil
nch--
continue
}
case pi, ok = <-lanCh:
if !ok {
lanCh = nil
nch--
continue
}
}
......@@ -155,18 +153,7 @@ func (dht *DHT) FindPeer(ctx context.Context, pid peer.ID) (peer.AddrInfo, error
return peer.AddrInfo{
ID: pid,
Addrs: append(wanInfo.Addrs, lanInfo.Addrs...),
}, mergeErrors(wanErr, lanErr)
}
func mergeErrors(a, b error) error {
if a == nil && b == nil {
return nil
} else if a != nil && b != nil {
return fmt.Errorf("%v, %v", a, b)
} else if a != nil {
return a
}
return b
}, multierror.Append(wanErr, lanErr).ErrorOrNil()
}
// Bootstrap allows callers to hint to the routing system to get into a
......@@ -174,7 +161,7 @@ func mergeErrors(a, b error) error {
func (dht *DHT) Bootstrap(ctx context.Context) error {
erra := dht.WAN.Bootstrap(ctx)
errb := dht.LAN.Bootstrap(ctx)
return mergeErrors(erra, errb)
return multierror.Append(erra, errb).ErrorOrNil()
}
// PutValue adds value corresponding to given Key.
......@@ -190,43 +177,24 @@ func (d *DHT) GetValue(ctx context.Context, key string, opts ...routing.Option)
reqCtx, cncl := context.WithCancel(ctx)
defer cncl()
resChan := make(chan []byte)
defer close(resChan)
errChan := make(chan error)
defer close(errChan)
runner := func(impl *dht.IpfsDHT, valCh chan []byte, errCh chan error) {
val, err := impl.GetValue(reqCtx, key, opts...)
if err != nil {
errCh <- err
return
}
valCh <- val
}
go runner(d.WAN, resChan, errChan)
go runner(d.LAN, resChan, errChan)
var err error
var val []byte
select {
case val = <-resChan:
cncl()
case err = <-errChan:
}
var lanVal []byte
var lanErr error
var lanWaiter sync.WaitGroup
lanWaiter.Add(1)
go func() {
defer lanWaiter.Done()
lanVal, lanErr = d.LAN.GetValue(reqCtx, key, opts...)
}()
// Drain or wait for the slower runner
select {
case secondVal := <-resChan:
if val == nil {
val = secondVal
}
case secondErr := <-errChan:
if err != nil {
err = mergeErrors(err, secondErr)
} else if val == nil {
err = secondErr
wanVal, wanErr := d.WAN.GetValue(ctx, key, opts...)
if wanErr != nil {
lanWaiter.Wait()
if lanErr != nil {
return nil, multierror.Append(wanErr, lanErr).ErrorOrNil()
}
return lanVal, nil
}
return val, err
return wanVal, nil
}
// SearchValue searches for better values from this value
......@@ -236,45 +204,7 @@ func (dht *DHT) SearchValue(ctx context.Context, key string, opts ...routing.Opt
}
// GetPublicKey returns the public key for the given peer.
func (d *DHT) GetPublicKey(ctx context.Context, pid peer.ID) (ci.PubKey, error) {
reqCtx, cncl := context.WithCancel(ctx)
defer cncl()
resChan := make(chan ci.PubKey)
defer close(resChan)
errChan := make(chan error)
defer close(errChan)
runner := func(impl *dht.IpfsDHT, valCh chan ci.PubKey, errCh chan error) {
val, err := impl.GetPublicKey(reqCtx, pid)
if err != nil {
errCh <- err
return
}
valCh <- val
}
go runner(d.WAN, resChan, errChan)
go runner(d.LAN, resChan, errChan)
var err error
var val ci.PubKey
select {
case val = <-resChan:
cncl()
case err = <-errChan:
}
// Drain or wait for the slower runner
select {
case secondVal := <-resChan:
if val == nil {
val = secondVal
}
case secondErr := <-errChan:
if err != nil {
err = mergeErrors(err, secondErr)
} else if val == nil {
err = secondErr
}
}
return val, err
func (dht *DHT) GetPublicKey(ctx context.Context, pid peer.ID) (ci.PubKey, error) {
p := helper.Parallel{Routers: []routing.Routing{dht.WAN, dht.LAN}, Validator: dht.WAN.Validator}
return p.GetPublicKey(ctx, pid)
}
......@@ -65,7 +65,7 @@ func setupDHTWithFilters(ctx context.Context, t *testing.T, options ...dht.Optio
lanOpts := []dht.Option{
dht.NamespacedValidator("v", blankValidator{}),
dht.ProtocolPrefix("/test"),
dht.ProtocolExtension(DefaultLanExtension),
dht.ProtocolExtension(LanExtension),
dht.DisableAutoRefresh(),
dht.RoutingTableFilter(lanFilter),
dht.Mode(dht.ModeServer),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment