Commit 6e49586c authored by Jeromy Johnson's avatar Jeromy Johnson Committed by GitHub

Merge pull request #4 from libp2p/fix/lock-ordering

fix lock ordering race condition in dial sync code
parents a636a14b 511d7a82
...@@ -53,22 +53,38 @@ func (ad *activeDial) incref() { ...@@ -53,22 +53,38 @@ func (ad *activeDial) incref() {
func (ad *activeDial) decref() { func (ad *activeDial) decref() {
ad.refCntLk.Lock() ad.refCntLk.Lock()
defer ad.refCntLk.Unlock()
ad.refCnt-- ad.refCnt--
maybeZero := (ad.refCnt <= 0)
ad.refCntLk.Unlock()
// make sure to always take locks in correct order.
if maybeZero {
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
// check again after lock swap drop to make sure nobody else called incref
// in between locks
if ad.refCnt <= 0 { if ad.refCnt <= 0 {
ad.cancel() ad.cancel()
ad.ds.dialsLk.Lock()
delete(ad.ds.dials, ad.id) delete(ad.ds.dials, ad.id)
}
ad.refCntLk.Unlock()
ad.ds.dialsLk.Unlock() ad.ds.dialsLk.Unlock()
} }
} }
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { func (ad *activeDial) start(ctx context.Context) {
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)
close(ad.waitch)
ad.cancel()
}
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
ds.dialsLk.Lock() ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
actd, ok := ds.dials[p] actd, ok := ds.dials[p]
if !ok { if !ok {
ctx, cancel := context.WithCancel(context.Background()) adctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{ actd = &activeDial{
id: p, id: p,
cancel: cancel, cancel: cancel,
...@@ -77,15 +93,15 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -77,15 +93,15 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
} }
ds.dials[p] = actd ds.dials[p] = actd
go func(ctx context.Context, p peer.ID, ad *activeDial) { go actd.start(adctx)
ad.conn, ad.err = ds.dialFunc(ctx, p)
close(ad.waitch)
ad.cancel()
}(ctx, p, actd)
} }
// increase ref count before dropping dialsLk
actd.incref() actd.incref()
ds.dialsLk.Unlock()
return actd.wait(ctx) return actd
}
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return ds.getActiveDial(p).wait(ctx)
} }
...@@ -201,3 +201,27 @@ func TestFailFirst(t *testing.T) { ...@@ -201,3 +201,27 @@ func TestFailFirst(t *testing.T) {
t.Fatal("should have gotten a 'real' conn back") t.Fatal("should have gotten a 'real' conn back")
} }
} }
func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, nil
})
wg := sync.WaitGroup{}
pid := peer.ID("foo")
makeDials := func() {
for i := 0; i < 10000; i++ {
ds.DialLock(context.Background(), pid)
}
wg.Done()
}
for i := 0; i < 100; i++ {
wg.Add(1)
go makeDials()
}
wg.Wait()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment