Commit 8d08e1e3 authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

reuseport: respect dialer timeout

parent 786406bd
......@@ -160,7 +160,7 @@
},
{
"ImportPath": "github.com/jbenet/go-reuseport",
"Rev": "a2e454f12a99b8898c41f9dcebae6c35dc3efa3a"
"Rev": "6924153aded2d61c89a83c8f0738ed4e8df9191f"
},
{
"ImportPath": "github.com/jbenet/go-sockaddr/net",
......
......@@ -18,3 +18,15 @@ func ResolveAddr(network, address string) (net.Addr, error) {
return net.ResolveUnixAddr(network, address)
}
}
// conn is a struct that stores a raddr to get around:
// * https://github.com/golang/go/issues/9661#issuecomment-71043147
// * https://gist.github.com/jbenet/5c191d698fe9ec58c49d
type conn struct {
net.Conn
raddr net.Addr
}
func (c *conn) RemoteAddr() net.Addr {
return c.raddr
}
......@@ -9,6 +9,7 @@ import (
"syscall"
"time"
poll "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll"
sockaddrnet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-sockaddr/net"
)
......@@ -30,20 +31,18 @@ func socket(family, socktype, protocol int) (fd int, err error) {
return -1, err
}
// set non-blocking until after connect, because we cant poll using runtime :(
// cant set it until after connect
// if err = syscall.SetNonblock(fd, true); err != nil {
// syscall.Close(fd)
// return -1, err
// }
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil {
// fmt.Println("reuse addr failed")
syscall.Close(fd)
return -1, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil {
// fmt.Println("reuse port failed")
syscall.Close(fd)
return -1, err
}
......@@ -51,7 +50,6 @@ func socket(family, socktype, protocol int) (fd int, err error) {
// set setLinger to 5 as reusing exact same (srcip:srcport, dstip:dstport)
// will otherwise fail on connect.
if err = setLinger(fd, 5); err != nil {
// fmt.Println("linger failed")
syscall.Close(fd)
return -1, err
}
......@@ -68,13 +66,13 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
lprotocol int
rprotocol int
file *os.File
deadline time.Time
remoteSockaddr syscall.Sockaddr
localSockaddr syscall.Sockaddr
)
netAddr, err := ResolveAddr(netw, addr)
if err != nil {
// fmt.Println("resolve addr failed")
return nil, err
}
......@@ -84,6 +82,13 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
return nil, ErrUnsupportedProtocol
}
switch {
case !dialer.Deadline.IsZero():
deadline = dialer.Deadline
case dialer.Timeout != 0:
deadline = time.Now().Add(dialer.Timeout)
}
localSockaddr = sockaddrnet.NetAddrToSockaddr(dialer.LocalAddr)
remoteSockaddr = sockaddrnet.NetAddrToSockaddr(netAddr)
......@@ -109,18 +114,29 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
// look at dialTCP in http://golang.org/src/net/tcpsock_posix.go .... !
// here we just try again 3 times.
for i := 0; i < 3; i++ {
if !deadline.IsZero() && deadline.Before(time.Now()) {
err = errTimeout
break
}
if fd, err = socket(rfamily, socktype, rprotocol); err != nil {
return nil, err
}
if err = syscall.Bind(fd, localSockaddr); err != nil {
// fmt.Println("bind failed")
if localSockaddr != nil {
if err = syscall.Bind(fd, localSockaddr); err != nil {
syscall.Close(fd)
return nil, err
}
}
if err = syscall.SetNonblock(fd, true); err != nil {
syscall.Close(fd)
return nil, err
}
if err = connect(fd, remoteSockaddr); err != nil {
if err = connect(fd, remoteSockaddr, deadline); err != nil {
syscall.Close(fd)
// fmt.Println("connect failed", localSockaddr, err)
continue // try again.
}
......@@ -133,48 +149,40 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
if rprotocol == syscall.IPPROTO_TCP {
// by default golang/net sets TCP no delay to true.
if err = setNoDelay(fd, true); err != nil {
// fmt.Println("set no delay failed")
syscall.Close(fd)
return nil, err
}
}
if err = syscall.SetNonblock(fd, true); err != nil {
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
syscall.Close(fd)
return nil, err
}
switch socktype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("fileconn failed")
syscall.Close(fd)
return nil, err
}
case syscall.SOCK_DGRAM:
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("fileconn failed")
syscall.Close(fd)
return nil, err
}
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
// c = wrapConnWithRemoteAddr(c, netAddr)
return c, err
}
// there's a rare case where dial returns successfully but for some reason the
// RemoteAddr is not yet set. So, since we know what raddr should be, we just
// wrap it. This is not ideal in that sometimes getpeername() may return a
// different addr. But until this is fixed, best way to do it.
// * https://gist.github.com/jbenet/5c191d698fe9ec58c49d
// * https://github.com/golang/go/issues/9661#issuecomment-71043147
func wrapConnWithRemoteAddr(c net.Conn, raddr net.Addr) net.Conn {
if c.RemoteAddr() == nil {
return &conn{Conn: c, raddr: raddr}
}
return c // it's fine, no need to wrap.
}
func listen(netw, addr string) (fd int, err error) {
var (
family int
......@@ -185,7 +193,6 @@ func listen(netw, addr string) (fd int, err error) {
netAddr, err := ResolveAddr(netw, addr)
if err != nil {
// fmt.Println("resolve addr failed")
return -1, err
}
......@@ -205,7 +212,6 @@ func listen(netw, addr string) (fd int, err error) {
}
if err = syscall.Bind(fd, sockaddr); err != nil {
// fmt.Println("bind failed")
syscall.Close(fd)
return -1, err
}
......@@ -213,7 +219,6 @@ func listen(netw, addr string) (fd int, err error) {
if protocol == syscall.IPPROTO_TCP {
// by default golang/net sets TCP no delay to true.
if err = setNoDelay(fd, true); err != nil {
// fmt.Println("set no delay failed")
syscall.Close(fd)
return -1, err
}
......@@ -239,20 +244,17 @@ func listenStream(netw, addr string) (l net.Listener, err error) {
// Set backlog size to the maximum
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
// fmt.Println("listen failed")
syscall.Close(fd)
return nil, err
}
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if l, err = net.FileListener(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
......@@ -272,13 +274,11 @@ func listenPacket(netw, addr string) (p net.PacketConn, err error) {
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if p, err = net.FilePacketConn(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
......@@ -298,13 +298,11 @@ func listenUDP(netw, addr string) (c net.Conn, err error) {
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
......@@ -313,26 +311,36 @@ func listenUDP(netw, addr string) (c net.Conn, err error) {
}
// this is close to the connect() function inside stdlib/net
func connect(fd int, ra syscall.Sockaddr) error {
func connect(fd int, ra syscall.Sockaddr, deadline time.Time) error {
switch err := syscall.Connect(fd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN:
if !deadline.IsZero() && deadline.Before(time.Now()) {
return errTimeout
}
return nil
default:
return err
}
var err error
start := time.Now()
poller, err := poll.New(fd)
if err != nil {
return err
}
for {
if err = poller.WaitWrite(deadline); err != nil {
return err
}
// if err := fd.pd.WaitWrite(); err != nil {
// return err
// }
// i'd use the above fd.pd.WaitWrite to poll io correctly, just like net sockets...
// but of course, it uses fucking runtime_* functions that _cannot_ be used by
// non-go-stdlib source... seriously guys, what kind of bullshit is that!?
// but of course, it uses the damn runtime_* functions that _cannot_ be used by
// non-go-stdlib source... seriously guys, this is not nice.
// we're relegated to using syscall.Select (what nightmare that is) or using
// a simple but totally bogus time-based wait. garbage.
// a simple but totally bogus time-based wait. such garbage.
var nerr int
nerr, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil {
......@@ -340,14 +348,22 @@ func connect(fd int, ra syscall.Sockaddr) error {
}
switch err = syscall.Errno(nerr); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
if time.Now().Sub(start) > time.Second {
return err
}
<-time.After(20 * time.Microsecond)
continue
case syscall.Errno(0), syscall.EISCONN:
if !deadline.IsZero() && deadline.Before(time.Now()) {
return errTimeout
}
return nil
default:
return err
}
}
}
var errTimeout = &timeoutError{}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
......@@ -91,21 +91,20 @@ type Dialer struct {
// Returns a net.Conn created from a file discriptor for a socket
// with SO_REUSEPORT and SO_REUSEADDR option set.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
c, err := dial(d.D, network, address)
if err != nil {
return nil, err
if !available() {
return nil, syscall.Errno(syscall.ENOPROTOOPT)
}
// there's a rare case where dial returns successfully but for some reason the
// RemoteAddr is not yet set. We wait here a while until it is, and if too long
// passes, we fail. This is horrendous.
for start := time.Now(); c.RemoteAddr() == nil; {
if time.Now().Sub(start) > (time.Millisecond * 500) {
c.Close()
return nil, ErrReuseFailed
}
return dial(d.D, network, address)
}
<-time.After(20 * time.Microsecond)
func (d *Dialer) deadline(def time.Duration) time.Time {
switch {
case !d.D.Deadline.IsZero():
return d.D.Deadline
case d.D.Timeout != 0:
return time.Now().Add(d.D.Timeout)
default:
return time.Now().Add(def)
}
return c, nil
}
package poll
var errTimeout = &timeoutError{}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
// +build darwin freebsd dragonfly netbsd openbsd
package poll
import (
"syscall"
"time"
)
type Poller struct {
kqfd int
event syscall.Kevent_t
}
func New(fd int) (p *Poller, err error) {
p = &Poller{}
p.kqfd, err = syscall.Kqueue()
if p.kqfd == -1 || err != nil {
return nil, err
}
p.event = syscall.Kevent_t{
Ident: uint64(fd),
Filter: syscall.EVFILT_WRITE,
Flags: syscall.EV_ADD | syscall.EV_ENABLE | syscall.EV_ONESHOT,
Fflags: 0,
Data: 0,
Udata: nil,
}
return p, nil
}
func (p *Poller) Close() error {
return syscall.Close(p.kqfd)
}
func (p *Poller) WaitWrite(deadline time.Time) error {
// setup timeout
var timeout *syscall.Timespec
if !deadline.IsZero() {
d := deadline.Sub(time.Now())
t := syscall.NsecToTimespec(d.Nanoseconds())
timeout = &t
}
// wait on kevent
events := make([]syscall.Kevent_t, 1)
n, err := syscall.Kevent(p.kqfd, []syscall.Kevent_t{p.event}, events, timeout)
if err != nil {
return err
}
if n < 1 {
return errTimeout
}
return nil
}
// +build linux
package poll
import (
"syscall"
"time"
)
type Poller struct {
epfd int
event syscall.EpollEvent
events [32]syscall.EpollEvent
}
func New(fd int) (p *Poller, err error) {
p = &Poller{}
if p.epfd, err = syscall.EpollCreate1(0); err != nil {
return nil, err
}
p.event.Events = syscall.EPOLLOUT
p.event.Fd = int32(fd)
if err = syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_ADD, fd, &p.event); err != nil {
p.Close()
return nil, err
}
return p, nil
}
func (p *Poller) Close() error {
return syscall.Close(p.epfd)
}
func (p *Poller) WaitWrite(deadline time.Time) error {
msec := -1
if !deadline.IsZero() {
d := deadline.Sub(time.Now())
msec = int(d.Nanoseconds() / 1000000) // ms!? omg...
}
n, err := syscall.EpollWait(p.epfd, p.events[:], msec)
if err != nil {
return err
}
if n < 1 {
return errTimeout
}
return nil
}
// +build windows plan9
package poll
import (
"errors"
)
func WaitWrite(fd int) error {
return errors.New("platform not supported")
}
......@@ -2,11 +2,15 @@ package reuseport
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"testing"
"time"
)
func echo(c net.Conn) {
......@@ -226,7 +230,7 @@ func TestStreamListenDialSamePort(t *testing.T) {
c1, err := Dial(network, l1.Addr().String(), l2.Addr().String())
if err != nil {
t.Fatal(err)
t.Fatal(err, network, l1.Addr().String(), l2.Addr().String())
continue
}
defer c1.Close()
......@@ -260,6 +264,120 @@ func TestStreamListenDialSamePort(t *testing.T) {
}
}
func TestStreamListenDialSamePortStressManyMsgs(t *testing.T) {
testCases := [][]string{
[]string{"tcp", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:0"},
[]string{"tcp6", "[::]:0"},
}
for _, tcase := range testCases {
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 2, 1000)
}
}
func TestStreamListenDialSamePortStressManyNodes(t *testing.T) {
testCases := [][]string{
[]string{"tcp", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:0"},
[]string{"tcp6", "[::]:0"},
}
for _, tcase := range testCases {
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 1)
}
}
func TestStreamListenDialSamePortStressManyMsgsManyNodes(t *testing.T) {
testCases := [][]string{
[]string{"tcp", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:0"},
[]string{"tcp6", "[::]:0"},
}
for _, tcase := range testCases {
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 100)
}
}
func subestStreamListenDialSamePortStress(t *testing.T, network, addr string, nodes int, msgs int) {
t.Logf("testing %s:%s %d nodes %d msgs", network, addr, nodes, msgs)
var ls []net.Listener
for i := 0; i < nodes; i++ {
l, err := Listen(network, addr)
if err != nil {
t.Fatal(err)
}
defer l.Close()
go acceptAndEcho(l)
ls = append(ls, l)
t.Logf("listening %s", l.Addr())
}
// connect them all
var cs []net.Conn
for i := 0; i < nodes; i++ {
for j := 0; j < i; j++ {
if i == j {
continue // cannot do self.
}
ia := ls[i].Addr().String()
ja := ls[j].Addr().String()
c, err := Dial(network, ia, ja)
if err != nil {
t.Fatal(network, ia, ja, err)
}
defer c.Close()
cs = append(cs, c)
t.Logf("dialed %s --> %s", c.LocalAddr(), c.RemoteAddr())
}
}
errs := make(chan error)
send := func(c net.Conn, buf []byte) {
if _, err := c.Write(buf); err != nil {
errs <- err
}
}
recv := func(c net.Conn, buf []byte) {
buf2 := make([]byte, len(buf))
if _, err := c.Read(buf2); err != nil {
errs <- err
}
if !bytes.Equal(buf, buf2) {
errs <- fmt.Errorf("recv failure: %s <--> %s -- %s %s", c.RemoteAddr(), c.LocalAddr(), buf, buf2)
}
}
t.Logf("sending %d msgs per conn", msgs)
go func() {
var wg sync.WaitGroup
for _, c := range cs {
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
for i := 0; i < msgs; i++ {
msg := []byte(fmt.Sprintf("message %d", i))
send(c, msg)
recv(c, msg)
}
}(c)
}
wg.Wait()
close(errs)
}()
for err := range errs {
if err != nil {
t.Error(err)
}
}
}
func TestPacketListenDialSamePort(t *testing.T) {
any := [][]string{
......@@ -343,6 +461,68 @@ func TestPacketListenDialSamePort(t *testing.T) {
}
}
func TestDialRespectsTimeout(t *testing.T) {
testCases := [][]string{
[]string{"tcp", "127.0.0.1:6780", "1.2.3.4:6781"},
[]string{"tcp4", "127.0.0.1:6782", "1.2.3.4:6783"},
[]string{"tcp6", "[::1]:6784", "[::2]:6785"},
}
timeout := 50 * time.Millisecond
for _, tcase := range testCases {
network := tcase[0]
laddr := tcase[1]
raddr := tcase[2]
// l, err := Listen(network, raddr)
// if err != nil {
// t.Error("without a listener it wont work")
// continue
// }
// defer l.Close()
nladdr, err := ResolveAddr(network, laddr)
if err != nil {
t.Error("failed to resolve addr", network, laddr, err)
continue
}
t.Log("testing", network, nladdr, raddr)
d := Dialer{
D: net.Dialer{
LocalAddr: nil,
Timeout: timeout,
},
}
errs := make(chan error)
go func() {
c, err := d.Dial(network, raddr)
if err == nil {
errs <- errors.New("should've not connected")
c.Close()
return
}
close(errs) // success!
}()
ErrDrain:
select {
case <-time.After(5 * time.Second):
t.Fatal("took too long")
case err, more := <-errs:
if !more {
break
}
t.Error(err)
goto ErrDrain
}
}
}
func TestUnixNotSupported(t *testing.T) {
testCases := [][]string{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment