dial_sync.go 2.73 KB
Newer Older
1 2 3 4 5 6
package swarm

import (
	"context"
	"sync"

tavit ohanian's avatar
tavit ohanian committed
7 8
	"gitlab.dms3.io/p2p/go-p2p-core/network"
	"gitlab.dms3.io/p2p/go-p2p-core/peer"
9 10
)

11
// DialWorerFunc is used by DialSync to spawn a new dial worker
12
type dialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) error
13

14 15
// newDialSync constructs a new DialSync
func newDialSync(worker dialWorkerFunc) *DialSync {
16
	return &DialSync{
17 18
		dials:      make(map[peer.ID]*activeDial),
		dialWorker: worker,
19 20 21
	}
}

Steven Allen's avatar
Steven Allen committed
22 23
// DialSync is a dial synchronization helper that ensures that at most one dial
// to any given peer is active at any given time.
24
type DialSync struct {
25 26
	dials      map[peer.ID]*activeDial
	dialsLk    sync.Mutex
27
	dialWorker dialWorkerFunc
28 29 30
}

type activeDial struct {
31 32
	id     peer.ID
	refCnt int
33

34 35
	ctx    context.Context
	cancel func()
36

37
	reqch chan dialRequest
38

39
	ds *DialSync
40 41 42
}

func (ad *activeDial) decref() {
43
	ad.ds.dialsLk.Lock()
44
	ad.refCnt--
45 46 47 48
	if ad.refCnt == 0 {
		ad.cancel()
		close(ad.reqch)
		delete(ad.ds.dials, ad.id)
49
	}
50
	ad.ds.dialsLk.Unlock()
51 52
}

53 54 55 56 57 58 59 60 61 62
func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
	dialCtx := ad.ctx

	if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect {
		dialCtx = network.WithForceDirectDial(dialCtx, reason)
	}
	if simConnect, reason := network.GetSimultaneousConnect(ctx); simConnect {
		dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
	}

63
	resch := make(chan dialResponse, 1)
64
	select {
65
	case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
66 67 68 69 70 71
	case <-ctx.Done():
		return nil, ctx.Err()
	}

	select {
	case res := <-resch:
72
		return res.conn, res.err
73 74
	case <-ctx.Done():
		return nil, ctx.Err()
75
	}
76 77
}

78
func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
79
	ds.dialsLk.Lock()
80
	defer ds.dialsLk.Unlock()
81 82 83

	actd, ok := ds.dials[p]
	if !ok {
84 85 86 87 88
		// This code intentionally uses the background context. Otherwise, if the first call
		// to Dial is canceled, subsequent dial calls will also be canceled.
		// XXX: this also breaks direct connection logic. We will need to pipe the
		// information through some other way.
		adctx, cancel := context.WithCancel(context.Background())
89 90
		actd = &activeDial{
			id:     p,
91
			ctx:    adctx,
92
			cancel: cancel,
93
			reqch:  make(chan dialRequest),
94 95 96
			ds:     ds,
		}

97 98 99 100 101
		err := ds.dialWorker(adctx, p, actd.reqch)
		if err != nil {
			cancel()
			return nil, err
		}
102 103

		ds.dials[p] = actd
104 105
	}

106
	// increase ref count before dropping dialsLk
vyzo's avatar
vyzo committed
107
	actd.refCnt++
108

109
	return actd, nil
110
}
111

Steven Allen's avatar
Steven Allen committed
112 113
// DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete.
114
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
115 116 117 118
	ad, err := ds.getActiveDial(p)
	if err != nil {
		return nil, err
	}
Steven Allen's avatar
Steven Allen committed
119

120
	defer ad.decref()
121
	return ad.dial(ctx, p)
Steven Allen's avatar
Steven Allen committed
122
}