Commit d921f5c9 authored by Steven Allen's avatar Steven Allen

add a per-dial transport-level dial timeout

parent e52c417a
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
peer "github.com/libp2p/go-libp2p-peer" peer "github.com/libp2p/go-libp2p-peer"
pstore "github.com/libp2p/go-libp2p-peerstore" pstore "github.com/libp2p/go-libp2p-peerstore"
swarmt "github.com/libp2p/go-libp2p-swarm/testing" swarmt "github.com/libp2p/go-libp2p-swarm/testing"
transport "github.com/libp2p/go-libp2p-transport"
testutil "github.com/libp2p/go-testutil" testutil "github.com/libp2p/go-testutil"
ci "github.com/libp2p/go-testutil/ci" ci "github.com/libp2p/go-testutil/ci"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
...@@ -20,7 +21,7 @@ import ( ...@@ -20,7 +21,7 @@ import (
) )
func init() { func init() {
DialTimeout = time.Second transport.DialTimeout = time.Second
} }
func closeSwarms(swarms []*Swarm) { func closeSwarms(swarms []*Swarm) {
...@@ -190,11 +191,11 @@ func TestDialWait(t *testing.T) { ...@@ -190,11 +191,11 @@ func TestDialWait(t *testing.T) {
} }
duration := time.Since(before) duration := time.Since(before)
if duration < DialTimeout*DialAttempts { if duration < transport.DialTimeout*DialAttempts {
t.Error("< DialTimeout * DialAttempts not being respected", duration, DialTimeout*DialAttempts) t.Error("< transport.DialTimeout * DialAttempts not being respected", duration, transport.DialTimeout*DialAttempts)
} }
if duration > 2*DialTimeout*DialAttempts { if duration > 2*transport.DialTimeout*DialAttempts {
t.Error("> 2*DialTimeout * DialAttempts not being respected", duration, 2*DialTimeout*DialAttempts) t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts)
} }
if !s1.Backoff().Backoff(s2p) { if !s1.Backoff().Backoff(s2p) {
...@@ -278,8 +279,8 @@ func TestDialBackoff(t *testing.T) { ...@@ -278,8 +279,8 @@ func TestDialBackoff(t *testing.T) {
s3done := dialOfflineNode(s3p, N) s3done := dialOfflineNode(s3p, N)
// when all dials should be done by: // when all dials should be done by:
dialTimeout1x := time.After(DialTimeout) dialTimeout1x := time.After(transport.DialTimeout)
dialTimeout10Ax := time.After(DialTimeout * 2 * 10) // DialAttempts * 10) dialTimeout10Ax := time.After(transport.DialTimeout * 2 * 10) // DialAttempts * 10)
// 2) all dials should hang // 2) all dials should hang
select { select {
...@@ -361,8 +362,8 @@ func TestDialBackoff(t *testing.T) { ...@@ -361,8 +362,8 @@ func TestDialBackoff(t *testing.T) {
s3done := dialOfflineNode(s3p, N) s3done := dialOfflineNode(s3p, N)
// when all dials should be done by: // when all dials should be done by:
dialTimeout1x := time.After(DialTimeout) dialTimeout1x := time.After(transport.DialTimeout)
dialTimeout10Ax := time.After(DialTimeout * 2 * 10) // DialAttempts * 10) dialTimeout10Ax := time.After(transport.DialTimeout * 2 * 10) // DialAttempts * 10)
// 7) s3 dials should all return immediately (except 1) // 7) s3 dials should all return immediately (except 1)
for i := 0; i < N-1; i++ { for i := 0; i < N-1; i++ {
...@@ -441,11 +442,11 @@ func TestDialBackoffClears(t *testing.T) { ...@@ -441,11 +442,11 @@ func TestDialBackoffClears(t *testing.T) {
} }
duration := time.Since(before) duration := time.Since(before)
if duration < DialTimeout*DialAttempts { if duration < transport.DialTimeout*DialAttempts {
t.Error("< DialTimeout * DialAttempts not being respected", duration, DialTimeout*DialAttempts) t.Error("< transport.DialTimeout * DialAttempts not being respected", duration, transport.DialTimeout*DialAttempts)
} }
if duration > 2*DialTimeout*DialAttempts { if duration > 2*transport.DialTimeout*DialAttempts {
t.Error("> 2*DialTimeout * DialAttempts not being respected", duration, 2*DialTimeout*DialAttempts) t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts)
} }
if !s1.Backoff().Backoff(s2.LocalPeer()) { if !s1.Backoff().Backoff(s2.LocalPeer()) {
......
...@@ -21,11 +21,6 @@ import ( ...@@ -21,11 +21,6 @@ import (
mafilter "github.com/whyrusleeping/multiaddr-filter" mafilter "github.com/whyrusleeping/multiaddr-filter"
) )
// DialTimeout is the maximum duration a Dial is allowed to take.
// This includes the time spent waiting in dial limiter, between dialing the raw
// network connection, protocol selection as well the handshake, if applicable.
var DialTimeout = 60 * time.Second
// DialTimeoutLocal is the maximum duration a Dial to local network address // DialTimeoutLocal is the maximum duration a Dial to local network address
// is allowed to take. // is allowed to take.
// This includes the time between dialing the raw network connection, // This includes the time between dialing the raw network connection,
...@@ -42,6 +37,11 @@ var ErrSwarmClosed = errors.New("swarm closed") ...@@ -42,6 +37,11 @@ var ErrSwarmClosed = errors.New("swarm closed")
// transport is misbehaving. // transport is misbehaving.
var ErrAddrFiltered = errors.New("address filtered") var ErrAddrFiltered = errors.New("address filtered")
// DialTimeout is the maximum duration a Dial is allowed to take.
// This includes the time between dialing the raw network connection,
// protocol selection as well the handshake, if applicable.
var DialTimeout = 60 * time.Second
// Swarm is a connection muxer, allowing connections to other peers to // Swarm is a connection muxer, allowing connections to other peers to
// be opened and closed, while still using the same Chan for all // be opened and closed, while still using the same Chan for all
// communication. The Chan sends/receives Messages, which note the // communication. The Chan sends/receives Messages, which note the
......
...@@ -221,9 +221,6 @@ func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -221,9 +221,6 @@ func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) {
return c, nil return c, nil
} }
ctx, cancel := context.WithTimeout(ctx, DialTimeout)
defer cancel()
logdial := lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) logdial := lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil)
// ok, we have been charged to dial! let's do it. // ok, we have been charged to dial! let's do it.
...@@ -259,6 +256,9 @@ func (s *Swarm) canDial(addr ma.Multiaddr) bool { ...@@ -259,6 +256,9 @@ func (s *Swarm) canDial(addr ma.Multiaddr) bool {
// dial is the actual swarm's dial logic, gated by Dial. // dial is the actual swarm's dial logic, gated by Dial.
func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) { func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) {
ctx, cancel := context.WithTimeout(ctx, DialTimeout)
defer cancel()
var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil)
if p == s.local { if p == s.local {
log.Event(ctx, "swarmDialDoDialSelf", logdial) log.Event(ctx, "swarmDialDoDialSelf", logdial)
...@@ -398,12 +398,15 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (tra ...@@ -398,12 +398,15 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (tra
} }
log.Debugf("%s swarm dialing %s %s", s.local, p, addr) log.Debugf("%s swarm dialing %s %s", s.local, p, addr)
transport := s.TransportForDialing(addr) tpt := s.TransportForDialing(addr)
if transport == nil { if tpt == nil {
return nil, ErrNoTransport return nil, ErrNoTransport
} }
connC, err := transport.Dial(ctx, addr, p) ctx, cancel := context.WithTimeout(ctx, transport.DialTimeout)
defer cancel()
connC, err := tpt.Dial(ctx, addr, p)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s --> %s dial attempt failed: %s", s.local, p, err) return nil, fmt.Errorf("%s --> %s dial attempt failed: %s", s.local, p, err)
} }
...@@ -411,7 +414,7 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (tra ...@@ -411,7 +414,7 @@ func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (tra
// Trust the transport? Yeah... right. // Trust the transport? Yeah... right.
if connC.RemotePeer() != p { if connC.RemotePeer() != p {
connC.Close() connC.Close()
err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), transport) err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), tpt)
log.Error(err) log.Error(err)
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment