diff --git a/dial_sync.go b/dial_sync.go index 48b77899f2f2354ae588fd1723e20b7c12b0628b..69739e54ec9e281f8f2c48005f9d9f072497ea11 100644 --- a/dial_sync.go +++ b/dial_sync.go @@ -53,22 +53,38 @@ func (ad *activeDial) incref() { func (ad *activeDial) decref() { ad.refCntLk.Lock() - defer ad.refCntLk.Unlock() ad.refCnt-- - if ad.refCnt <= 0 { - ad.cancel() + maybeZero := (ad.refCnt <= 0) + ad.refCntLk.Unlock() + + // make sure to always take locks in correct order. + if maybeZero { ad.ds.dialsLk.Lock() - delete(ad.ds.dials, ad.id) + ad.refCntLk.Lock() + // check again after lock swap drop to make sure nobody else called incref + // in between locks + if ad.refCnt <= 0 { + ad.cancel() + delete(ad.ds.dials, ad.id) + } + ad.refCntLk.Unlock() ad.ds.dialsLk.Unlock() } } -func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { +func (ad *activeDial) start(ctx context.Context) { + ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id) + close(ad.waitch) + ad.cancel() +} + +func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { ds.dialsLk.Lock() + defer ds.dialsLk.Unlock() actd, ok := ds.dials[p] if !ok { - ctx, cancel := context.WithCancel(context.Background()) + adctx, cancel := context.WithCancel(context.Background()) actd = &activeDial{ id: p, cancel: cancel, @@ -77,15 +93,15 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { } ds.dials[p] = actd - go func(ctx context.Context, p peer.ID, ad *activeDial) { - ad.conn, ad.err = ds.dialFunc(ctx, p) - close(ad.waitch) - ad.cancel() - }(ctx, p, actd) + go actd.start(adctx) } + // increase ref count before dropping dialsLk actd.incref() - ds.dialsLk.Unlock() - return actd.wait(ctx) + return actd +} + +func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { + return ds.getActiveDial(p).wait(ctx) } diff --git a/dial_sync_test.go b/dial_sync_test.go index ca81a9c87209594b5a270ba85c19fc6ac10cb803..0d70e226afcbe1c6cdacd92f6fea4cb5ad0362e0 100644 --- a/dial_sync_test.go +++ b/dial_sync_test.go @@ -201,3 +201,27 @@ func TestFailFirst(t *testing.T) { t.Fatal("should have gotten a 'real' conn back") } } + +func TestStressActiveDial(t *testing.T) { + ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) { + return nil, nil + }) + + wg := sync.WaitGroup{} + + pid := peer.ID("foo") + + makeDials := func() { + for i := 0; i < 10000; i++ { + ds.DialLock(context.Background(), pid) + } + wg.Done() + } + + for i := 0; i < 100; i++ { + wg.Add(1) + go makeDials() + } + + wg.Wait() +}