Unverified Commit 28fa9fbb authored by Will Scott's avatar Will Scott

Merge branch 'master' of github.com:libp2p/go-libp2p-kad-dht into feat/dual

parents 85d5de75 796b95bc
...@@ -3,9 +3,13 @@ package dht ...@@ -3,9 +3,13 @@ package dht
import ( import (
"bytes" "bytes"
"net" "net"
"sync"
"time"
"github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
"github.com/google/gopacket/routing"
netroute "github.com/libp2p/go-netroute" netroute "github.com/libp2p/go-netroute"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
...@@ -64,10 +68,42 @@ func PrivateQueryFilter(dht *IpfsDHT, ai peer.AddrInfo) bool { ...@@ -64,10 +68,42 @@ func PrivateQueryFilter(dht *IpfsDHT, ai peer.AddrInfo) bool {
var _ QueryFilterFunc = PrivateQueryFilter var _ QueryFilterFunc = PrivateQueryFilter
// We call this very frequently but routes can technically change at runtime.
// Cache it for two minutes.
const routerCacheTime = 2 * time.Minute
var routerCache struct {
sync.RWMutex
router routing.Router
expires time.Time
}
func getCachedRouter() routing.Router {
routerCache.RLock()
router := routerCache.router
expires := routerCache.expires
routerCache.RUnlock()
if time.Now().Before(expires) {
return router
}
routerCache.Lock()
defer routerCache.Unlock()
now := time.Now()
if now.Before(routerCache.expires) {
return router
}
routerCache.router, _ = netroute.New()
routerCache.expires = now.Add(routerCacheTime)
return router
}
// PrivateRoutingTableFilter allows a peer to be added to the routing table if the connections to that peer indicate // PrivateRoutingTableFilter allows a peer to be added to the routing table if the connections to that peer indicate
// that it is on a private network // that it is on a private network
func PrivateRoutingTableFilter(dht *IpfsDHT, conns []network.Conn) bool { func PrivateRoutingTableFilter(dht *IpfsDHT, conns []network.Conn) bool {
router, _ := netroute.New() router := getCachedRouter()
myAdvertisedIPs := make([]net.IP, 0) myAdvertisedIPs := make([]net.IP, 0)
for _, a := range dht.Host().Addrs() { for _, a := range dht.Host().Addrs() {
if manet.IsPublicAddr(a) && !isRelayAddr(a) { if manet.IsPublicAddr(a) && !isRelayAddr(a) {
......
package dht package dht
import ( import (
"context"
"net"
"testing" "testing"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multiaddr"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
) )
func TestIsRelay(t *testing.T) { func TestIsRelay(t *testing.T) {
...@@ -21,3 +28,39 @@ func TestIsRelay(t *testing.T) { ...@@ -21,3 +28,39 @@ func TestIsRelay(t *testing.T) {
} }
} }
type mockConn struct {
local peer.AddrInfo
remote peer.AddrInfo
}
func (m *mockConn) Close() error { return nil }
func (m *mockConn) NewStream() (network.Stream, error) { return nil, nil }
func (m *mockConn) GetStreams() []network.Stream { return []network.Stream{} }
func (m *mockConn) Stat() network.Stat { return network.Stat{Direction: network.DirOutbound} }
func (m *mockConn) LocalMultiaddr() ma.Multiaddr { return m.local.Addrs[0] }
func (m *mockConn) RemoteMultiaddr() ma.Multiaddr { return m.remote.Addrs[0] }
func (m *mockConn) LocalPeer() peer.ID { return m.local.ID }
func (m *mockConn) LocalPrivateKey() ic.PrivKey { return nil }
func (m *mockConn) RemotePeer() peer.ID { return m.remote.ID }
func (m *mockConn) RemotePublicKey() ic.PubKey { return nil }
func TestFilterCaching(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
d := setupDHT(ctx, t, true)
remote, _ := manet.FromIP(net.IPv4(8, 8, 8, 8))
if PrivateRoutingTableFilter(d, []network.Conn{&mockConn{
local: d.Host().Peerstore().PeerInfo(d.Host().ID()),
remote: peer.AddrInfo{ID: "", Addrs: []ma.Multiaddr{remote}},
}}) {
t.Fatal("filter should prevent public remote peers.")
}
r1 := getCachedRouter()
r2 := getCachedRouter()
if r1 != r2 {
t.Fatal("router should be returned multiple times.")
}
}
...@@ -1934,3 +1934,27 @@ func TestInvalidKeys(t *testing.T) { ...@@ -1934,3 +1934,27 @@ func TestInvalidKeys(t *testing.T) {
t.Fatal("expected to have failed") t.Fatal("expected to have failed")
} }
} }
func TestRoutingFilter(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
nDHTs := 2
dhts := setupDHTS(t, ctx, nDHTs)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].host.Close()
}
}()
dhts[0].routingTablePeerFilter = PublicRoutingTableFilter
connectNoSync(t, ctx, dhts[0], dhts[1])
wait(t, ctx, dhts[1], dhts[0])
select {
case <-ctx.Done():
t.Fatal(ctx.Err())
case <-time.After(time.Millisecond * 200):
}
}
...@@ -178,11 +178,11 @@ func handleLocalReachabilityChangedEvent(dht *IpfsDHT, e event.EvtLocalReachabil ...@@ -178,11 +178,11 @@ func handleLocalReachabilityChangedEvent(dht *IpfsDHT, e event.EvtLocalReachabil
// routing table // routing table
func (dht *IpfsDHT) validRTPeer(p peer.ID) (bool, error) { func (dht *IpfsDHT) validRTPeer(p peer.ID) (bool, error) {
protos, err := dht.peerstore.SupportsProtocols(p, protocol.ConvertToStrings(dht.protocols)...) protos, err := dht.peerstore.SupportsProtocols(p, protocol.ConvertToStrings(dht.protocols)...)
if err != nil { if len(protos) == 0 || err != nil {
return false, err return false, err
} }
return len(protos) > 0, nil return dht.routingTablePeerFilter == nil || dht.routingTablePeerFilter(dht, dht.Host().Network().ConnsToPeer(p)), nil
} }
func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment