impl_unix.go 7.97 KB
Newer Older
1 2 3 4 5 6 7 8 9
// +build darwin freebsd dragonfly netbsd openbsd linux

package reuseport

import (
	"net"
	"os"
	"strconv"
	"syscall"
10
	"time"
11 12 13 14 15 16 17 18

	sockaddrnet "github.com/jbenet/go-sockaddr/net"
)

const (
	filePrefix = "port."
)

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
// Wrapper around the socket system call that marks the returned file
// descriptor as nonblocking and close-on-exec.
func socket(family, socktype, protocol int) (fd int, err error) {
	syscall.ForkLock.RLock()
	fd, err = syscall.Socket(family, socktype, protocol)
	if err == nil {
		syscall.CloseOnExec(fd)
	}
	syscall.ForkLock.RUnlock()

	if err != nil {
		return -1, err
	}

	// set non-blocking until after connect, because we cant poll using runtime :(
	// if err = syscall.SetNonblock(fd, true); err != nil {
	// 	syscall.Close(fd)
	// 	return -1, err
	// }

	if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil {
		// fmt.Println("reuse addr failed")
		syscall.Close(fd)
		return -1, err
	}

	if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil {
		// fmt.Println("reuse port failed")
		syscall.Close(fd)
		return -1, err
	}

	// set setLinger to 5 as reusing exact same (srcip:srcport, dstip:dstport)
	// will otherwise fail on connect.
	if err = setLinger(fd, 5); err != nil {
		// fmt.Println("linger failed")
		syscall.Close(fd)
		return -1, err
	}

	return fd, nil
}

62 63
func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) {
	var (
64 65 66 67 68 69
		fd             int
		lfamily        int
		rfamily        int
		socktype       int
		lprotocol      int
		rprotocol      int
70 71 72 73 74
		file           *os.File
		remoteSockaddr syscall.Sockaddr
		localSockaddr  syscall.Sockaddr
	)

Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
75
	netAddr, err := ResolveAddr(netw, addr)
76
	if err != nil {
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
77
		// fmt.Println("resolve addr failed")
78 79 80 81 82 83 84 85 86 87 88 89
		return nil, err
	}

	switch netAddr.(type) {
	case *net.TCPAddr, *net.UDPAddr:
	default:
		return nil, ErrUnsupportedProtocol
	}

	localSockaddr = sockaddrnet.NetAddrToSockaddr(dialer.LocalAddr)
	remoteSockaddr = sockaddrnet.NetAddrToSockaddr(netAddr)

90 91 92
	rfamily = sockaddrnet.NetAddrAF(netAddr)
	rprotocol = sockaddrnet.NetAddrIPPROTO(netAddr)
	socktype = sockaddrnet.NetAddrSOCK(netAddr)
93

94 95 96 97 98 99
	if dialer.LocalAddr != nil {
		switch dialer.LocalAddr.(type) {
		case *net.TCPAddr, *net.UDPAddr:
		default:
			return nil, ErrUnsupportedProtocol
		}
100

101 102 103 104 105 106
		// check family and protocols match.
		lfamily = sockaddrnet.NetAddrAF(dialer.LocalAddr)
		lprotocol = sockaddrnet.NetAddrIPPROTO(dialer.LocalAddr)
		if lfamily != rfamily && lprotocol != rfamily {
			return nil, &net.AddrError{Err: "unexpected address type", Addr: netAddr.String()}
		}
107 108
	}

Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
109 110 111 112 113 114
	// look at dialTCP in http://golang.org/src/net/tcpsock_posix.go  .... !
	// here we just try again 3 times.
	for i := 0; i < 3; i++ {
		if fd, err = socket(rfamily, socktype, rprotocol); err != nil {
			return nil, err
		}
115

Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
116 117 118 119 120 121 122 123 124 125 126 127
		if err = syscall.Bind(fd, localSockaddr); err != nil {
			// fmt.Println("bind failed")
			syscall.Close(fd)
			return nil, err
		}
		if err = connect(fd, remoteSockaddr); err != nil {
			syscall.Close(fd)
			// fmt.Println("connect failed", localSockaddr, err)
			continue // try again.
		}

		break
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
128
	}
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
129
	if err != nil {
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
130
		return nil, err
131 132
	}

133 134 135 136 137 138 139
	if rprotocol == syscall.IPPROTO_TCP {
		//  by default golang/net sets TCP no delay to true.
		if err = setNoDelay(fd, true); err != nil {
			// fmt.Println("set no delay failed")
			syscall.Close(fd)
			return nil, err
		}
140 141
	}

142 143
	if err = syscall.SetNonblock(fd, true); err != nil {
		syscall.Close(fd)
144 145 146
		return nil, err
	}

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
	switch socktype {
	case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:

		// File Name get be nil
		file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
		if c, err = net.FileConn(file); err != nil {
			// fmt.Println("fileconn failed")
			syscall.Close(fd)
			return nil, err
		}

	case syscall.SOCK_DGRAM:

		// File Name get be nil
		file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
		if c, err = net.FileConn(file); err != nil {
			// fmt.Println("fileconn failed")
			syscall.Close(fd)
			return nil, err
		}
	}

169
	if err = file.Close(); err != nil {
170 171
		// fmt.Println("file close failed")
		syscall.Close(fd)
172 173 174 175 176 177
		return nil, err
	}

	return c, err
}

178
func listen(netw, addr string) (fd int, err error) {
179
	var (
180 181 182 183
		family   int
		socktype int
		protocol int
		sockaddr syscall.Sockaddr
184 185
	)

Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
186
	netAddr, err := ResolveAddr(netw, addr)
187
	if err != nil {
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
188
		// fmt.Println("resolve addr failed")
189
		return -1, err
190 191 192 193 194
	}

	switch netAddr.(type) {
	case *net.TCPAddr, *net.UDPAddr:
	default:
195
		return -1, ErrUnsupportedProtocol
196 197 198
	}

	family = sockaddrnet.NetAddrAF(netAddr)
199
	protocol = sockaddrnet.NetAddrIPPROTO(netAddr)
200
	sockaddr = sockaddrnet.NetAddrToSockaddr(netAddr)
201
	socktype = sockaddrnet.NetAddrSOCK(netAddr)
202

203 204
	if fd, err = socket(family, socktype, protocol); err != nil {
		return -1, err
205 206
	}

207 208 209 210
	if err = syscall.Bind(fd, sockaddr); err != nil {
		// fmt.Println("bind failed")
		syscall.Close(fd)
		return -1, err
211 212
	}

213 214 215 216 217 218 219
	if protocol == syscall.IPPROTO_TCP {
		//  by default golang/net sets TCP no delay to true.
		if err = setNoDelay(fd, true); err != nil {
			// fmt.Println("set no delay failed")
			syscall.Close(fd)
			return -1, err
		}
220 221
	}

222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
	if err = syscall.SetNonblock(fd, true); err != nil {
		syscall.Close(fd)
		return -1, err
	}

	return fd, nil
}

func listenStream(netw, addr string) (l net.Listener, err error) {
	var (
		file *os.File
	)

	fd, err := listen(netw, addr)
	if err != nil {
237 238 239 240 241
		return nil, err
	}

	// Set backlog size to the maximum
	if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil {
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
242
		// fmt.Println("listen failed")
243
		syscall.Close(fd)
244 245 246 247 248
		return nil, err
	}

	file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
	if l, err = net.FileListener(file); err != nil {
249 250
		// fmt.Println("filelistener failed")
		syscall.Close(fd)
251 252 253 254
		return nil, err
	}

	if err = file.Close(); err != nil {
255 256
		// fmt.Println("file close failed")
		syscall.Close(fd)
257 258 259 260 261
		return nil, err
	}

	return l, err
}
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

func listenPacket(netw, addr string) (p net.PacketConn, err error) {
	var (
		file *os.File
	)

	fd, err := listen(netw, addr)
	if err != nil {
		return nil, err
	}

	file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
	if p, err = net.FilePacketConn(file); err != nil {
		// fmt.Println("filelistener failed")
		syscall.Close(fd)
		return nil, err
	}

	if err = file.Close(); err != nil {
		// fmt.Println("file close failed")
		syscall.Close(fd)
		return nil, err
	}

	return p, err
}

func listenUDP(netw, addr string) (c net.Conn, err error) {
	var (
		file *os.File
	)

	fd, err := listen(netw, addr)
	if err != nil {
		return nil, err
	}

	file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid()))
	if c, err = net.FileConn(file); err != nil {
		// fmt.Println("filelistener failed")
		syscall.Close(fd)
		return nil, err
	}

	if err = file.Close(); err != nil {
		// fmt.Println("file close failed")
		syscall.Close(fd)
		return nil, err
	}

	return c, err
}

Juan Batiz-Benet's avatar
nits  
Juan Batiz-Benet committed
315
// this is close to the connect() function inside stdlib/net
316 317 318 319 320 321 322 323 324 325
func connect(fd int, ra syscall.Sockaddr) error {
	switch err := syscall.Connect(fd, ra); err {
	case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
	case nil, syscall.EISCONN:
		return nil
	default:
		return err
	}

	var err error
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
326
	start := time.Now()
327 328 329 330 331 332 333
	for {
		// if err := fd.pd.WaitWrite(); err != nil {
		// 	return err
		// }
		// i'd use the above fd.pd.WaitWrite to poll io correctly, just like net sockets...
		// but of course, it uses fucking runtime_* functions that _cannot_ be used by
		// non-go-stdlib source... seriously guys, what kind of bullshit is that!?
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
334 335
		// we're relegated to using syscall.Select (what nightmare that is) or using
		// a simple but totally bogus time-based wait. garbage.
336 337 338 339 340 341 342
		var nerr int
		nerr, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
		if err != nil {
			return err
		}
		switch err = syscall.Errno(nerr); err {
		case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
343 344 345 346
			if time.Now().Sub(start) > time.Second {
				return err
			}
			<-time.After(20 * time.Microsecond)
347 348 349 350 351 352 353
		case syscall.Errno(0), syscall.EISCONN:
			return nil
		default:
			return err
		}
	}
}