dial_sync.go 2.88 KB
Newer Older
1 2 3 4
package swarm

import (
	"context"
5
	"errors"
6 7
	"sync"

8
	"github.com/libp2p/go-libp2p-core/network"
9
	"github.com/libp2p/go-libp2p-core/peer"
10 11
)

12 13 14
// TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")

15 16
// DialWorerFunc is used by DialSync to spawn a new dial worker
type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) error
17

Steven Allen's avatar
Steven Allen committed
18
// NewDialSync constructs a new DialSync
19
func NewDialSync(worker DialWorkerFunc) *DialSync {
20
	return &DialSync{
21 22
		dials:      make(map[peer.ID]*activeDial),
		dialWorker: worker,
23 24 25
	}
}

Steven Allen's avatar
Steven Allen committed
26 27
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
28
type DialSync struct {
29 30 31
	dials      map[peer.ID]*activeDial
	dialsLk    sync.Mutex
	dialWorker DialWorkerFunc
32 33 34
}

type activeDial struct {
35 36
	id     peer.ID
	refCnt int
37

38 39
	ctx    context.Context
	cancel func()
40

41
	reqch chan dialRequest
42

43
	ds *DialSync
44 45 46
}

func (ad *activeDial) decref() {
47
	ad.ds.dialsLk.Lock()
48
	ad.refCnt--
49 50 51 52
	if ad.refCnt == 0 {
		ad.cancel()
		close(ad.reqch)
		delete(ad.ds.dials, ad.id)
53
	}
54
	ad.ds.dialsLk.Unlock()
55 56
}

57 58 59 60 61 62 63 64 65 66
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
	dialCtx := ad.ctx

	if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
		dialCtx = network.WithForceDirectDial(dialCtx, reason)
	}
	if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect {
		dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
	}

67
	resch := make(chan dialResponse, 1)
68
	select {
69
	case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
70 71 72 73 74 75
	case <-ctx.Done():
		return nil, ctx.Err()
	}

	select {
	case res := <-resch:
76
		return res.conn, res.err
77 78
	case <-ctx.Done():
		return nil, ctx.Err()
79
	}
80 81
}

82
func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
83
	ds.dialsLk.Lock()
84
	defer ds.dialsLk.Unlock()
85 86 87

	actd, ok := ds.dials[p]
	if !ok {
88 89 90 91 92
		// This code intentionally uses the background context. Otherwise, if the first call
		// to Dial is canceled, subsequent dial calls will also be canceled.
		// XXX: this also breaks direct connection logic. We will need to pipe the
		// information through some other way.
		adctx, cancel := context.WithCancel(context.Background())
93 94
		actd = &activeDial{
			id:     p,
95
			ctx:    adctx,
96
			cancel: cancel,
97
			reqch:  make(chan dialRequest),
98 99 100
			ds:     ds,
		}

101 102 103 104 105
		err := ds.dialWorker(adctx, p, actd.reqch)
		if err != nil {
			cancel()
			return nil, err
		}
106 107

		ds.dials[p] = actd
108 109
	}

110
	// increase ref count before dropping dialsLk
vyzo's avatar
vyzo committed
111
	actd.refCnt++
112

113
	return actd, nil
114
}
115

Steven Allen's avatar
Steven Allen committed
116 117
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
118
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
119 120 121 122
	ad, err := ds.getActiveDial(p)
	if err != nil {
		return nil, err
	}
Steven Allen's avatar
Steven Allen committed
123

124
	defer ad.decref()
125
	return ad.dial(ctx, p)
Steven Allen's avatar
Steven Allen committed
126
}