Commit 2ee7bf0a authored by vyzo's avatar vyzo

make dialWorker return an error for self dials and responsible for spawning the loop

parent de528f18
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
// DialFunc is the type of function expected by DialSync. // DialFunc is the type of function expected by DialSync.
type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) error
// NewDialSync constructs a new DialSync // NewDialSync constructs a new DialSync
func NewDialSync(worker DialWorkerFunc) *DialSync { func NewDialSync(worker DialWorkerFunc) *DialSync {
...@@ -79,7 +79,7 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -79,7 +79,7 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
} }
} }
func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
ds.dialsLk.Lock() ds.dialsLk.Lock()
defer ds.dialsLk.Unlock() defer ds.dialsLk.Unlock()
...@@ -99,20 +99,27 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { ...@@ -99,20 +99,27 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
} }
ds.dials[p] = actd ds.dials[p] = actd
go ds.dialWorker(adctx, p, actd.reqch) err := ds.dialWorker(adctx, p, actd.reqch)
if err != nil {
cancel()
return nil, err
}
} }
// increase ref count before dropping dialsLk // increase ref count before dropping dialsLk
actd.refCnt++ actd.refCnt++
return actd return actd, nil
} }
// DialLock initiates a dial to the given peer if there are none in progress // DialLock initiates a dial to the given peer if there are none in progress
// then waits for the dial to that peer to complete. // then waits for the dial to that peer to complete.
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) {
ad := ds.getActiveDial(p) ad, err := ds.getActiveDial(p)
defer ad.decref() if err != nil {
return nil, err
}
defer ad.decref()
return ad.dial(ctx, p) return ad.dial(ctx, p)
} }
...@@ -16,7 +16,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} ...@@ -16,7 +16,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care
dialctx, cancel := context.WithCancel(context.Background()) dialctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{}) ch := make(chan struct{})
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error {
dfcalls <- struct{}{} dfcalls <- struct{}{}
go func() { go func() {
defer cancel() defer cancel()
...@@ -39,6 +39,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} ...@@ -39,6 +39,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}
} }
} }
}() }()
return nil
} }
o := new(sync.Once) o := new(sync.Once)
...@@ -188,25 +189,28 @@ func TestDialSyncAllCancel(t *testing.T) { ...@@ -188,25 +189,28 @@ func TestDialSyncAllCancel(t *testing.T) {
func TestFailFirst(t *testing.T) { func TestFailFirst(t *testing.T) {
var count int var count int
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error {
for { go func() {
select { for {
case req, ok := <-reqch: select {
if !ok { case req, ok := <-reqch:
return if !ok {
} return
}
if count > 0 { if count > 0 {
req.Resch <- DialResponse{Conn: new(Conn)} req.Resch <- DialResponse{Conn: new(Conn)}
} else { } else {
req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")} req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")}
} }
count++ count++
case <-ctx.Done(): case <-ctx.Done():
return return
}
} }
} }()
return nil
} }
ds := NewDialSync(f) ds := NewDialSync(f)
...@@ -232,19 +236,22 @@ func TestFailFirst(t *testing.T) { ...@@ -232,19 +236,22 @@ func TestFailFirst(t *testing.T) {
} }
func TestStressActiveDial(t *testing.T) { func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error {
for { go func() {
select { for {
case req, ok := <-reqch: select {
if !ok { case req, ok := <-reqch:
if !ok {
return
}
req.Resch <- DialResponse{}
case <-ctx.Done():
return return
} }
req.Resch <- DialResponse{}
case <-ctx.Done():
return
} }
} }()
return nil
}) })
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
......
...@@ -307,21 +307,13 @@ type dialComplete struct { ...@@ -307,21 +307,13 @@ type dialComplete struct {
} }
// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials // dialWorker is an active dial goroutine that synchronizes and executes concurrent dials
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error {
if p == s.local { if p == s.local {
for { return ErrDialToSelf
select {
case req, ok := <-reqch:
if !ok {
return
}
req.Resch <- DialResponse{Err: ErrDialToSelf}
}
}
} }
s.dialWorkerLoop(ctx, p, reqch) go s.dialWorkerLoop(ctx, p, reqch)
return nil
} }
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment