diff --git a/dial_sync.go b/dial_sync.go index e5f25478cbb393d25c8c1b854443012aed441c63..32196cb2d5d6f06050c5badb7296975832063057 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -13,7 +13,7 @@ import ( var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") // DialFunc is the type of function expected by DialSync. -type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) +type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) error // NewDialSync constructs a new DialSync func NewDialSync(worker DialWorkerFunc) *DialSync { @@ -79,7 +79,7 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { } } -func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { +func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { ds.dialsLk.Lock() defer ds.dialsLk.Unlock() @@ -99,20 +99,27 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { } ds.dials[p] = actd - go ds.dialWorker(adctx, p, actd.reqch) + err := ds.dialWorker(adctx, p, actd.reqch) + if err != nil { + cancel() + return nil, err + } } // increase ref count before dropping dialsLk actd.refCnt++ - return actd + return actd, nil } // DialLock initiates a dial to the given peer if there are none in progress // then waits for the dial to that peer to complete. func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { - ad := ds.getActiveDial(p) - defer ad.decref() + ad, err := ds.getActiveDial(p) + if err != nil { + return nil, err + } + defer ad.decref() return ad.dial(ctx, p) } diff --git a/dial_sync_test.go b/dial_sync_test.go index ef7458a554700501c0c4422cac10617a5b486e68..f1a9f8a539bdb5b7ba1b6126d3311108fba228b6 100644 --- a/dial_sync_test.go +++ b/dial_sync_test.go @@ -16,7 +16,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dialctx, cancel := context.WithCancel(context.Background()) ch := make(chan struct{}) - f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { + f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { dfcalls <- struct{}{} go func() { defer cancel() @@ -39,6 +39,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} } } }() + return nil } o := new(sync.Once) @@ -188,25 +189,28 @@ func TestDialSyncAllCancel(t *testing.T) { func TestFailFirst(t *testing.T) { var count int - f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { - for { - select { - case req, ok := <-reqch: - if !ok { - return - } + f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { + go func() { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } - if count > 0 { - req.Resch <- DialResponse{Conn: new(Conn)} - } else { - req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")} - } - count++ + if count > 0 { + req.Resch <- DialResponse{Conn: new(Conn)} + } else { + req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")} + } + count++ - case <-ctx.Done(): - return + case <-ctx.Done(): + return + } } - } + }() + return nil } ds := NewDialSync(f) @@ -232,19 +236,22 @@ func TestFailFirst(t *testing.T) { } func TestStressActiveDial(t *testing.T) { - ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { - for { - select { - case req, ok := <-reqch: - if !ok { + ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { + go func() { + for { + select { + case req, ok := <-reqch: + if !ok { + return + } + + req.Resch <- DialResponse{} + case <-ctx.Done(): return } - - req.Resch <- DialResponse{} - case <-ctx.Done(): - return } - } + }() + return nil }) wg := sync.WaitGroup{} diff --git a/swarm_dial.go b/swarm_dial.go index 9b10db165bb49ccd10aaf3c7fa5f59de52f8aed4..77cb8a6bc70e7e4585022c9c61e938e12028c79c 100644 --- a/swarm_dial.go +++ b/swarm_dial.go @@ -307,21 +307,13 @@ type dialComplete struct { } // dialWorker is an active dial goroutine that synchronizes and executes concurrent dials -func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { +func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { if p == s.local { - for { - select { - case req, ok := <-reqch: - if !ok { - return - } - - req.Resch <- DialResponse{Err: ErrDialToSelf} - } - } + return ErrDialToSelf } - s.dialWorkerLoop(ctx, p, reqch) + go s.dialWorkerLoop(ctx, p, reqch) + return nil } func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {