dial_sync.go 3.18 KB
Newer Older
1 2 3 4
package swarm

import (
	"context"
5
	"errors"
6 7
	"sync"

8
	"github.com/libp2p/go-libp2p-core/peer"
9 10
)

11 12 13
// TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")

Steven Allen's avatar
Steven Allen committed
14
// DialFunc is the type of function expected by DialSync.
15 16
type DialFunc func(context.Context, peer.ID) (*Conn, error)

Steven Allen's avatar
Steven Allen committed
17
// NewDialSync constructs a new DialSync
18 19 20 21 22 23 24
func NewDialSync(dfn DialFunc) *DialSync {
	return &DialSync{
		dials:    make(map[peer.ID]*activeDial),
		dialFunc: dfn,
	}
}

Steven Allen's avatar
Steven Allen committed
25 26
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
type DialSync struct {
	dials    map[peer.ID]*activeDial
	dialsLk  sync.Mutex
	dialFunc DialFunc
}

type activeDial struct {
	id       peer.ID
	refCnt   int
	refCntLk sync.Mutex
	cancel   func()

	err    error
	conn   *Conn
	waitch chan struct{}

	ds *DialSync
}

Steven Allen's avatar
Steven Allen committed
46 47
func (ad *activeDial) wait(ctx context.Context) (*Conn, error) {
	defer ad.decref()
48
	select {
Steven Allen's avatar
Steven Allen committed
49 50
	case <-ad.waitch:
		return ad.conn, ad.err
51 52 53 54 55 56 57 58 59 60 61 62 63 64
	case <-ctx.Done():
		return nil, ctx.Err()
	}
}

func (ad *activeDial) incref() {
	ad.refCntLk.Lock()
	defer ad.refCntLk.Unlock()
	ad.refCnt++
}

func (ad *activeDial) decref() {
	ad.refCntLk.Lock()
	ad.refCnt--
65 66 67 68 69
	maybeZero := (ad.refCnt <= 0)
	ad.refCntLk.Unlock()

	// make sure to always take locks in correct order.
	if maybeZero {
70
		ad.ds.dialsLk.Lock()
71 72 73 74 75 76 77 78
		ad.refCntLk.Lock()
		// check again after lock swap drop to make sure nobody else called incref
		// in between locks
		if ad.refCnt <= 0 {
			ad.cancel()
			delete(ad.ds.dials, ad.id)
		}
		ad.refCntLk.Unlock()
Jeromy's avatar
Jeromy committed
79
		ad.ds.dialsLk.Unlock()
80 81 82
	}
}

83 84
func (ad *activeDial) start(ctx context.Context) {
	ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id)
85 86 87 88 89 90 91 92 93 94

	// This isn't the user's context so we should fix the error.
	switch ad.err {
	case context.Canceled:
		// The dial was canceled with `CancelDial`.
		ad.err = errDialCanceled
	case context.DeadlineExceeded:
		// We hit an internal timeout, not a context timeout.
		ad.err = ErrDialTimeout
	}
95 96 97 98
	close(ad.waitch)
	ad.cancel()
}

99
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
100
	ds.dialsLk.Lock()
101
	defer ds.dialsLk.Unlock()
102 103 104

	actd, ok := ds.dials[p]
	if !ok {
105 106 107 108 109
		// This code intentionally uses the background context. Otherwise, if the first call
		// to Dial is canceled, subsequent dial calls will also be canceled.
		// XXX: this also breaks direct connection logic. We will need to pipe the
		// information through some other way.
		adctx, cancel := context.WithCancel(context.Background())
110 111 112 113 114 115 116 117
		actd = &activeDial{
			id:     p,
			cancel: cancel,
			waitch: make(chan struct{}),
			ds:     ds,
		}
		ds.dials[p] = actd

118
		go actd.start(adctx)
119 120
	}

121 122 123
	// increase ref count before dropping dialsLk
	actd.incref()

124 125
	return actd
}
126

Steven Allen's avatar
Steven Allen committed
127 128
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
129
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
130
	return ds.getActiveDial(p).wait(ctx)
131
}
Steven Allen's avatar
Steven Allen committed
132 133 134 135 136 137 138 139 140

// CancelDial cancels all in-progress dials to the given peer.
func (ds *DialSync) CancelDial(p peer.ID) {
	ds.dialsLk.Lock()
	defer ds.dialsLk.Unlock()
	if ad, ok := ds.dials[p]; ok {
		ad.cancel()
	}
}