Commit 34e49ba4 authored by Jeromy's avatar Jeromy

make sure to always take locks in correct order

parent a636a14b
...@@ -36,6 +36,7 @@ type activeDial struct { ...@@ -36,6 +36,7 @@ type activeDial struct {
} }
func (dr *activeDial) wait(ctx context.Context) (*Conn, error) { func (dr *activeDial) wait(ctx context.Context) (*Conn, error) {
dr.incref()
defer dr.decref() defer dr.decref()
select { select {
case <-dr.waitch: case <-dr.waitch:
...@@ -53,22 +54,38 @@ func (ad *activeDial) incref() { ...@@ -53,22 +54,38 @@ func (ad *activeDial) incref() {
func (ad *activeDial) decref() { func (ad *activeDial) decref() {
ad.refCntLk.Lock() ad.refCntLk.Lock()
defer ad.refCntLk.Unlock()
ad.refCnt-- ad.refCnt--
maybeZero := (ad.refCnt <= 0)
ad.refCntLk.Unlock()
// make sure to always take locks in correct order.
if maybeZero {
ad.ds.dialsLk.Lock()
ad.refCntLk.Lock()
// check again after lock swap drop to make sure nobody else called incref
// in between locks
if ad.refCnt <= 0 { if ad.refCnt <= 0 {
ad.cancel() ad.cancel()
ad.ds.dialsLk.Lock()
delete(ad.ds.dials, ad.id) delete(ad.ds.dials, ad.id)
}
ad.ds.dialsLk.Unlock() ad.ds.dialsLk.Unlock()
ad.refCntLk.Unlock()
} }
} }
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { func (ad *activeDial) start(ctx context.Context) {
ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)
close(ad.waitch)
ad.cancel()
}
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
ds.dialsLk.Lock() ds.dialsLk.Lock()
defer ds.dialsLk.Unlock()
actd, ok := ds.dials[p] actd, ok := ds.dials[p]
if !ok { if !ok {
ctx, cancel := context.WithCancel(context.Background()) adctx, cancel := context.WithCancel(context.Background())
actd = &activeDial{ actd = &activeDial{
id: p, id: p,
cancel: cancel, cancel: cancel,
...@@ -77,15 +94,12 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -77,15 +94,12 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
} }
ds.dials[p] = actd ds.dials[p] = actd
go func(ctx context.Context, p peer.ID, ad *activeDial) { go actd.start(adctx)
ad.conn, ad.err = ds.dialFunc(ctx, p)
close(ad.waitch)
ad.cancel()
}(ctx, p, actd)
} }
actd.incref() return actd
ds.dialsLk.Unlock() }
return actd.wait(ctx) func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
return ds.getActiveDial(p).wait(ctx)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment