dial_sync.go 2.47 KB
Newer Older
1 2 3 4 5 6
package swarm

import (
	"context"
	"sync"

7
	"github.com/libp2p/go-libp2p-core/peer"
8 9
)

Steven Allen's avatar
Steven Allen committed
10
// DialFunc is the type of function expected by DialSync.
11 12
type DialFunc func(context.Context, peer.ID) (*Conn, error)

Steven Allen's avatar
Steven Allen committed
13
// NewDialSync constructs a new DialSync
14 15 16 17 18 19 20
func NewDialSync(dfn DialFunc) *DialSync {
	return &DialSync{
		dials:    make(map[peer.ID]*activeDial),
		dialFunc: dfn,
	}
}

Steven Allen's avatar
Steven Allen committed
21 22
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
type DialSync struct {
	dials    map[peer.ID]*activeDial
	dialsLk  sync.Mutex
	dialFunc DialFunc
}

type activeDial struct {
	id       peer.ID
	refCnt   int
	refCntLk sync.Mutex
	cancel   func()

	err    error
	conn   *Conn
	waitch chan struct{}

	ds *DialSync
}

Steven Allen's avatar
Steven Allen committed
42 43
func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
	defer ad.decref()
44
	select {
Steven Allen's avatar
Steven Allen committed
45 46
	case <-ad.waitch:
		return ad.conn, ad.err
47 48 49 50 51 52 53 54 55 56 57 58 59 60
	case <-ctx.Done():
		return nil, ctx.Err()
	}
}

func (ad *activeDial) incref() {
	ad.refCntLk.Lock()
	defer ad.refCntLk.Unlock()
	ad.refCnt++
}

func (ad *activeDial) decref() {
	ad.refCntLk.Lock()
	ad.refCnt--
61 62 63 64 65
	maybeZero := (ad.refCnt <= 0)
	ad.refCntLk.Unlock()

	// make sure to always take locks in correct order.
	if maybeZero {
66
		ad.ds.dialsLk.Lock()
67 68 69 70 71 72 73 74
		ad.refCntLk.Lock()
		// check again after lock swap drop to make sure nobody else called incref
		// in between locks
		if ad.refCnt <= 0 {
			ad.cancel()
			delete(ad.ds.dials, ad.id)
		}
		ad.refCntLk.Unlock()
Jeromy's avatar
Jeromy committed
75
		ad.ds.dialsLk.Unlock()
76 77 78
	}
}

79 80 81 82 83 84 85
func (ad *activeDial) start(ctx context.Context) {
	ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)
	close(ad.waitch)
	ad.cancel()
}

func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
86
	ds.dialsLk.Lock()
87
	defer ds.dialsLk.Unlock()
88 89 90

	actd, ok := ds.dials[p]
	if !ok {
91
		adctx, cancel := context.WithCancel(context.Background())
92 93 94 95 96 97 98 99
		actd = &activeDial{
			id:     p,
			cancel: cancel,
			waitch: make(chan struct{}),
			ds:     ds,
		}
		ds.dials[p] = actd

100
		go actd.start(adctx)
101 102
	}

103 104 105
	// increase ref count before dropping dialsLk
	actd.incref()

106 107
	return actd
}
108

Steven Allen's avatar
Steven Allen committed
109 110
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
111 112
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
	return ds.getActiveDial(p).wait(ctx)
113
}
Steven Allen's avatar
Steven Allen committed
114 115 116 117 118 119 120 121 122

// CancelDial cancels all in-progress dials to the given peer.
func (ds *DialSync) CancelDial(p peer.ID) {
	ds.dialsLk.Lock()
	defer ds.dialsLk.Unlock()
	if ad, ok := ds.dials[p]; ok {
		ad.cancel()
	}
}