Commit 2e742fb9 authored by vyzo's avatar vyzo

fix dial_sync tests

parent 201a8d14
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
// DialFunc is the type of function expected by DialSync. // DialFunc is the type of function expected by DialSync.
type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest)
// NewDialSync constructs a new DialSync // NewDialSync constructs a new DialSync
func NewDialSync(worker DialWorkerFunc) *DialSync { func NewDialSync(worker DialWorkerFunc) *DialSync {
...@@ -38,7 +38,7 @@ type activeDial struct { ...@@ -38,7 +38,7 @@ type activeDial struct {
ctx context.Context ctx context.Context
cancel func() cancel func()
reqch chan dialRequest reqch chan DialRequest
ds *DialSync ds *DialSync
} }
...@@ -68,16 +68,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -68,16 +68,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason) dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
} }
resch := make(chan dialResponse, 1) resch := make(chan DialResponse, 1)
select { select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}: case ad.reqch <- DialRequest{Ctx: dialCtx, Resch: resch}:
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
select { select {
case res := <-resch: case res := <-resch:
return res.conn, res.err return res.Conn, res.Err
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
...@@ -98,7 +98,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { ...@@ -98,7 +98,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
id: p, id: p,
ctx: adctx, ctx: adctx,
cancel: cancel, cancel: cancel,
reqch: make(chan dialRequest), reqch: make(chan DialRequest),
ds: ds, ds: ds,
} }
ds.dials[p] = actd ds.dials[p] = actd
......
...@@ -12,20 +12,34 @@ import ( ...@@ -12,20 +12,34 @@ import (
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
) )
func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}) {
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care
dialctx, cancel := context.WithCancel(context.Background()) dialctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{}) ch := make(chan struct{})
f := func(ctx context.Context, p peer.ID) (*Conn, error) { f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
dfcalls <- struct{}{} dfcalls <- struct{}{}
go func() {
defer cancel() defer cancel()
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
select { select {
case <-ch: case <-ch:
return new(Conn), nil req.Resch <- DialResponse{Conn: new(Conn)}
case <-ctx.Done():
req.Resch <- DialResponse{Err: ctx.Err()}
return
}
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return
} }
} }
}()
}
o := new(sync.Once) o := new(sync.Once)
...@@ -174,12 +188,25 @@ func TestDialSyncAllCancel(t *testing.T) { ...@@ -174,12 +188,25 @@ func TestDialSyncAllCancel(t *testing.T) {
func TestFailFirst(t *testing.T) { func TestFailFirst(t *testing.T) {
var count int var count int
f := func(ctx context.Context, p peer.ID) (*Conn, error) { f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
if count > 0 { if count > 0 {
return new(Conn), nil req.Resch <- DialResponse{Conn: new(Conn)}
} else {
req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")}
} }
count++ count++
return nil, fmt.Errorf("gophers ate the modem")
case <-ctx.Done():
return
}
}
} }
ds := NewDialSync(f) ds := NewDialSync(f)
...@@ -205,8 +232,19 @@ func TestFailFirst(t *testing.T) { ...@@ -205,8 +232,19 @@ func TestFailFirst(t *testing.T) {
} }
func TestStressActiveDial(t *testing.T) { func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) { ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
return nil, nil for {
select {
case req, ok := <-reqch:
if !ok {
return
}
req.Resch <- DialResponse{}
case <-ctx.Done():
return
}
}
}) })
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
......
...@@ -290,14 +290,14 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -290,14 +290,14 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
// lo and behold, The Dialer // lo and behold, The Dialer
// TODO explain how all this works // TODO explain how all this works
////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////
type dialRequest struct { type DialRequest struct {
ctx context.Context Ctx context.Context
resch chan dialResponse Resch chan DialResponse
} }
type dialResponse struct { type DialResponse struct {
conn *Conn Conn *Conn
err error Err error
} }
type dialComplete struct { type dialComplete struct {
...@@ -307,7 +307,7 @@ type dialComplete struct { ...@@ -307,7 +307,7 @@ type dialComplete struct {
} }
// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials // dialWorker is an active dial goroutine that synchronizes and executes concurrent dials
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
if p == s.local { if p == s.local {
for { for {
select { select {
...@@ -316,7 +316,7 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ ...@@ -316,7 +316,7 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ
return return
} }
req.resch <- dialResponse{err: ErrDialToSelf} req.Resch <- DialResponse{Err: ErrDialToSelf}
} }
} }
} }
...@@ -324,11 +324,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ ...@@ -324,11 +324,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ
s.dialWorkerLoop(ctx, p, reqch) s.dialWorkerLoop(ctx, p, reqch)
} }
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) { func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
defer s.limiter.clearAllPeerDials(p) defer s.limiter.clearAllPeerDials(p)
type pendRequest struct { type pendRequest struct {
req dialRequest // the original request req DialRequest // the original request
err *DialError // dial error accumulator err *DialError // dial error accumulator
addrs map[ma.Multiaddr]struct{} // pending addr dials addrs map[ma.Multiaddr]struct{} // pending addr dials
} }
...@@ -368,15 +368,15 @@ loop: ...@@ -368,15 +368,15 @@ loop:
return return
} }
c := s.bestAcceptableConnToPeer(req.ctx, p) c := s.bestAcceptableConnToPeer(req.Ctx, p)
if c != nil { if c != nil {
req.resch <- dialResponse{conn: c} req.Resch <- DialResponse{Conn: c}
continue loop continue loop
} }
addrs, err := s.addrsForDial(req.ctx, p) addrs, err := s.addrsForDial(req.Ctx, p)
if err != nil { if err != nil {
req.resch <- dialResponse{err: err} req.Resch <- DialResponse{Err: err}
continue loop continue loop
} }
...@@ -408,7 +408,7 @@ loop: ...@@ -408,7 +408,7 @@ loop:
if ad.conn != nil { if ad.conn != nil {
// dial to this addr was successful, complete the request // dial to this addr was successful, complete the request
req.resch <- dialResponse{conn: ad.conn} req.Resch <- DialResponse{Conn: ad.conn}
continue loop continue loop
} }
...@@ -424,7 +424,7 @@ loop: ...@@ -424,7 +424,7 @@ loop:
if len(todial) == 0 && len(tojoin) == 0 { if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored // all request applicable addrs have been dialed, we must have errored
req.resch <- dialResponse{err: pr.err} req.Resch <- DialResponse{Err: pr.err}
continue loop continue loop
} }
...@@ -438,7 +438,7 @@ loop: ...@@ -438,7 +438,7 @@ loop:
if len(todial) > 0 { if len(todial) > 0 {
for _, a := range todial { for _, a := range todial {
pending[a] = &addrDial{ctx: req.ctx, requests: []int{reqno}} pending[a] = &addrDial{ctx: req.Ctx, requests: []int{reqno}}
} }
nextDial = append(nextDial, todial...) nextDial = append(nextDial, todial...)
...@@ -492,7 +492,7 @@ loop: ...@@ -492,7 +492,7 @@ loop:
continue continue
} }
pr.req.resch <- dialResponse{conn: res.conn} pr.req.Resch <- DialResponse{Conn: res.conn}
delete(requests, reqno) delete(requests, reqno)
} }
...@@ -513,7 +513,7 @@ loop: ...@@ -513,7 +513,7 @@ loop:
delete(pr.addrs, res.addr) delete(pr.addrs, res.addr)
if len(pr.addrs) == 0 { if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error // all addrs have erred, dispatch dial error
pr.req.resch <- dialResponse{err: pr.err} pr.req.Resch <- DialResponse{Err: pr.err}
delete(requests, reqno) delete(requests, reqno)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment