Commit a109f8da authored by Steven Allen's avatar Steven Allen

expose methods from underlying connection types

This sucks but I can't think of a better way to do this. We really do want to
expose these features and doing so through type assertions is very go-like.
parent 8792ba0a
...@@ -28,8 +28,64 @@ type Conn interface { ...@@ -28,8 +28,64 @@ type Conn interface {
RemoteMultiaddr() ma.Multiaddr RemoteMultiaddr() ma.Multiaddr
} }
// WrapNetConn wraps a net.Conn object with a Multiaddr type halfOpen interface {
// friendly Conn. net.Conn
CloseRead() error
CloseWrite() error
}
func wrap(nconn net.Conn, laddr, raddr ma.Multiaddr) Conn {
endpts := maEndpoints{
laddr: laddr,
raddr: raddr,
}
// This sucks. However, it's the only way to reliably expose the
// underlying methods. This way, users that need access to, e.g.,
// CloseRead and CloseWrite, can do so via type assertions.
switch nconn := nconn.(type) {
case *net.TCPConn:
return &struct {
*net.TCPConn
maEndpoints
}{nconn, endpts}
case *net.UDPConn:
return &struct {
*net.UDPConn
maEndpoints
}{nconn, endpts}
case *net.IPConn:
return &struct {
*net.IPConn
maEndpoints
}{nconn, endpts}
case *net.UnixConn:
return &struct {
*net.UnixConn
maEndpoints
}{nconn, endpts}
case halfOpen:
return &struct {
halfOpen
maEndpoints
}{nconn, endpts}
default:
return &struct {
net.Conn
maEndpoints
}{nconn, endpts}
}
}
// WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn.
//
// This function does it's best to avoid "hiding" methods exposed by the wrapped
// type. Guarantees:
//
// * If the wrapped connection exposes the "half-open" closer methods
// (CloseWrite, CloseRead), these will be available on the wrapped connection
// via type assertions.
// * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all
// methods on these wrapped connections will be available via type assertions.
func WrapNetConn(nconn net.Conn) (Conn, error) { func WrapNetConn(nconn net.Conn) (Conn, error) {
if nconn == nil { if nconn == nil {
return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil") return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil")
...@@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) { ...@@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) {
return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
} }
return &maConn{ return wrap(nconn, laddr, raddr), nil
Conn: nconn,
laddr: laddr,
raddr: raddr,
}, nil
} }
// maConn implements the Conn interface. It's a thin wrapper type maEndpoints struct {
// around a net.Conn
type maConn struct {
net.Conn
laddr ma.Multiaddr laddr ma.Multiaddr
raddr ma.Multiaddr raddr ma.Multiaddr
} }
// LocalMultiaddr returns the local address associated with // LocalMultiaddr returns the local address associated with
// this connection // this connection
func (c *maConn) LocalMultiaddr() ma.Multiaddr { func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
return c.laddr return c.laddr
} }
// RemoteMultiaddr returns the remote address associated with // RemoteMultiaddr returns the remote address associated with
// this connection // this connection
func (c *maConn) RemoteMultiaddr() ma.Multiaddr { func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
return c.raddr return c.raddr
} }
...@@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er ...@@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
return nil, err return nil, err
} }
} }
return wrap(nconn, local, remote), nil
return &maConn{
Conn: nconn,
laddr: local,
raddr: remote,
}, nil
} }
// Dial connects to a remote address. It uses an underlying net.Conn, // Dial connects to a remote address. It uses an underlying net.Conn,
...@@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) { ...@@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) {
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err) return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
} }
return &maConn{ return wrap(nconn, l.laddr, raddr), nil
Conn: nconn,
laddr: l.laddr,
raddr: raddr,
}, nil
} }
// Multiaddr returns the listener's (local) Multiaddr. // Multiaddr returns the listener's (local) Multiaddr.
......
...@@ -407,12 +407,14 @@ func TestWrapNetConn(t *testing.T) { ...@@ -407,12 +407,14 @@ func TestWrapNetConn(t *testing.T) {
defer wg.Done() defer wg.Done()
cB, err := listener.Accept() cB, err := listener.Accept()
checkErr(err, "failed to accept") checkErr(err, "failed to accept")
_ = cB.(halfOpen)
cB.Close() cB.Close()
}() }()
cA, err := net.Dial("tcp", listener.Addr().String()) cA, err := net.Dial("tcp", listener.Addr().String())
checkErr(err, "failed to dial") checkErr(err, "failed to dial")
defer cA.Close() defer cA.Close()
_ = cA.(halfOpen)
lmaddr, err := FromNetAddr(cA.LocalAddr()) lmaddr, err := FromNetAddr(cA.LocalAddr())
checkErr(err, "failed to get local addr") checkErr(err, "failed to get local addr")
...@@ -422,6 +424,8 @@ func TestWrapNetConn(t *testing.T) { ...@@ -422,6 +424,8 @@ func TestWrapNetConn(t *testing.T) {
mcA, err := WrapNetConn(cA) mcA, err := WrapNetConn(cA)
checkErr(err, "failed to wrap conn") checkErr(err, "failed to wrap conn")
_ = mcA.(halfOpen)
if mcA.LocalAddr().String() != cA.LocalAddr().String() { if mcA.LocalAddr().String() != cA.LocalAddr().String() {
t.Error("wrapped conn local addr differs") t.Error("wrapped conn local addr differs")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment