Commit 8437d27d authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

fixed udp (was totally broken)

parent e3b33eb2
......@@ -7,6 +7,7 @@ import (
"os"
"strconv"
"syscall"
"time"
sockaddrnet "github.com/jbenet/go-sockaddr/net"
)
......@@ -17,9 +18,57 @@ const (
filePrefix = "port."
)
// Wrapper around the socket system call that marks the returned file
// descriptor as nonblocking and close-on-exec.
func socket(family, socktype, protocol int) (fd int, err error) {
syscall.ForkLock.RLock()
fd, err = syscall.Socket(family, socktype, protocol)
if err == nil {
syscall.CloseOnExec(fd)
}
syscall.ForkLock.RUnlock()
if err != nil {
return -1, err
}
// set non-blocking until after connect, because we cant poll using runtime :(
// if err = syscall.SetNonblock(fd, true); err != nil {
// syscall.Close(fd)
// return -1, err
// }
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil {
// fmt.Println("reuse addr failed")
syscall.Close(fd)
return -1, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil {
// fmt.Println("reuse port failed")
syscall.Close(fd)
return -1, err
}
// set setLinger to 5 as reusing exact same (srcip:srcport, dstip:dstport)
// will otherwise fail on connect.
if err = setLinger(fd, 5); err != nil {
// fmt.Println("linger failed")
syscall.Close(fd)
return -1, err
}
return fd, nil
}
func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
var (
family, fd int
fd int
lfamily int
rfamily int
socktype int
lprotocol int
rprotocol int
file *os.File
remoteSockaddr syscall.Sockaddr
localSockaddr syscall.Sockaddr
......@@ -37,114 +86,271 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
return nil, ErrUnsupportedProtocol
}
switch dialer.LocalAddr.(type) {
case *net.TCPAddr, *net.UDPAddr:
default:
return nil, ErrUnsupportedProtocol
}
family = sockaddrnet.NetAddrAF(netAddr)
localSockaddr = sockaddrnet.NetAddrToSockaddr(dialer.LocalAddr)
remoteSockaddr = sockaddrnet.NetAddrToSockaddr(netAddr)
if fd, err = syscall.Socket(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP); err != nil {
// fmt.Println("tcp socket failed")
return nil, err
}
rfamily = sockaddrnet.NetAddrAF(netAddr)
rprotocol = sockaddrnet.NetAddrIPPROTO(netAddr)
socktype = sockaddrnet.NetAddrSOCK(netAddr)
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil {
// fmt.Println("reuse addr failed")
return nil, err
}
if dialer.LocalAddr != nil {
switch dialer.LocalAddr.(type) {
case *net.TCPAddr, *net.UDPAddr:
default:
return nil, ErrUnsupportedProtocol
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil {
// fmt.Println("reuse port failed")
return nil, err
// check family and protocols match.
lfamily = sockaddrnet.NetAddrAF(dialer.LocalAddr)
lprotocol = sockaddrnet.NetAddrIPPROTO(dialer.LocalAddr)
if lfamily != rfamily && lprotocol != rfamily {
return nil, &net.AddrError{Err: "unexpected address type", Addr: netAddr.String()}
}
}
if localSockaddr != nil {
// try to connect in a loop for EADDRINUSE errors.
start := time.Now()
for {
if fd, err = socket(rfamily, socktype, rprotocol); err != nil {
return nil, err
}
if err = syscall.Bind(fd, localSockaddr); err != nil {
// fmt.Println("bind failed")
syscall.Close(fd)
return nil, err
}
if err = connect(fd, remoteSockaddr); err != nil {
syscall.Close(fd)
if err == syscall.EADDRINUSE {
// if we've waited longer than 2 seconds bail.
if time.Now().Sub(start) > 2*time.Second {
return nil, err
}
// otherwise, wait a bit and try again
<-time.After(20 * time.Microsecond)
continue
}
// fmt.Println("connect failed", localSockaddr, err)
return nil, err
}
break
}
// Set backlog size to the maximum
if err = syscall.Connect(fd, remoteSockaddr); err != nil {
// fmt.Println("connect failed")
return nil, err
if rprotocol == syscall.IPPROTO_TCP {
// by default golang/net sets TCP no delay to true.
if err = setNoDelay(fd, true); err != nil {
// fmt.Println("set no delay failed")
syscall.Close(fd)
return nil, err
}
}
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
if err = syscall.SetNonblock(fd, true); err != nil {
syscall.Close(fd)
return nil, err
}
switch socktype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("fileconn failed")
syscall.Close(fd)
return nil, err
}
case syscall.SOCK_DGRAM:
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("fileconn failed")
syscall.Close(fd)
return nil, err
}
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
return c, err
}
func listen(netw, addr string) (l net.Listener, err error) {
func listen(netw, addr string) (fd int, err error) {
var (
family, fd int
file *os.File
sockaddr syscall.Sockaddr
family int
socktype int
protocol int
sockaddr syscall.Sockaddr
)
netAddr, err := ResolveAddr(netw, addr)
if err != nil {
// fmt.Println("resolve addr failed")
return nil, err
return -1, err
}
switch netAddr.(type) {
case *net.TCPAddr, *net.UDPAddr:
default:
return nil, ErrUnsupportedProtocol
return -1, ErrUnsupportedProtocol
}
family = sockaddrnet.NetAddrAF(netAddr)
protocol = sockaddrnet.NetAddrIPPROTO(netAddr)
sockaddr = sockaddrnet.NetAddrToSockaddr(netAddr)
socktype = sockaddrnet.NetAddrSOCK(netAddr)
if fd, err = syscall.Socket(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP); err != nil {
// fmt.Println("socket failed")
return nil, err
if fd, err = socket(family, socktype, protocol); err != nil {
return -1, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil {
// fmt.Println("setsockopt reuseport failed")
return nil, err
if err = syscall.Bind(fd, sockaddr); err != nil {
// fmt.Println("bind failed")
syscall.Close(fd)
return -1, err
}
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil {
// fmt.Println("setsockopt reuseaddr failed")
return nil, err
if protocol == syscall.IPPROTO_TCP {
// by default golang/net sets TCP no delay to true.
if err = setNoDelay(fd, true); err != nil {
// fmt.Println("set no delay failed")
syscall.Close(fd)
return -1, err
}
}
if err = syscall.Bind(fd, sockaddr); err != nil {
// fmt.Println("bind failed")
if err = syscall.SetNonblock(fd, true); err != nil {
syscall.Close(fd)
return -1, err
}
return fd, nil
}
func listenStream(netw, addr string) (l net.Listener, err error) {
var (
file *os.File
)
fd, err := listen(netw, addr)
if err != nil {
return nil, err
}
// Set backlog size to the maximum
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
// fmt.Println("listen failed")
syscall.Close(fd)
return nil, err
}
// File Name get be nil
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if l, err = net.FileListener(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
return l, err
}
func listenPacket(netw, addr string) (p net.PacketConn, err error) {
var (
file *os.File
)
fd, err := listen(netw, addr)
if err != nil {
return nil, err
}
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if p, err = net.FilePacketConn(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
return p, err
}
func listenUDP(netw, addr string) (c net.Conn, err error) {
var (
file *os.File
)
fd, err := listen(netw, addr)
if err != nil {
return nil, err
}
file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
if c, err = net.FileConn(file); err != nil {
// fmt.Println("filelistener failed")
syscall.Close(fd)
return nil, err
}
if err = file.Close(); err != nil {
// fmt.Println("file close failed")
syscall.Close(fd)
return nil, err
}
return c, err
}
func connect(fd int, ra syscall.Sockaddr) error {
switch err := syscall.Connect(fd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN:
return nil
default:
return err
}
var err error
for {
// if err := fd.pd.WaitWrite(); err != nil {
// return err
// }
// i'd use the above fd.pd.WaitWrite to poll io correctly, just like net sockets...
// but of course, it uses fucking runtime_* functions that _cannot_ be used by
// non-go-stdlib source... seriously guys, what kind of bullshit is that!?
<-time.After(20 * time.Microsecond)
var nerr int
nerr, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil {
return err
}
switch err = syscall.Errno(nerr); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case syscall.Errno(0), syscall.EISCONN:
return nil
default:
return err
}
}
}
......@@ -27,13 +27,20 @@ import (
var ErrUnsupportedProtocol = errors.New("protocol not yet supported")
// ErrReuseFailed is returned if a reuse attempt was unsuccessful.
var ErrReuseFailed = errors.New("protocol not yet supported")
var ErrReuseFailed = errors.New("reuse failed")
// Listen listens at the given network and address. see net.Listen
// Returns a net.Listener created from a file discriptor for a socket
// with SO_REUSEPORT and SO_REUSEADDR option set.
func Listen(network, address string) (net.Listener, error) {
return listen(network, address)
return listenStream(network, address)
}
// ListenPacket listens at the given network and address. see net.ListenPacket
// Returns a net.Listener created from a file discriptor for a socket
// with SO_REUSEPORT and SO_REUSEADDR option set.
func ListenPacket(network, address string) (net.PacketConn, error) {
return listenPacket(network, address)
}
// Dial dials the given network and address. see net.Dialer.Dial
......
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris windows
package reuseport
import (
"os"
"syscall"
)
func boolint(b bool) int {
if b {
return 1
}
return 0
}
func setNoDelay(fd int, noDelay bool) error {
return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)))
}
func setLinger(fd int, sec int) error {
var l syscall.Linger
if sec >= 0 {
l.Onoff = 1
l.Linger = int32(sec)
} else {
l.Onoff = 0
l.Linger = 0
}
return os.NewSyscallError("setsockopt", syscall.SetsockoptLinger(fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l))
}
......@@ -14,6 +14,19 @@ func echo(c net.Conn) {
c.Close()
}
func packetEcho(c net.PacketConn) {
buf := make([]byte, 65536)
for {
n, addr, err := c.ReadFrom(buf)
if err != nil {
return
}
if _, err := c.WriteTo(buf[:n], addr); err != nil {
return
}
}
}
func acceptAndEcho(l net.Listener) {
for {
c, err := l.Accept()
......@@ -28,25 +41,18 @@ func CI() bool {
return os.Getenv("TRAVIS") == "true"
}
func TestListenSamePort(t *testing.T) {
func TestStreamListenSamePort(t *testing.T) {
// any ports
any := [][]string{
[]string{"tcp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp6", "[::]:0", "[::]:0"},
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp6", "[::]:0", "[::]:0"},
[]string{"tcp", "0.0.0.0:0"},
[]string{"tcp4", "0.0.0.0:0"},
[]string{"tcp6", "[::]:0"},
[]string{"tcp", "127.0.0.1:0"},
[]string{"tcp", "[::1]:0"},
[]string{"tcp4", "127.0.0.1:0"},
[]string{"tcp6", "[::1]:0"},
[]string{"udp", "127.0.0.1:0"},
[]string{"udp", "[::1]:0"},
[]string{"udp4", "127.0.0.1:0"},
[]string{"udp6", "[::1]:0"},
}
// specific ports. off in CI
......@@ -55,10 +61,6 @@ func TestListenSamePort(t *testing.T) {
[]string{"tcp", "[::1]:5557"},
[]string{"tcp4", "127.0.0.1:5558"},
[]string{"tcp6", "[::1]:5559"},
[]string{"udp", "127.0.0.1:5560"},
[]string{"udp", "[::1]:5561"},
[]string{"udp4", "127.0.0.1:5562"},
[]string{"udp6", "[::1]:5563"},
}
testCases := any
......@@ -105,31 +107,91 @@ func TestListenSamePort(t *testing.T) {
}
}
func TestListenDialSamePort(t *testing.T) {
func TestPacketListenSamePort(t *testing.T) {
// any ports
any := [][]string{
[]string{"udp", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0"},
[]string{"udp6", "[::]:0"},
[]string{"udp", "127.0.0.1:0"},
[]string{"udp", "[::1]:0"},
[]string{"udp4", "127.0.0.1:0"},
[]string{"udp6", "[::1]:0"},
}
// specific ports. off in CI
specific := [][]string{
[]string{"udp", "127.0.0.1:5560"},
[]string{"udp", "[::1]:5561"},
[]string{"udp4", "127.0.0.1:5562"},
[]string{"udp6", "[::1]:5563"},
}
testCases := any
if !CI() {
testCases = append(testCases, specific...)
}
for _, tcase := range testCases {
network := tcase[0]
addr := tcase[1]
t.Log("testing", network, addr)
l1, err := ListenPacket(network, addr)
if err != nil {
t.Fatal(err)
continue
}
defer l1.Close()
t.Log("listening", l1.LocalAddr())
l2, err := ListenPacket(l1.LocalAddr().Network(), l1.LocalAddr().String())
if err != nil {
t.Fatal(err)
continue
}
defer l2.Close()
t.Log("listening", l2.LocalAddr())
l3, err := ListenPacket(l2.LocalAddr().Network(), l2.LocalAddr().String())
if err != nil {
t.Fatal(err)
continue
}
defer l3.Close()
t.Log("listening", l3.LocalAddr())
if l1.LocalAddr().String() != l2.LocalAddr().String() {
t.Fatal("addrs should match", l1.LocalAddr(), l2.LocalAddr())
}
if l1.LocalAddr().String() != l3.LocalAddr().String() {
t.Fatal("addrs should match", l1.LocalAddr(), l3.LocalAddr())
}
}
}
func TestStreamListenDialSamePort(t *testing.T) {
any := [][]string{
[]string{"tcp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"tcp6", "[::]:0", "[::]:0"},
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp6", "[::]:0", "[::]:0"},
[]string{"tcp", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"tcp6", "[::1]:0", "[::1]:0"},
[]string{"udp", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"udp4", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"udp6", "[::1]:0", "[::1]:0"},
}
specific := [][]string{
[]string{"tcp", "127.0.0.1:5570", "127.0.0.1:5571"},
[]string{"tcp4", "127.0.0.1:5572", "127.0.0.1:5573"},
[]string{"tcp6", "[::1]:5573", "[::1]:5574"},
[]string{"udp", "127.0.0.1:5670", "127.0.0.1:5671"},
[]string{"udp4", "127.0.0.1:5672", "127.0.0.1:5673"},
[]string{"udp6", "[::1]:5673", "[::1]:5674"},
[]string{"tcp", "127.0.0.1:0", "127.0.0.1:5571"},
[]string{"tcp4", "127.0.0.1:0", "127.0.0.1:5573"},
[]string{"tcp6", "[::1]:0", "[::1]:5574"},
[]string{"tcp", "127.0.0.1:5570", "127.0.0.1:0"},
[]string{"tcp4", "127.0.0.1:5572", "127.0.0.1:0"},
[]string{"tcp6", "[::1]:5573", "[::1]:0"},
}
testCases := any
......@@ -168,7 +230,7 @@ func TestListenDialSamePort(t *testing.T) {
continue
}
defer c1.Close()
t.Log("dialed", c1.LocalAddr(), c1.RemoteAddr())
t.Log("dialed", c1, c1.LocalAddr(), c1.RemoteAddr())
if getPort(l1.Addr()) != getPort(c1.LocalAddr()) {
t.Fatal("addrs should match", l1.Addr(), c1.LocalAddr())
......@@ -190,6 +252,90 @@ func TestListenDialSamePort(t *testing.T) {
continue
}
if !bytes.Equal(hello1, hello2) {
t.Fatal("echo failed", string(hello1), "!=", string(hello2))
}
t.Log("echoed", string(hello2))
c1.Close()
}
}
func TestPacketListenDialSamePort(t *testing.T) {
any := [][]string{
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
[]string{"udp6", "[::]:0", "[::]:0"},
[]string{"udp", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"udp4", "127.0.0.1:0", "127.0.0.1:0"},
[]string{"udp6", "[::1]:0", "[::1]:0"},
}
specific := [][]string{
[]string{"udp", "127.0.0.1:5670", "127.0.0.1:5671"},
[]string{"udp4", "127.0.0.1:5672", "127.0.0.1:5673"},
[]string{"udp6", "[::1]:5673", "[::1]:5674"},
}
testCases := any
if !CI() {
testCases = append(testCases, specific...)
}
for _, tcase := range testCases {
t.Log("testing", tcase)
network := tcase[0]
addr1 := tcase[1]
addr2 := tcase[2]
l1, err := ListenPacket(network, addr1)
if err != nil {
t.Fatal(err)
continue
}
defer l1.Close()
t.Log("listening", l1.LocalAddr())
l2, err := ListenPacket(network, addr2)
if err != nil {
t.Fatal(err)
continue
}
defer l2.Close()
t.Log("listening", l2.LocalAddr())
go packetEcho(l1)
go packetEcho(l2)
c1, err := Dial(network, l1.LocalAddr().String(), l2.LocalAddr().String())
if err != nil {
t.Fatal(err)
continue
}
defer c1.Close()
t.Log("dialed", c1.LocalAddr(), c1.RemoteAddr())
if getPort(l1.LocalAddr()) != getPort(c1.LocalAddr()) {
t.Fatal("addrs should match", l1.LocalAddr(), c1.LocalAddr())
}
if getPort(l2.LocalAddr()) != getPort(c1.RemoteAddr()) {
t.Fatal("addrs should match", l2.LocalAddr(), c1.RemoteAddr())
}
hello1 := []byte("hello world")
hello2 := make([]byte, len(hello1))
if _, err := c1.Write(hello1); err != nil {
t.Fatal(err)
continue
}
if _, err := c1.Read(hello2); err != nil {
t.Fatal(err)
continue
}
if !bytes.Equal(hello1, hello2) {
t.Fatal("echo failed", string(hello1), "!=", string(hello2))
}
......@@ -217,5 +363,12 @@ func TestUnixNotSupported(t *testing.T) {
}
func getPort(a net.Addr) string {
return strings.Split(a.String(), ":")[1]
if a == nil {
return ""
}
s := strings.Split(a.String(), ":")
if len(s) > 1 {
return s[1]
}
return ""
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment