Commit 16248280 authored by vyzo's avatar vyzo

make DialRequest and DialResponse private

parent 27f6c394
...@@ -12,8 +12,8 @@ import ( ...@@ -12,8 +12,8 @@ import (
// TODO: change this text when we fix the bug // TODO: change this text when we fix the bug
var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW") var errDialCanceled = errors.New("dial was aborted internally, likely due to https://git.io/Je2wW")
// DialFunc is the type of function expected by DialSync. // DialWorerFunc is used by DialSync to spawn a new dial worker
type DialWorkerFunc func(context.Context, peer.ID, <-chan DialRequest) error type DialWorkerFunc func(context.Context, peer.ID, <-chan dialRequest) error
// NewDialSync constructs a new DialSync // NewDialSync constructs a new DialSync
func NewDialSync(worker DialWorkerFunc) *DialSync { func NewDialSync(worker DialWorkerFunc) *DialSync {
...@@ -38,7 +38,7 @@ type activeDial struct { ...@@ -38,7 +38,7 @@ type activeDial struct {
ctx context.Context ctx context.Context
cancel func() cancel func()
reqch chan DialRequest reqch chan dialRequest
ds *DialSync ds *DialSync
} }
...@@ -64,16 +64,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -64,16 +64,16 @@ func (ad *activeDial) dial(ctx context.Context, p peer.ID) (*Conn, error) {
dialCtx = network.WithSimultaneousConnect(dialCtx, reason) dialCtx = network.WithSimultaneousConnect(dialCtx, reason)
} }
resch := make(chan DialResponse, 1) resch := make(chan dialResponse, 1)
select { select {
case ad.reqch <- DialRequest{Ctx: dialCtx, Resch: resch}: case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}:
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
select { select {
case res := <-resch: case res := <-resch:
return res.Conn, res.Err return res.conn, res.err
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
} }
...@@ -94,7 +94,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) { ...@@ -94,7 +94,7 @@ func (ds *DialSync) getActiveDial(p peer.ID) (*activeDial, error) {
id: p, id: p,
ctx: adctx, ctx: adctx,
cancel: cancel, cancel: cancel,
reqch: make(chan DialRequest), reqch: make(chan dialRequest),
ds: ds, ds: ds,
} }
......
package swarm_test package swarm
import ( import (
"context" "context"
...@@ -7,8 +7,6 @@ import ( ...@@ -7,8 +7,6 @@ import (
"testing" "testing"
"time" "time"
. "github.com/libp2p/go-libp2p-swarm"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
) )
...@@ -16,7 +14,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} ...@@ -16,7 +14,7 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care
dialctx, cancel := context.WithCancel(context.Background()) dialctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{}) ch := make(chan struct{})
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
dfcalls <- struct{}{} dfcalls <- struct{}{}
go func() { go func() {
defer cancel() defer cancel()
...@@ -29,9 +27,9 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{} ...@@ -29,9 +27,9 @@ func getMockDialFunc() (DialWorkerFunc, func(), context.Context, <-chan struct{}
select { select {
case <-ch: case <-ch:
req.Resch <- DialResponse{Conn: new(Conn)} req.resch <- dialResponse{conn: new(Conn)}
case <-ctx.Done(): case <-ctx.Done():
req.Resch <- DialResponse{Err: ctx.Err()} req.resch <- dialResponse{err: ctx.Err()}
return return
} }
case <-ctx.Done(): case <-ctx.Done():
...@@ -189,7 +187,7 @@ func TestDialSyncAllCancel(t *testing.T) { ...@@ -189,7 +187,7 @@ func TestDialSyncAllCancel(t *testing.T) {
func TestFailFirst(t *testing.T) { func TestFailFirst(t *testing.T) {
var count int var count int
f := func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { f := func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
go func() { go func() {
for { for {
select { select {
...@@ -199,9 +197,9 @@ func TestFailFirst(t *testing.T) { ...@@ -199,9 +197,9 @@ func TestFailFirst(t *testing.T) {
} }
if count > 0 { if count > 0 {
req.Resch <- DialResponse{Conn: new(Conn)} req.resch <- dialResponse{conn: new(Conn)}
} else { } else {
req.Resch <- DialResponse{Err: fmt.Errorf("gophers ate the modem")} req.resch <- dialResponse{err: fmt.Errorf("gophers ate the modem")}
} }
count++ count++
...@@ -236,7 +234,7 @@ func TestFailFirst(t *testing.T) { ...@@ -236,7 +234,7 @@ func TestFailFirst(t *testing.T) {
} }
func TestStressActiveDial(t *testing.T) { func TestStressActiveDial(t *testing.T) {
ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { ds := NewDialSync(func(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
go func() { go func() {
for { for {
select { select {
...@@ -245,7 +243,7 @@ func TestStressActiveDial(t *testing.T) { ...@@ -245,7 +243,7 @@ func TestStressActiveDial(t *testing.T) {
return return
} }
req.Resch <- DialResponse{} req.resch <- dialResponse{}
case <-ctx.Done(): case <-ctx.Done():
return return
} }
......
...@@ -285,18 +285,18 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { ...@@ -285,18 +285,18 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
// TODO explain how all this works // TODO explain how all this works
////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////
type DialRequest struct { type dialRequest struct {
Ctx context.Context ctx context.Context
Resch chan DialResponse resch chan dialResponse
} }
type DialResponse struct { type dialResponse struct {
Conn *Conn conn *Conn
Err error err error
} }
// dialWorker is an active dial goroutine that synchronizes and executes concurrent dials // dialWorker is an active dial goroutine that synchronizes and executes concurrent dials
func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequest) error { func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan dialRequest) error {
if p == s.local { if p == s.local {
return ErrDialToSelf return ErrDialToSelf
} }
...@@ -305,11 +305,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequ ...@@ -305,11 +305,11 @@ func (s *Swarm) dialWorker(ctx context.Context, p peer.ID, reqch <-chan DialRequ
return nil return nil
} }
func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan DialRequest) { func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan dialRequest) {
defer s.limiter.clearAllPeerDials(p) defer s.limiter.clearAllPeerDials(p)
type pendRequest struct { type pendRequest struct {
req DialRequest // the original request req dialRequest // the original request
err *DialError // dial error accumulator err *DialError // dial error accumulator
addrs map[ma.Multiaddr]struct{} // pending addr dials addrs map[ma.Multiaddr]struct{} // pending addr dials
} }
...@@ -344,11 +344,11 @@ func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan Dial ...@@ -344,11 +344,11 @@ func (s *Swarm) dialWorkerLoop(ctx context.Context, p peer.ID, reqch <-chan Dial
// all addrs have erred, dispatch dial error // all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from // but first do a last one check in case an acceptable connection has landed from
// a simultaneous dial that started later and added new acceptable addrs // a simultaneous dial that started later and added new acceptable addrs
c := s.bestAcceptableConnToPeer(pr.req.Ctx, p) c := s.bestAcceptableConnToPeer(pr.req.ctx, p)
if c != nil { if c != nil {
pr.req.Resch <- DialResponse{Conn: c} pr.req.resch <- dialResponse{conn: c}
} else { } else {
pr.req.Resch <- DialResponse{Err: pr.err} pr.req.resch <- dialResponse{err: pr.err}
} }
delete(requests, reqno) delete(requests, reqno)
} }
...@@ -390,15 +390,15 @@ loop: ...@@ -390,15 +390,15 @@ loop:
return return
} }
c := s.bestAcceptableConnToPeer(req.Ctx, p) c := s.bestAcceptableConnToPeer(req.ctx, p)
if c != nil { if c != nil {
req.Resch <- DialResponse{Conn: c} req.resch <- dialResponse{conn: c}
continue loop continue loop
} }
addrs, err := s.addrsForDial(req.Ctx, p) addrs, err := s.addrsForDial(req.ctx, p)
if err != nil { if err != nil {
req.Resch <- DialResponse{Err: err} req.resch <- dialResponse{err: err}
continue loop continue loop
} }
...@@ -430,7 +430,7 @@ loop: ...@@ -430,7 +430,7 @@ loop:
if ad.conn != nil { if ad.conn != nil {
// dial to this addr was successful, complete the request // dial to this addr was successful, complete the request
req.Resch <- DialResponse{Conn: ad.conn} req.resch <- dialResponse{conn: ad.conn}
continue loop continue loop
} }
...@@ -447,7 +447,7 @@ loop: ...@@ -447,7 +447,7 @@ loop:
if len(todial) == 0 && len(tojoin) == 0 { if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored // all request applicable addrs have been dialed, we must have errored
req.Resch <- DialResponse{Err: pr.err} req.resch <- dialResponse{err: pr.err}
continue loop continue loop
} }
...@@ -457,14 +457,14 @@ loop: ...@@ -457,14 +457,14 @@ loop:
for _, ad := range tojoin { for _, ad := range tojoin {
if !ad.dialed { if !ad.dialed {
ad.ctx = s.mergeDialContexts(ad.ctx, req.Ctx) ad.ctx = s.mergeDialContexts(ad.ctx, req.ctx)
} }
ad.requests = append(ad.requests, reqno) ad.requests = append(ad.requests, reqno)
} }
if len(todial) > 0 { if len(todial) > 0 {
for _, a := range todial { for _, a := range todial {
pending[a] = &addrDial{addr: a, ctx: req.Ctx, requests: []int{reqno}} pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{reqno}}
} }
nextDial = append(nextDial, todial...) nextDial = append(nextDial, todial...)
...@@ -550,7 +550,7 @@ loop: ...@@ -550,7 +550,7 @@ loop:
continue continue
} }
pr.req.Resch <- DialResponse{Conn: conn} pr.req.resch <- dialResponse{conn: conn}
delete(requests, reqno) delete(requests, reqno)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment