dial_sync.go 2.81 KB
Newer Older
1 2 3 4
package swarm

import (
	"context"
5
	"errors"
6 7
	"sync"

8
	"github.com/libp2p/go-libp2p-core/network"
9
	"github.com/libp2p/go-libp2p-core/peer"
10 11
)

12 13 14
// TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")

Steven Allen's avatar
Steven Allen committed
15
// DialFunc is the type of function expected by DialSync.
vyzo's avatar
vyzo committed
16
type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest)
17

Steven Allen's avatar
Steven Allen committed
18
// NewDialSync constructs a new DialSync
19
func NewDialSync(worker DialWorkerFunc) *DialSync {
20
	return &DialSync{
21 22
		dials:      make(map[peer.ID]*activeDial),
		dialWorker: worker,
23 24 25
	}
}

Steven Allen's avatar
Steven Allen committed
26 27
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
28
type DialSync struct {
29 30 31
	dials      map[peer.ID]*activeDial
	dialsLk    sync.Mutex
	dialWorker DialWorkerFunc
32 33 34
}

type activeDial struct {
35 36
	id     peer.ID
	refCnt int
37

38 39
	ctx    context.Context
	cancel func()
40

vyzo's avatar
vyzo committed
41
	reqch chan DialRequest
42

43
	ds *DialSync
44 45 46 47 48 49 50
}

func (ad *activeDial) incref() {
	ad.refCnt++
}

func (ad *activeDial) decref() {
51
	ad.ds.dialsLk.Lock()
52
	ad.refCnt--
53 54 55 56
	if ad.refCnt == 0 {
		ad.cancel()
		close(ad.reqch)
		delete(ad.ds.dials, ad.id)
57
	}
58
	ad.ds.dialsLk.Unlock()
59 60
}

61 62 63 64 65 66 67 68 69 70
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
	dialCtx := ad.ctx

	if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
		dialCtx = network.WithForceDirectDial(dialCtx, reason)
	}
	if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect {
		dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
	}

vyzo's avatar
vyzo committed
71
	resch := make(chan DialResponse, 1)
72
	select {
vyzo's avatar
vyzo committed
73
	case ad.reqch <- DialRequest{Ctx: dialCtx, Resch: resch}:
74 75 76 77 78 79
	case <-ctx.Done():
		return nil, ctx.Err()
	}

	select {
	case res := <-resch:
vyzo's avatar
vyzo committed
80
		return res.Conn, res.Err
81 82
	case <-ctx.Done():
		return nil, ctx.Err()
83
	}
84 85
}

86
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
87
	ds.dialsLk.Lock()
88
	defer ds.dialsLk.Unlock()
89 90 91

	actd, ok := ds.dials[p]
	if !ok {
92 93 94 95 96
		// This code intentionally uses the background context. Otherwise, if the first call
		// to Dial is canceled, subsequent dial calls will also be canceled.
		// XXX: this also breaks direct connection logic. We will need to pipe the
		// information through some other way.
		adctx, cancel := context.WithCancel(context.Background())
97 98
		actd = &activeDial{
			id:     p,
99
			ctx:    adctx,
100
			cancel: cancel,
vyzo's avatar
vyzo committed
101
			reqch:  make(chan DialRequest),
102 103 104 105
			ds:     ds,
		}
		ds.dials[p] = actd

106
		go ds.dialWorker(adctx, p, actd.reqch)
107 108
	}

109 110 111
	// increase ref count before dropping dialsLk
	actd.incref()

112 113
	return actd
}
114

Steven Allen's avatar
Steven Allen committed
115 116
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
117
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
118 119
	ad := ds.getActiveDial(p)
	defer ad.decref()
Steven Allen's avatar
Steven Allen committed
120

121
	return ad.dial(ctx, p)
Steven Allen's avatar
Steven Allen committed
122
}