diff --git a/swarm_dial.go b/swarm_dial.go index 2d3255b4a674a21cb2e61b102d6ac3d174a35db0..513ce1c617e4591d9602cb3f064215fec37bcdd0 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -320,6 +320,7 @@ func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan Dial conn *Conn err error requests []int + dialed bool } reqno := 0 @@ -454,6 +455,9 @@ loop: requests[reqno] = pr for _, ad := range tojoin { + if !ad.dialed { + ad.ctx = s.mergeDialContexts(ad.ctx, req.Ctx) + } ad.requests = append(ad.requests, reqno) } @@ -490,6 +494,7 @@ loop: continue } + ad.dialed = true dialed = true last = i active++ @@ -581,6 +586,18 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er return goodAddrs, nil } +func (s *Swarm) mergeDialContexts(a, b context.Context) context.Context { + dialCtx := a + + if simConnect, reason := network.GetSimultaneousConnect(b); simConnect { + if simConnect, _ := network.GetSimultaneousConnect(a); !simConnect { + dialCtx = network.WithSimultaneousConnect(dialCtx, reason) + } + } + + return dialCtx +} + func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error { // check the dial backoff if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect {