Commit 2e742fb9 authored by vyzo's avatar vyzo

fix dial_sync tests

parent 201a8d14
......@@ -13,7 +13,7 @@ import (
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
// DialFunc is the type of function expected by DialSync.
type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest)
type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest)
// NewDialSync constructs a new DialSync
func NewDialSync(worker DialWorkerFunc) *DialSync {
......@@ -38,7 +38,7 @@ type activeDial struct {
ctx context.Context
cancel func()
reqch chan dialRequest
reqch chan DialRequest
ds *DialSync
}
......@@ -68,16 +68,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
}
resch := make(chan dialResponse, 1)
resch := make(chan DialResponse, 1)
select {
case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case ad.reqch <- DialRequest{Ctx: dialCtx, Resch: resch}:
case <-ctx.Done():
return nil, ctx.Err()
}
select {
case res := <-resch:
return res.conn, res.err
return res.Conn, res.Err
case <-ctx.Done():
return nil, ctx.Err()
}
......@@ -98,7 +98,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) *activeDial {
id: p,
ctx: adctx,
cancel: cancel,
reqch: make(chan dialRequest),
reqch: make(chan DialRequest),
ds: ds,
}
ds.dials[p] = actd
......
......@@ -12,20 +12,34 @@ import (
"github.com/libp2p/go-libp2p-core/peer"
)
func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) {
func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}) {
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care
dialctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{})
f := func(ctx context.Context, p peer.ID) (*Conn, error) {
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
dfcalls <- struct{}{}
go func() {
defer cancel()
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
select {
case <-ch:
return new(Conn), nil
req.Resch <- DialResponse{Conn: new(Conn)}
case <-ctx.Done():
req.Resch <- DialResponse{Err: ctx.Err()}
return
}
case <-ctx.Done():
return nil, ctx.Err()
return
}
}
}()
}
o := new(sync.Once)
......@@ -174,12 +188,25 @@ func TestDialSyncAllCancel(t *testing.T) {
func TestFailFirst(t *testing.T) {
var count int
f := func(ctx context.Context, p peer.ID) (*Conn, error) {
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
if count > 0 {
return new(Conn), nil
req.Resch <- DialResponse{Conn: new(Conn)}
} else {
req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")}
}
count++
return nil, fmt.Errorf("gophers ate the modem")
case <-ctx.Done():
return
}
}
}
ds := NewDialSync(f)
......@@ -205,8 +232,19 @@ func TestFailFirst(t *testing.T) {
}
func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, nil
ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
for {
select {
case req, ok := <-reqch:
if !ok {
return
}
req.Resch <- DialResponse{}
case <-ctx.Done():
return
}
}
})
wg := sync.WaitGroup{}
......
......@@ -290,14 +290,14 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
// lo and behold, The Dialer
// TODO explain how all this works
//////////////////////////////////////////////////////////////////////////////////
type dialRequest struct {
ctx context.Context
resch chan dialResponse
type DialRequest struct {
Ctx context.Context
Resch chan DialResponse
}
type dialResponse struct {
conn *Conn
err error
type DialResponse struct {
Conn *Conn
Err error
}
type dialComplete struct {
......@@ -307,7 +307,7 @@ type dialComplete struct {
}
// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) {
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
if p == s.local {
for {
select {
......@@ -316,7 +316,7 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ
return
}
req.resch <- dialResponse{err: ErrDialToSelf}
req.Resch <- DialResponse{Err: ErrDialToSelf}
}
}
}
......@@ -324,11 +324,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequ
s.dialWorkerLoop(ctx, p, reqch)
}
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) {
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) {
defer s.limiter.clearAllPeerDials(p)
type pendRequest struct {
req dialRequest // the original request
req DialRequest // the original request
err *DialError // dial error accumulator
addrs map[ma.Multiaddr]struct{} // pending addr dials
}
......@@ -368,15 +368,15 @@ loop:
return
}
c := s.bestAcceptableConnToPeer(req.ctx, p)
c := s.bestAcceptableConnToPeer(req.Ctx, p)
if c != nil {
req.resch <- dialResponse{conn: c}
req.Resch <- DialResponse{Conn: c}
continue loop
}
addrs, err := s.addrsForDial(req.ctx, p)
addrs, err := s.addrsForDial(req.Ctx, p)
if err != nil {
req.resch <- dialResponse{err: err}
req.Resch <- DialResponse{Err: err}
continue loop
}
......@@ -408,7 +408,7 @@ loop:
if ad.conn != nil {
// dial to this addr was successful, complete the request
req.resch <- dialResponse{conn: ad.conn}
req.Resch <- DialResponse{Conn: ad.conn}
continue loop
}
......@@ -424,7 +424,7 @@ loop:
if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
req.resch <- dialResponse{err: pr.err}
req.Resch <- DialResponse{Err: pr.err}
continue loop
}
......@@ -438,7 +438,7 @@ loop:
if len(todial) > 0 {
for _, a := range todial {
pending[a] = &addrDial{ctx: req.ctx, requests: []int{reqno}}
pending[a] = &addrDial{ctx: req.Ctx, requests: []int{reqno}}
}
nextDial = append(nextDial, todial...)
......@@ -492,7 +492,7 @@ loop:
continue
}
pr.req.resch <- dialResponse{conn: res.conn}
pr.req.Resch <- DialResponse{Conn: res.conn}
delete(requests, reqno)
}
......@@ -513,7 +513,7 @@ loop:
delete(pr.addrs, res.addr)
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
pr.req.resch <- dialResponse{err: pr.err}
pr.req.Resch <- DialResponse{Err: pr.err}
delete(requests, reqno)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment