From 88c64b7274ecdcb473394ebb3b3493426cb0fdc0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 8 Feb 2021 17:06:32 +0800 Subject: [PATCH] prevent dialing addresses that we're listening on It's impossible to run two nodes on the same IP:port, so we know for sure that any dial to an address that we're listening on will fail (as the peer IDs won't match). In practice, this will be most useful for preventing dials to localhost for nodes that are listening on the default port. --- swarm_dial.go | 2 +- swarm_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/swarm_dial.go b/swarm_dial.go index 4dbb1f2..30a9ab0 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -440,7 +440,7 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul for _, addr := range lisAddrs { protos := addr.Protocols() // we're only sure about filtering out /ip4 and /ip6 addresses, so far - if len(protos) == 2 && (protos[0].Code == ma.P_IP4 || protos[0].Code == ma.P_IP6) { + if protos[0].Code == ma.P_IP4 || protos[0].Code == ma.P_IP6 { ourAddrs = append(ourAddrs, addr) } } diff --git a/swarm_test.go b/swarm_test.go index bd2c844..9b1e9c4 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -3,8 +3,10 @@ package swarm_test import ( "bytes" "context" + "errors" "fmt" "io" + "strings" "sync" "testing" "time" @@ -20,6 +22,7 @@ import ( logging "github.com/ipfs/go-log" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) @@ -457,3 +460,31 @@ func TestCloseWithOpenStreams(t *testing.T) { t.Fatal(err) } } + +func TestPreventDialListenAddr(t *testing.T) { + s := GenSwarm(t, context.Background(), OptDialOnly) + if err := s.Listen(ma.StringCast("/ip4/0.0.0.0/udp/0/quic")); err != nil { + t.Fatal(err) + } + addrs, err := s.InterfaceListenAddresses() + if err != nil { + t.Fatal(err) + } + var addr ma.Multiaddr + for _, a := range addrs { + _, s, err := manet.DialArgs(a) + if err != nil { + t.Fatal(err) + } + if strings.Split(s, ":")[0] == "127.0.0.1" { + addr = a + break + } + } + remote := peer.ID("foobar") + s.Peerstore().AddAddr(remote, addr, time.Hour) + _, err = s.DialPeer(context.Background(), remote) + if !errors.Is(err, ErrNoGoodAddresses) { + t.Fatal("expected dial to fail: %w", err) + } +} -- GitLab