diff --git a/dial_sync.go b/dial_sync.go index 50f3a69821e443b2cfa3469c36f4881a41fa00cd..2efdf067b60c62bc4cc5a8fb41c1b76695aac4cd 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -5,6 +5,7 @@ import ( "errors" "sync" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" ) @@ -12,88 +13,74 @@ import ( var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") // DialFunc is the type of function expected by DialSync. -type DialFunc func(context.Context, peer.ID) (*Conn, error) +type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) // NewDialSync constructs a new DialSync -func NewDialSync(dfn DialFunc) *DialSync { +func NewDialSync(worker DialWorkerFunc) *DialSync { return &DialSync{ - dials: make(map[peer.ID]*activeDial), - dialFunc: dfn, + dials: make(map[peer.ID]*activeDial), + dialWorker: worker, } } // DialSync is a dial synchronization helper that ensures that at most one dial // to any given peer is active at any given time. type DialSync struct { - dials map[peer.ID]*activeDial - dialsLk sync.Mutex - dialFunc DialFunc + dials map[peer.ID]*activeDial + dialsLk sync.Mutex + dialWorker DialWorkerFunc } type activeDial struct { - id peer.ID - refCnt int - refCntLk sync.Mutex - cancel func() + id peer.ID + refCnt int - err error - conn *Conn - waitch chan struct{} + ctx context.Context + cancel func() - ds *DialSync -} + reqch chan dialRequest -func (ad *activeDial) wait(ctx context.Context) (*Conn, error) { - defer ad.decref() - select { - case <-ad.waitch: - return ad.conn, ad.err - case <-ctx.Done(): - return nil, ctx.Err() - } + ds *DialSync } func (ad *activeDial) incref() { - ad.refCntLk.Lock() - defer ad.refCntLk.Unlock() ad.refCnt++ } func (ad *activeDial) decref() { - ad.refCntLk.Lock() + ad.ds.dialsLk.Lock() ad.refCnt-- - maybeZero := (ad.refCnt <= 0) - ad.refCntLk.Unlock() - - // make sure to always take locks in correct order. - if maybeZero { - ad.ds.dialsLk.Lock() - ad.refCntLk.Lock() - // check again after lock swap drop to make sure nobody else called incref - // in between locks - if ad.refCnt <= 0 { - ad.cancel() - delete(ad.ds.dials, ad.id) - } - ad.refCntLk.Unlock() - ad.ds.dialsLk.Unlock() + if ad.refCnt == 0 { + ad.cancel() + close(ad.reqch) + delete(ad.ds.dials, ad.id) } + ad.ds.dialsLk.Unlock() } -func (ad *activeDial) start(ctx context.Context) { - ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id) - - // This isn't the user's context so we should fix the error. - switch ad.err { - case context.Canceled: - // The dial was canceled with `CancelDial`. - ad.err = errDialCanceled - case context.DeadlineExceeded: - // We hit an internal timeout, not a context timeout. - ad.err = ErrDialTimeout +func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { + dialCtx := ad.ctx + + if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect { + dialCtx = network.WithForceDirectDial(dialCtx, reason) + } + if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect { + dialCtx = network.WithSimultaneousConnect(dialCtx, reason) + } + + resch := make(chan dialResponse, 1) + select { + case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}: + case <-ctx.Done(): + return nil, ctx.Err() + } + + select { + case res := <-resch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() } - close(ad.waitch) - ad.cancel() } func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { @@ -109,13 +96,14 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { adctx, cancel := context.WithCancel(context.Background()) actd = &activeDial{ id: p, + ctx: adctx, cancel: cancel, - waitch: make(chan struct{}), + reqch: make(chan dialRequest), ds: ds, } ds.dials[p] = actd - go actd.start(adctx) + go ds.dialWorker(adctx, p, actd.reqch) } // increase ref count before dropping dialsLk @@ -127,14 +115,8 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { // DialLock initiates a dial to the given peer if there are none in progress // then waits for the dial to that peer to complete. func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { - return ds.getActiveDial(p).wait(ctx) -} + ad := ds.getActiveDial(p) + defer ad.decref() -// CancelDial cancels all in-progress dials to the given peer. -func (ds *DialSync) CancelDial(p peer.ID) { - ds.dialsLk.Lock() - defer ds.dialsLk.Unlock() - if ad, ok := ds.dials[p]; ok { - ad.cancel() - } + return ad.dial(ctx, p) } diff --git a/swarm.go b/swarm.go index 3b3cf832280256fb865266218633ed90c04f0c94..c57c563c251b6f4ac9cfef0836e4cdb59c52b2e1 100644 --- a/swarm.go +++ b/swarm.go @@ -121,7 +121,7 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc } } - s.dsync = NewDialSync(s.doDial) + s.dsync = NewDialSync(s.dialWorker) s.limiter = newDialLimiter(s.dialAddr, s.IsFdConsumingAddr) s.proc = goprocessctx.WithContext(ctx) s.ctx = goprocessctx.OnClosingContext(s.proc) @@ -259,12 +259,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() - // We have a connection now. Cancel all other in-progress dials. - // This should be fast, no reason to wait till later. - if dir == network.DirOutbound { - s.dsync.CancelDial(p) - } - s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) diff --git a/swarm_dial.go b/swarm_dial.go index ccf33c2e4c2c21fbd6acca1a9378d74e764f48dc..cc739791dc00d67921944dfa5fc72484cb62b5ce 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -14,7 +14,6 @@ import ( addrutil "github.com/libp2p/go-addr-util" lgbl "github.com/libp2p/go-libp2p-loggables" - logging "github.com/ipfs/go-log" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -58,6 +57,12 @@ var ( ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer") ) +var ( + DelayDialPrivateAddr = 5 * time.Millisecond + DelayDialPublicAddr = 50 * time.Millisecond + DelayDialRelayAddr = 100 * time.Millisecond +) + // DialAttempts governs how many times a goroutine will try to dial a given peer. // Note: this is down to one, as we have _too many dials_ atm. To add back in, // add loop back in Dial(.) @@ -281,39 +286,306 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { return nil, err } -// doDial is an ugly shim method to retain all the logging and backoff logic -// of the old dialsync code -func (s *Swarm) doDial(ctx context.Context, p peer.ID) (*Conn, error) { - // Short circuit. - // By the time we take the dial lock, we may already *have* a connection - // to the peer. - c := s.bestAcceptableConnToPeer(ctx, p) - if c != nil { - return c, nil +/////////////////////////////////////////////////////////////////////////////////// +// lo and behold, The Dialer +// TODO explain how all this works +////////////////////////////////////////////////////////////////////////////////// +type dialRequest struct { + ctx context.Context + resch chan dialResponse +} + +type dialResponse struct { + conn *Conn + err error +} + +type dialComplete struct { + addr ma.Multiaddr + conn *Conn + err error +} + +// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials +func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { + if p == s.local { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + req.resch <- dialResponse{err: ErrDialToSelf} + } + } } - logdial := lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) + s.dialWorkerLoop(ctx, p, reqch) +} - // ok, we have been charged to dial! let's do it. - // if it succeeds, dial will add the conn to the swarm itself. - defer log.EventBegin(ctx, "swarmDialAttemptStart", logdial).Done() +func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { + defer s.limiter.clearAllPeerDials(p) - conn, err := s.dial(ctx, p) - if err != nil { - conn := s.bestAcceptableConnToPeer(ctx, p) - if conn != nil { - // Hm? What error? - // Could have canceled the dial because we received a - // connection or some other random reason. - // Just ignore the error and return the connection. - log.Debugf("ignoring dial error because we already have a connection: %s", err) - return conn, nil + type pendRequest struct { + req dialRequest // the original request + err *DialError // dial error accumulator + addrs map[ma.Multiaddr]struct{} // pending addr dials + } + + type addrDial struct { + ctx context.Context + conn *Conn + err error + requests []int + } + + reqno := 0 + requests := make(map[int]*pendRequest) + pending := make(map[ma.Multiaddr]*addrDial) + + var triggerDial <-chan time.Time + var nextDial []ma.Multiaddr + active := 0 + done := false + + resch := make(chan dialComplete) + +loop: + for { + select { + case req, ok := <-reqch: + if !ok { + // request channel has been closed, wait for pending dials to complete + if active > 0 { + done = true + reqch = nil + triggerDial = nil + continue loop + } + + // no active dials, we are done + return + } + + c := s.bestAcceptableConnToPeer(req.ctx, p) + if c != nil { + req.resch <- dialResponse{conn: c} + continue loop + } + + addrs, err := s.addrsForDial(req.ctx, p) + if err != nil { + req.resch <- dialResponse{err: err} + continue loop + } + + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + addrs = s.rankAddrs(addrs) + + // create the pending request object + pr := &pendRequest{ + req: req, + err: &DialError{Peer: p}, + addrs: make(map[ma.Multiaddr]struct{}), + } + for _, a := range addrs { + pr.addrs[a] = struct{}{} + } + + // check if any of the addrs has been successfully dialed and accumulate + // errors from complete dials while collecting new addrs to dial/join + var todial []ma.Multiaddr + var tojoin []*addrDial + + for _, a := range addrs { + ad, ok := pending[a] + if !ok { + todial = append(todial, a) + continue + } + + if ad.conn != nil { + // dial to this addr was successful, complete the request + req.resch <- dialResponse{conn: ad.conn} + continue loop + } + + if ad.err != nil { + // dial to this addr errored, accumulate the error + pr.err.recordErr(a, ad.err) + delete(pr.addrs, a) + } + + // dial is still pending, add to the join list + tojoin = append(tojoin, ad) + } + + if len(todial) == 0 && len(tojoin) == 0 { + // all request applicable addrs have been dialed, we must have errored + req.resch <- dialResponse{err: pr.err} + continue loop + } + + // the request has some pending or new dials, track it and schedule new dials + reqno++ + requests[reqno] = pr + + for _, ad := range tojoin { + ad.requests = append(ad.requests, reqno) + } + + if len(todial) > 0 { + for _, a := range todial { + pending[a] = &addrDial{ctx: req.ctx, requests: []int{reqno}} + } + + nextDial = append(nextDial, todial...) + nextDial = s.rankAddrs(nextDial) + + if triggerDial == nil { + trigger := make(chan time.Time) + close(trigger) + triggerDial = trigger + } + } + + case <-triggerDial: + if len(nextDial) == 0 { + triggerDial = nil + continue loop + } + + next := nextDial[0] + nextDial = nextDial[1:] + + // spawn the next dial + ad := pending[next] + go s.dialNextAddr(ad.ctx, p, next, resch) + active++ + + // select an appropriate delay for the next dial trigger + delay := s.delayForNextDial(next) + triggerDial = time.After(delay) + + case res := <-resch: + active-- + + if done && active == 0 { + return + } + + ad := pending[res.addr] + ad.conn = res.conn + ad.err = res.err + + dialRequests := ad.requests + ad.requests = nil + + if res.conn != nil { + // we got a connection, dispatch to still pending requests + for _, reqno := range dialRequests { + pr, ok := requests[reqno] + if !ok { + // it has already dispatched a connection + continue + } + + pr.req.resch <- dialResponse{conn: res.conn} + delete(requests, reqno) + } + + continue loop + } + + // it must be an error, accumulate it and dispatch dial error if the request has tried all addrs + for _, reqno := range dialRequests { + pr, ok := requests[reqno] + if !ok { + // has already been dispatched + continue + } + + // accumulate the error + pr.err.recordErr(res.addr, res.err) + + delete(pr.addrs, res.addr) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + pr.req.resch <- dialResponse{err: pr.err} + delete(requests, reqno) + } + } } + } +} - // ok, we failed. - return nil, err +func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { + peerAddrs := s.peers.Addrs(p) + if len(peerAddrs) == 0 { + return nil, ErrNoAddresses + } + + goodAddrs := s.filterKnownUndialables(p, peerAddrs) + if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { + goodAddrs = addrutil.FilterAddrs(goodAddrs, s.nonProxyAddr) } - return conn, nil + + if len(goodAddrs) == 0 { + return nil, ErrNoGoodAddresses + } + + return goodAddrs, nil +} + +func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialComplete) { + // check the dial backoff + if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect { + if s.backf.Backoff(p, addr) { + resch <- dialComplete{addr: addr, err: ErrDialBackoff} + return + } + } + + // start the dial + dresch := make(chan dialResult) + s.limitedDial(ctx, p, addr, dresch) + select { + case res := <-dresch: + if res.Err != nil { + if res.Err != context.Canceled { + s.backf.AddBackoff(p, addr) + } + + resch <- dialComplete{addr: addr, err: res.Err} + return + } + + conn, err := s.addConn(res.Conn, network.DirOutbound) + if err != nil { + res.Conn.Close() + resch <- dialComplete{addr: addr, err: err} + return + } + + resch <- dialComplete{addr: addr, conn: conn} + + case <-ctx.Done(): + resch <- dialComplete{addr: addr, err: ctx.Err()} + } +} + +func (s *Swarm) delayForNextDial(addr ma.Multiaddr) time.Duration { + if _, err := addr.ValueForProtocol(ma.P_CIRCUIT); err == nil { + return DelayDialRelayAddr + } + + if manet.IsPrivateAddr(addr) { + return DelayDialPrivateAddr + } + + return DelayDialPublicAddr } func (s *Swarm) canDial(addr ma.Multiaddr) bool { @@ -365,80 +637,6 @@ func (s *Swarm) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { return append(append(append(localUdpAddrs, othersUdp...), fds...), relays...) } -// dial is the actual swarm's dial logic, gated by Dial. -func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) { - forceDirect, _ := network.GetForceDirectDial(ctx) - var logdial = lgbl.Dial("swarm", s.LocalPeer(), p, nil, nil) - if p == s.local { - log.Event(ctx, "swarmDialDoDialSelf", logdial) - return nil, ErrDialToSelf - } - defer log.EventBegin(ctx, "swarmDialDo", logdial).Done() - logdial["dial"] = "failure" // start off with failure. set to "success" at the end. - - sk := s.peers.PrivKey(s.local) - logdial["encrypted"] = sk != nil // log whether this will be an encrypted dial or not. - if sk == nil { - // fine for sk to be nil, just log. - log.Debug("Dial not given PrivateKey, so WILL NOT SECURE conn.") - } - - ////// - peerAddrs := s.peers.Addrs(p) - if len(peerAddrs) == 0 { - return nil, &DialError{Peer: p, Cause: ErrNoAddresses} - } - goodAddrs := s.filterKnownUndialables(p, peerAddrs) - if forceDirect { - goodAddrs = addrutil.FilterAddrs(goodAddrs, s.nonProxyAddr) - } - if len(goodAddrs) == 0 { - return nil, &DialError{Peer: p, Cause: ErrNoGoodAddresses} - } - - if !forceDirect { - /////// Check backoff andnRank addresses - var nonBackoff bool - for _, a := range goodAddrs { - // skip addresses in back-off - if !s.backf.Backoff(p, a) { - nonBackoff = true - } - } - if !nonBackoff { - return nil, ErrDialBackoff - } - } - - connC, dialErr := s.dialAddrs(ctx, p, s.rankAddrs(goodAddrs)) - if dialErr != nil { - logdial["error"] = dialErr.Cause.Error() - switch dialErr.Cause { - case context.Canceled, context.DeadlineExceeded: - // Always prefer the context errors as we rely on being - // able to check them. - // - // Removing this will BREAK backoff (causing us to - // backoff when canceling dials). - return nil, dialErr.Cause - } - return nil, dialErr - } - logdial["conn"] = logging.Metadata{ - "localAddr": connC.LocalMultiaddr(), - "remoteAddr": connC.RemoteMultiaddr(), - } - swarmC, err := s.addConn(connC, network.DirOutbound) - if err != nil { - logdial["error"] = err.Error() - connC.Close() // close the connection. didn't work out :( - return nil, &DialError{Peer: p, Cause: err} - } - - logdial["dial"] = "success" - return swarmC, nil -} - // filterKnownUndialables takes a list of multiaddrs, and removes those // that we definitely don't want to dial: addresses configured to be blocked, // IPv6 link-local addresses, addresses without a dial-capable transport, @@ -466,98 +664,6 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul ) } -func (s *Swarm) dialAddrs(ctx context.Context, p peer.ID, remoteAddrs []ma.Multiaddr) (transport.CapableConn, *DialError) { - /* - This slice-to-chan code is temporary, the peerstore can currently provide - a channel as an interface for receiving addresses, but more thought - needs to be put into the execution. For now, this allows us to use - the improved rate limiter, while maintaining the outward behaviour - that we previously had (halting a dial when we run out of addrs) - */ - var remoteAddrChan chan ma.Multiaddr - if len(remoteAddrs) > 0 { - remoteAddrChan = make(chan ma.Multiaddr, len(remoteAddrs)) - for i := range remoteAddrs { - remoteAddrChan <- remoteAddrs[i] - } - close(remoteAddrChan) - } - - log.Debugf("%s swarm dialing %s", s.local, p) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() // cancel work when we exit func - - // use a single response type instead of errs and conns, reduces complexity *a ton* - respch := make(chan dialResult) - err := &DialError{Peer: p} - - defer s.limiter.clearAllPeerDials(p) - - var active int -dialLoop: - for remoteAddrChan != nil || active > 0 { - // Check for context cancellations and/or responses first. - select { - case <-ctx.Done(): - break dialLoop - case resp := <-respch: - active-- - if resp.Err != nil { - // Errors are normal, lots of dials will fail - if resp.Err != context.Canceled { - s.backf.AddBackoff(p, resp.Addr) - } - - log.Infof("got error on dial: %s", resp.Err) - err.recordErr(resp.Addr, resp.Err) - } else if resp.Conn != nil { - return resp.Conn, nil - } - - // We got a result, try again from the top. - continue - default: - } - - // Now, attempt to dial. - select { - case addr, ok := <-remoteAddrChan: - if !ok { - remoteAddrChan = nil - continue - } - - s.limitedDial(ctx, p, addr, respch) - active++ - case <-ctx.Done(): - break dialLoop - case resp := <-respch: - active-- - if resp.Err != nil { - // Errors are normal, lots of dials will fail - if resp.Err != context.Canceled { - s.backf.AddBackoff(p, resp.Addr) - } - - log.Infof("got error on dial: %s", resp.Err) - err.recordErr(resp.Addr, resp.Err) - } else if resp.Conn != nil { - return resp.Conn, nil - } - } - } - - if ctxErr := ctx.Err(); ctxErr != nil { - err.Cause = ctxErr - } else if len(err.DialErrors) == 0 { - err.Cause = network.ErrNoRemoteAddrs - } else { - err.Cause = ErrAllDialsFailed - } - return nil, err -} - // limitedDial will start a dial to the given peer when // it is able, respecting the various different types of rate // limiting that occur without using extra goroutines per addr @@ -570,6 +676,7 @@ func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp }) } +// dialAddr is the actual dial for an addr, indirectly invoked through the limiter func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { // Just to double check. Costs nothing. if s.local == p {