Commit 9d50a8c4 authored by vyzo's avatar vyzo

implement dial worker for synchronizing simultaneous dials

parent 8930f293
......@@ -5,6 +5,7 @@ import (
"errors"
"sync"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
)
......@@ -12,88 +13,74 @@ import (
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
// DialFunc is the type of function expected by DialSync.
type DialFunc func(context.Context, peer.ID) (*Conn, error)
type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest)
// NewDialSync constructs a new DialSync
func NewDialSync(dfn DialFunc) *DialSync {
func NewDialSync(worker DialWorkerFunc) *DialSync {
return &DialSync{
dials: make(map[peer.ID]*activeDial),
dialFunc: dfn,
dials: make(map[peer.ID]*activeDial),
dialWorker: worker,
}
}
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
type DialSync struct {
dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialFunc DialFunc
dials map[peer.ID]*activeDial
dialsLk sync.Mutex
dialWorker DialWorkerFunc
}
type activeDial struct {
id peer.ID
refCnt int
refCntLk sync.Mutex
cancel func()
id peer.ID
refCnt int
err error
conn *Conn
waitch chan struct{}
ctx context.Context
cancel func()
ds *DialSync
}
reqch chan dialRequest
func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
defer ad.decref()
select {
case <-ad.waitch:
return ad.conn, ad.err
case <-ctx.Done():
return nil, ctx.Err()
}
ds *DialSync
}
func (ad *activeDial) incref() {
ad.refCntLk.Lock()
defer ad.refCntLk.Unlock()
ad.refCnt++
}
func (ad *activeDial) decref() {
ad.refCntLk.Lock()
ad.ds.dialsLk.Lock()
ad.refCnt--
maybeZero := (ad.refCnt <= 0)
ad.refCntLk.Unlock()
// make sure to always take locks in correct order.
if maybeZero {
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
// check again after lock swap drop to make sure nobody else called incref
// in between locks
if ad.refCnt <= 0 {
ad.cancel()
delete(ad.ds.dials, ad.id)
}
ad.refCntLk.Unlock()
ad.ds.dialsLk.Unlock()
if ad.refCnt == 0 {
ad.cancel()
close(ad.reqch)
delete(ad.ds.dials, ad.id)
}
ad.ds.dialsLk.Unlock()
}
func (ad *activeDial) start(ctx context.Context) {
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)
// This isn't the user's context so we should fix the error.
switch ad.err {
case context.Canceled:
// The dial was canceled with `CancelDial`.
ad.err = errDialCanceled
case context.DeadlineExceeded:
// We hit an internal timeout, not a context timeout.
ad.err = ErrDialTimeout
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
dialCtx := ad.ctx
if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
dialCtx = network.WithForceDirectDial(dialCtx, reason)
}
if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
}
resch := make(chan dialResponse, 1)
select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case <-ctx.Done():
return nil, ctx.Err()
}
select {
case res := <-resch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
close(ad.waitch)
ad.cancel()
}
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
......@@ -109,13 +96,14 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
adctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{
id: p,
ctx: adctx,
cancel: cancel,
waitch: make(chan struct{}),
reqch: make(chan dialRequest),
ds: ds,
}
ds.dials[p] = actd
go actd.start(adctx)
go ds.dialWorker(adctx, p, actd.reqch)
}
// increase ref count before dropping dialsLk
......@@ -127,14 +115,8 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return ds.getActiveDial(p).wait(ctx)
}
ad := ds.getActiveDial(p)
defer ad.decref()
// CancelDial cancels all in-progress dials to the given peer.
func (ds *DialSync) CancelDial(p peer.ID) {
ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
if ad, ok := ds.dials[p]; ok {
ad.cancel()
}
return ad.dial(ctx, p)
}
......@@ -121,7 +121,7 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc
}
}
s.dsync = NewDialSync(s.doDial)
s.dsync = NewDialSync(s.dialWorker)
s.limiter = newDialLimiter(s.dialAddr, s.IsFdConsumingAddr)
s.proc = goprocessctx.WithContext(ctx)
s.ctx = goprocessctx.OnClosingContext(s.proc)
......@@ -259,12 +259,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
c.notifyLk.Lock()
s.conns.Unlock()
// We have a connection now. Cancel all other in-progress dials.
// This should be fast, no reason to wait till later.
if dir == network.DirOutbound {
s.dsync.CancelDial(p)
}
s.notifyAll(func(f network.Notifiee) {
f.Connected(s, c)
})
......
......@@ -14,7 +14,6 @@ import (
addrutil "github.com/libp2p/go-addr-util"
lgbl "github.com/libp2p/go-libp2p-loggables"
logging "github.com/ipfs/go-log"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
......@@ -58,6 +57,12 @@ var (
ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer")
)
var (
DelayDialPrivateAddr = 5 * time.Millisecond
DelayDialPublicAddr = 50 * time.Millisecond
DelayDialRelayAddr = 100 * time.Millisecond
)
// DialAttempts governs how many times a goroutine will try to dial a given peer.
// Note: this is down to one, as we have _too many dials_ atm. To add back in,
// add loop back in Dial(.)
......@@ -281,39 +286,306 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, err
}
// doDial is an ugly shim method to retain all the logging and backoff logic
// of the old dialsync code
func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) {
// Short circuit.
// By the time we take the dial lock, we may already *have* a connection
// to the peer.
c := s.bestAcceptableConnToPeer(ctx, p)
if c != nil {
return c, nil
///////////////////////////////////////////////////////////////////////////////////
// lo and behold, The Dialer
// TODO explain how all this works
//////////////////////////////////////////////////////////////////////////////////
type dialRequest struct {
ctx context.Context
resch chan dialResponse
}
type dialResponse struct {
conn *Conn
err error
}
type dialComplete struct {
addr ma.Multiaddr
conn *Conn
err error
}
// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) {
if p == s.local {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
req.resch <- dialResponse{err: ErrDialToSelf}
}
}
}
logdial := lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil)
s.dialWorkerLoop(ctx, p, reqch)
}
// ok, we have been charged to dial! let's do it.
// if it succeeds, dial will add the conn to the swarm itself.
defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done()
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) {
defer s.limiter.clearAllPeerDials(p)
conn, err := s.dial(ctx, p)
if err != nil {
conn := s.bestAcceptableConnToPeer(ctx, p)
if conn != nil {
// Hm? What error?
// Could have canceled the dial because we received a
// connection or some other random reason.
// Just ignore the error and return the connection.
log.Debugf("ignoring dial error because we already have a connection: %s", err)
return conn, nil
type pendRequest struct {
req dialRequest // the original request
err *DialError // dial error accumulator
addrs map[ma.Multiaddr]struct{} // pending addr dials
}
type addrDial struct {
ctx context.Context
conn *Conn
err error
requests []int
}
reqno := 0
requests := make(map[int]*pendRequest)
pending := make(map[ma.Multiaddr]*addrDial)
var triggerDial <-chan time.Time
var nextDial []ma.Multiaddr
active := 0
done := false
resch := make(chan dialComplete)
loop:
for {
select {
case req, ok := <-reqch:
if !ok {
// request channel has been closed, wait for pending dials to complete
if active > 0 {
done = true
reqch = nil
triggerDial = nil
continue loop
}
// no active dials, we are done
return
}
c := s.bestAcceptableConnToPeer(req.ctx, p)
if c != nil {
req.resch <- dialResponse{conn: c}
continue loop
}
addrs, err := s.addrsForDial(req.ctx, p)
if err != nil {
req.resch <- dialResponse{err: err}
continue loop
}
// at this point, len(addrs) > 0 or else it would be error from addrsForDial
// ranke them to process in order
addrs = s.rankAddrs(addrs)
// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: p},
addrs: make(map[ma.Multiaddr]struct{}),
}
for _, a := range addrs {
pr.addrs[a] = struct{}{}
}
// check if any of the addrs has been successfully dialed and accumulate
// errors from complete dials while collecting new addrs to dial/join
var todial []ma.Multiaddr
var tojoin []*addrDial
for _, a := range addrs {
ad, ok := pending[a]
if !ok {
todial = append(todial, a)
continue
}
if ad.conn != nil {
// dial to this addr was successful, complete the request
req.resch <- dialResponse{conn: ad.conn}
continue loop
}
if ad.err != nil {
// dial to this addr errored, accumulate the error
pr.err.recordErr(a, ad.err)
delete(pr.addrs, a)
}
// dial is still pending, add to the join list
tojoin = append(tojoin, ad)
}
if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
req.resch <- dialResponse{err: pr.err}
continue loop
}
// the request has some pending or new dials, track it and schedule new dials
reqno++
requests[reqno] = pr
for _, ad := range tojoin {
ad.requests = append(ad.requests, reqno)
}
if len(todial) > 0 {
for _, a := range todial {
pending[a] = &addrDial{ctx: req.ctx, requests: []int{reqno}}
}
nextDial = append(nextDial, todial...)
nextDial = s.rankAddrs(nextDial)
if triggerDial == nil {
trigger := make(chan time.Time)
close(trigger)
triggerDial = trigger
}
}
case <-triggerDial:
if len(nextDial) == 0 {
triggerDial = nil
continue loop
}
next := nextDial[0]
nextDial = nextDial[1:]
// spawn the next dial
ad := pending[next]
go s.dialNextAddr(ad.ctx, p, next, resch)
active++
// select an appropriate delay for the next dial trigger
delay := s.delayForNextDial(next)
triggerDial = time.After(delay)
case res := <-resch:
active--
if done && active == 0 {
return
}
ad := pending[res.addr]
ad.conn = res.conn
ad.err = res.err
dialRequests := ad.requests
ad.requests = nil
if res.conn != nil {
// we got a connection, dispatch to still pending requests
for _, reqno := range dialRequests {
pr, ok := requests[reqno]
if !ok {
// it has already dispatched a connection
continue
}
pr.req.resch <- dialResponse{conn: res.conn}
delete(requests, reqno)
}
continue loop
}
// it must be an error, accumulate it and dispatch dial error if the request has tried all addrs
for _, reqno := range dialRequests {
pr, ok := requests[reqno]
if !ok {
// has already been dispatched
continue
}
// accumulate the error
pr.err.recordErr(res.addr, res.err)
delete(pr.addrs, res.addr)
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
pr.req.resch <- dialResponse{err: pr.err}
delete(requests, reqno)
}
}
}
}
}
// ok, we failed.
return nil, err
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
peerAddrs := s.peers.Addrs(p)
if len(peerAddrs) == 0 {
return nil, ErrNoAddresses
}
goodAddrs := s.filterKnownUndialables(p, peerAddrs)
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
goodAddrs = addrutil.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
return conn, nil
if len(goodAddrs) == 0 {
return nil, ErrNoGoodAddresses
}
return goodAddrs, nil
}
func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialComplete) {
// check the dial backoff
if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect {
if s.backf.Backoff(p, addr) {
resch <- dialComplete{addr: addr, err: ErrDialBackoff}
return
}
}
// start the dial
dresch := make(chan dialResult)
s.limitedDial(ctx, p, addr, dresch)
select {
case res := <-dresch:
if res.Err != nil {
if res.Err != context.Canceled {
s.backf.AddBackoff(p, addr)
}
resch <- dialComplete{addr: addr, err: res.Err}
return
}
conn, err := s.addConn(res.Conn, network.DirOutbound)
if err != nil {
res.Conn.Close()
resch <- dialComplete{addr: addr, err: err}
return
}
resch <- dialComplete{addr: addr, conn: conn}
case <-ctx.Done():
resch <- dialComplete{addr: addr, err: ctx.Err()}
}
}
func (s *Swarm) delayForNextDial(addr ma.Multiaddr) time.Duration {
if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err == nil {
return DelayDialRelayAddr
}
if manet.IsPrivateAddr(addr) {
return DelayDialPrivateAddr
}
return DelayDialPublicAddr
}
func (s *Swarm) canDial(addr ma.Multiaddr) bool {
......@@ -365,80 +637,6 @@ func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
return append(append(append(localUdpAddrs, othersUdp...), fds...), relays...)
}
// dial is the actual swarm's dial logic, gated by Dial.
func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) {
forceDirect, _ := network.GetForceDirectDial(ctx)
var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil)
if p == s.local {
log.Event(ctx, "swarmDialDoDialSelf", logdial)
return nil, ErrDialToSelf
}
defer log.EventBegin(ctx, "swarmDialDo", logdial).Done()
logdial["dial"] = "failure" // start off with failure. set to "success" at the end.
sk := s.peers.PrivKey(s.local)
logdial["encrypted"] = sk != nil // log whether this will be an encrypted dial or not.
if sk == nil {
// fine for sk to be nil, just log.
log.Debug("Dial not given PrivateKey, so WILL NOT SECURE conn.")
}
//////
peerAddrs := s.peers.Addrs(p)
if len(peerAddrs) == 0 {
return nil, &DialError{Peer: p, Cause: ErrNoAddresses}
}
goodAddrs := s.filterKnownUndialables(p, peerAddrs)
if forceDirect {
goodAddrs = addrutil.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
if len(goodAddrs) == 0 {
return nil, &DialError{Peer: p, Cause: ErrNoGoodAddresses}
}
if !forceDirect {
/////// Check backoff andnRank addresses
var nonBackoff bool
for _, a := range goodAddrs {
// skip addresses in back-off
if !s.backf.Backoff(p, a) {
nonBackoff = true
}
}
if !nonBackoff {
return nil, ErrDialBackoff
}
}
connC, dialErr := s.dialAddrs(ctx, p, s.rankAddrs(goodAddrs))
if dialErr != nil {
logdial["error"] = dialErr.Cause.Error()
switch dialErr.Cause {
case context.Canceled, context.DeadlineExceeded:
// Always prefer the context errors as we rely on being
// able to check them.
//
// Removing this will BREAK backoff (causing us to
// backoff when canceling dials).
return nil, dialErr.Cause
}
return nil, dialErr
}
logdial["conn"] = logging.Metadata{
"localAddr": connC.LocalMultiaddr(),
"remoteAddr": connC.RemoteMultiaddr(),
}
swarmC, err := s.addConn(connC, network.DirOutbound)
if err != nil {
logdial["error"] = err.Error()
connC.Close() // close the connection. didn't work out :(
return nil, &DialError{Peer: p, Cause: err}
}
logdial["dial"] = "success"
return swarmC, nil
}
// filterKnownUndialables takes a list of multiaddrs, and removes those
// that we definitely don't want to dial: addresses configured to be blocked,
// IPv6 link-local addresses, addresses without a dial-capable transport,
......@@ -466,98 +664,6 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul
)
}
func (s *Swarm) dialAddrs(ctx context.Context, p peer.ID, remoteAddrs []ma.Multiaddr) (transport.CapableConn, *DialError) {
/*
This slice-to-chan code is temporary, the peerstore can currently provide
a channel as an interface for receiving addresses, but more thought
needs to be put into the execution. For now, this allows us to use
the improved rate limiter, while maintaining the outward behaviour
that we previously had (halting a dial when we run out of addrs)
*/
var remoteAddrChan chan ma.Multiaddr
if len(remoteAddrs) > 0 {
remoteAddrChan = make(chan ma.Multiaddr, len(remoteAddrs))
for i := range remoteAddrs {
remoteAddrChan <- remoteAddrs[i]
}
close(remoteAddrChan)
}
log.Debugf("%s swarm dialing %s", s.local, p)
ctx, cancel := context.WithCancel(ctx)
defer cancel() // cancel work when we exit func
// use a single response type instead of errs and conns, reduces complexity *a ton*
respch := make(chan dialResult)
err := &DialError{Peer: p}
defer s.limiter.clearAllPeerDials(p)
var active int
dialLoop:
for remoteAddrChan != nil || active > 0 {
// Check for context cancellations and/or responses first.
select {
case <-ctx.Done():
break dialLoop
case resp := <-respch:
active--
if resp.Err != nil {
// Errors are normal, lots of dials will fail
if resp.Err != context.Canceled {
s.backf.AddBackoff(p, resp.Addr)
}
log.Infof("got error on dial: %s", resp.Err)
err.recordErr(resp.Addr, resp.Err)
} else if resp.Conn != nil {
return resp.Conn, nil
}
// We got a result, try again from the top.
continue
default:
}
// Now, attempt to dial.
select {
case addr, ok := <-remoteAddrChan:
if !ok {
remoteAddrChan = nil
continue
}
s.limitedDial(ctx, p, addr, respch)
active++
case <-ctx.Done():
break dialLoop
case resp := <-respch:
active--
if resp.Err != nil {
// Errors are normal, lots of dials will fail
if resp.Err != context.Canceled {
s.backf.AddBackoff(p, resp.Addr)
}
log.Infof("got error on dial: %s", resp.Err)
err.recordErr(resp.Addr, resp.Err)
} else if resp.Conn != nil {
return resp.Conn, nil
}
}
}
if ctxErr := ctx.Err(); ctxErr != nil {
err.Cause = ctxErr
} else if len(err.DialErrors) == 0 {
err.Cause = network.ErrNoRemoteAddrs
} else {
err.Cause = ErrAllDialsFailed
}
return nil, err
}
// limitedDial will start a dial to the given peer when
// it is able, respecting the various different types of rate
// limiting that occur without using extra goroutines per addr
......@@ -570,6 +676,7 @@ func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp
})
}
// dialAddr is the actual dial for an addr, indirectly invoked through the limiter
func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) {
// Just to double check. Costs nothing.
if s.local == p {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment