Commit f03b69ca authored by Cole Brown's avatar Cole Brown

Add support for unix sockets

parent ef212b51
......@@ -3,6 +3,7 @@ package manet
import (
"fmt"
"net"
"path/filepath"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
......@@ -61,6 +62,8 @@ func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) {
return net.ResolveUDPAddr(network, host)
case "ip", "ip4", "ip6":
return net.ResolveIPAddr(network, host)
case "unix":
return net.ResolveUnixAddr(network, host)
}
return nil, fmt.Errorf("network not supported: %s", network)
......@@ -137,6 +140,10 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
hostname = true
ip = c.Value()
return true
case ma.P_UNIX:
network = "unix"
ip = c.Value()
return false
}
case "ip4":
switch c.Protocol().Code {
......@@ -184,6 +191,8 @@ func DialArgs(m ma.Multiaddr) (string, string, error) {
return network, ip + ":" + port, nil
}
return network, "[" + ip + "]" + ":" + port, nil
case "unix":
return network, ip, nil
default:
return "", "", fmt.Errorf("%s is not a 'thin waist' address", m)
}
......@@ -248,3 +257,12 @@ func parseIPPlusNetAddr(a net.Addr) (ma.Multiaddr, error) {
}
return FromIP(ac.IP)
}
func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) {
ac, ok := a.(*net.UnixAddr)
if !ok {
return nil, errIncorrectNetAddr
}
cleaned := filepath.Clean(ac.Name)
return ma.NewComponent("unix", cleaned)
}
......@@ -167,7 +167,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
// ok, Dial!
var nconn net.Conn
switch rnet {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr)
if err != nil {
return nil, err
......@@ -178,7 +178,9 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
// get local address (pre-specified or assigned within net.Conn)
local := d.LocalAddr
if local == nil {
// This block helps us avoid parsing addresses in transports (such as unix
// sockets) that don't have local addresses when dialing out.
if local == nil && nconn.LocalAddr().String() != "" {
local, err = FromNetAddr(nconn.LocalAddr())
if err != nil {
return nil, err
......@@ -243,9 +245,14 @@ func (l *maListener) Accept() (Conn, error) {
return nil, err
}
raddr, err := FromNetAddr(nconn.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
var raddr ma.Multiaddr
// This block protects us in transports (i.e. unix sockets) that don't have
// remote addresses for inbound connections.
if nconn.RemoteAddr().String() != "" {
raddr, err = FromNetAddr(nconn.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("failed to convert conn.RemoteAddr: %s", err)
}
}
return wrap(nconn, l.laddr, raddr), nil
......
......@@ -3,9 +3,13 @@ package manet
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"
ma "github.com/multiformats/go-multiaddr"
)
......@@ -75,6 +79,62 @@ func TestDial(t *testing.T) {
wg.Wait()
}
func TestUnixSockets(t *testing.T) {
dir, err := ioutil.TempDir(os.TempDir(), "manettest")
if err != nil {
t.Fatal(err)
}
path := filepath.Join(dir, "listen.sock")
maddr := newMultiaddr(t, "/unix/"+path)
listener, err := Listen(maddr)
if err != nil {
t.Fatal(err)
}
payload := []byte("hello")
// listen
done := make(chan struct{}, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("failed to read appropriate number of bytes")
}
if !bytes.Equal(buf[0:n], payload) {
t.Fatal("payload did not match")
}
done <- struct{}{}
}()
// dial
conn, err := Dial(maddr)
if err != nil {
t.Fatal(err)
}
n, err := conn.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("failed to write appropriate number of bytes")
}
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for read")
}
}
func TestListen(t *testing.T) {
maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322")
......
......@@ -21,8 +21,9 @@ func init() {
defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6")
defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6")
defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net")
defaultCodecs.RegisterFromNetAddr(parseUnixNetAddr, "unix")
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4")
defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4", "unix")
}
// CodecMap holds a map of NetCodecs indexed by their Protocol ID
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment