diff --git a/dial_sync.go b/dial_sync.go index 2efdf067b60c62bc4cc5a8fb41c1b76695aac4cd..54a21067b7bf4aebad081bda34d86e02e5c12f44 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -13,7 +13,7 @@ import ( var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") // DialFunc is the type of function expected by DialSync. -type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) +type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) // NewDialSync constructs a new DialSync func NewDialSync(worker DialWorkerFunc) *DialSync { @@ -38,7 +38,7 @@ type activeDial struct { ctx context.Context cancel func() - reqch chan dialRequest + reqch chan DialRequest ds *DialSync } @@ -68,16 +68,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { dialCtx = network.WithSimultaneousConnect(dialCtx, reason) } - resch := make(chan dialResponse, 1) + resch := make(chan DialResponse, 1) select { - case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}: + case ad.reqch <- DialRequest{Ctx: dialCtx, Resch: resch}: case <-ctx.Done(): return nil, ctx.Err() } select { case res := <-resch: - return res.conn, res.err + return res.Conn, res.Err case <-ctx.Done(): return nil, ctx.Err() } @@ -98,7 +98,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { id: p, ctx: adctx, cancel: cancel, - reqch: make(chan dialRequest), + reqch: make(chan DialRequest), ds: ds, } ds.dials[p] = actd diff --git a/dial_sync_test.go b/dial_sync_test.go index 485d1a3171a47832c508b586e4ee57dc47ae08b0..ef7458a554700501c0c4422cac10617a5b486e68 100644 --- a/dial_sync_test.go +++ b/dial_sync_test.go @@ -12,19 +12,33 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) -func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { +func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}) { dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dialctx, cancel := context.WithCancel(context.Background()) ch := make(chan struct{}) - f := func(ctx context.Context, p peer.ID) (*Conn, error) { + f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { dfcalls <- struct{}{} - defer cancel() - select { - case <-ch: - return new(Conn), nil - case <-ctx.Done(): - return nil, ctx.Err() - } + go func() { + defer cancel() + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + select { + case <-ch: + req.Resch <- DialResponse{Conn: new(Conn)} + case <-ctx.Done(): + req.Resch <- DialResponse{Err: ctx.Err()} + return + } + case <-ctx.Done(): + return + } + } + }() } o := new(sync.Once) @@ -174,12 +188,25 @@ func TestDialSyncAllCancel(t *testing.T) { func TestFailFirst(t *testing.T) { var count int - f := func(ctx context.Context, p peer.ID) (*Conn, error) { - if count > 0 { - return new(Conn), nil + f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + if count > 0 { + req.Resch <- DialResponse{Conn: new(Conn)} + } else { + req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")} + } + count++ + + case <-ctx.Done(): + return + } } - count++ - return nil, fmt.Errorf("gophers ate the modem") } ds := NewDialSync(f) @@ -205,8 +232,19 @@ func TestFailFirst(t *testing.T) { } func TestStressActiveDial(t *testing.T) { - ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) { - return nil, nil + ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + req.Resch <- DialResponse{} + case <-ctx.Done(): + return + } + } }) wg := sync.WaitGroup{} diff --git a/swarm_dial.go b/swarm_dial.go index 428845f3442573417b046d91908425f667cfca93..ae4c87f1ad215e5f68158281e51a648a7bb7c3a6 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -290,14 +290,14 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { // lo and behold, The Dialer // TODO explain how all this works ////////////////////////////////////////////////////////////////////////////////// -type dialRequest struct { - ctx context.Context - resch chan dialResponse +type DialRequest struct { + Ctx context.Context + Resch chan DialResponse } -type dialResponse struct { - conn *Conn - err error +type DialResponse struct { + Conn *Conn + Err error } type dialComplete struct { @@ -307,7 +307,7 @@ type dialComplete struct { } // dialWorker is an active dial goroutine that synchronizes and executes concurrent dials -func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { +func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { if p == s.local { for { select { @@ -316,7 +316,7 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ return } - req.resch <- dialResponse{err: ErrDialToSelf} + req.Resch <- DialResponse{Err: ErrDialToSelf} } } } @@ -324,11 +324,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ s.dialWorkerLoop(ctx, p, reqch) } -func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { +func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { defer s.limiter.clearAllPeerDials(p) type pendRequest struct { - req dialRequest // the original request + req DialRequest // the original request err *DialError // dial error accumulator addrs map[ma.Multiaddr]struct{} // pending addr dials } @@ -368,15 +368,15 @@ loop: return } - c := s.bestAcceptableConnToPeer(req.ctx, p) + c := s.bestAcceptableConnToPeer(req.Ctx, p) if c != nil { - req.resch <- dialResponse{conn: c} + req.Resch <- DialResponse{Conn: c} continue loop } - addrs, err := s.addrsForDial(req.ctx, p) + addrs, err := s.addrsForDial(req.Ctx, p) if err != nil { - req.resch <- dialResponse{err: err} + req.Resch <- DialResponse{Err: err} continue loop } @@ -408,7 +408,7 @@ loop: if ad.conn != nil { // dial to this addr was successful, complete the request - req.resch <- dialResponse{conn: ad.conn} + req.Resch <- DialResponse{Conn: ad.conn} continue loop } @@ -424,7 +424,7 @@ loop: if len(todial) == 0 && len(tojoin) == 0 { // all request applicable addrs have been dialed, we must have errored - req.resch <- dialResponse{err: pr.err} + req.Resch <- DialResponse{Err: pr.err} continue loop } @@ -438,7 +438,7 @@ loop: if len(todial) > 0 { for _, a := range todial { - pending[a] = &addrDial{ctx: req.ctx, requests: []int{reqno}} + pending[a] = &addrDial{ctx: req.Ctx, requests: []int{reqno}} } nextDial = append(nextDial, todial...) @@ -492,7 +492,7 @@ loop: continue } - pr.req.resch <- dialResponse{conn: res.conn} + pr.req.Resch <- DialResponse{Conn: res.conn} delete(requests, reqno) } @@ -513,7 +513,7 @@ loop: delete(pr.addrs, res.addr) if len(pr.addrs) == 0 { // all addrs have erred, dispatch dial error - pr.req.resch <- dialResponse{err: pr.err} + pr.req.Resch <- DialResponse{Err: pr.err} delete(requests, reqno) } }