Commit 40e2379d authored by dignifiedquire's avatar dignifiedquire

refactor: simplify and split into multiple files

parent 08a7704d
package websocket
import (
"fmt"
"net"
"net/url"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
ws "golang.org/x/net/websocket"
)
func ConvertWebsocketMultiaddrToNetAddr(maddr ma.Multiaddr) (net.Addr, error) {
_, host, err := manet.DialArgs(maddr)
if err != nil {
return nil, err
}
a := &ws.Addr{
URL: &url.URL{
Host: host,
},
}
return a, nil
}
func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
wsa, ok := a.(*ws.Addr)
if !ok {
return nil, fmt.Errorf("not a websocket address")
}
tcpaddr, err := net.ResolveTCPAddr("tcp", wsa.Host)
if err != nil {
return nil, err
}
tcpma, err := manet.FromNetAddr(tcpaddr)
if err != nil {
return nil, err
}
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
return nil, err
}
return tcpma.Encapsulate(wsma), nil
}
func parseMultiaddr(a ma.Multiaddr) (string, error) {
_, host, err := manet.DialArgs(a)
if err != nil {
return "", err
}
return "ws://" + host, nil
}
package websocket
import (
"fmt"
"testing"
ma "github.com/multiformats/go-multiaddr"
)
func TestMultiaddrParsing(t *testing.T) {
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/ws")
if err != nil {
t.Fatal(err)
}
res, err := parseMultiaddr(addr)
if err != nil {
t.Fatal(err)
}
if res != "ws://127.0.0.1:5555" {
t.Fatal(fmt.Errorf("%s != ws://127.0.0.1:5555", res))
}
}
package websocket
import (
"net"
"time"
ws "github.com/gorilla/websocket"
)
var _ net.Conn = (*Conn)(nil)
// Conn implements net.Conn interface for gorilla/websocket.
type Conn struct {
*ws.Conn
DefaultMessageType int
done func()
}
func (c *Conn) Read(b []byte) (n int, err error) {
_, r, err := c.Conn.NextReader()
if err != nil {
return 0, err
}
return r.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
if err := c.Conn.WriteMessage(c.DefaultMessageType, b); err != nil {
return 0, err
}
return len(b), nil
}
func (c *Conn) Close() error {
if c.done != nil {
c.done()
}
return c.Conn.Close()
}
func (c *Conn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}
func (c *Conn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.Conn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn, done func()) *Conn {
return &Conn{
Conn: raw,
DefaultMessageType: ws.BinaryMessage,
done: done,
}
}
package websocket
import (
"context"
ws "github.com/gorilla/websocket"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
)
type dialer struct{}
func (d *dialer) Dial(raddr ma.Multiaddr) (tpt.Conn, error) {
return d.DialContext(context.Background(), raddr)
}
func (d *dialer) DialContext(ctx context.Context, raddr ma.Multiaddr) (tpt.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil)
if err != nil {
return nil, err
}
mnc, err := manet.WrapNetConn(NewConn(wscon, nil))
if err != nil {
return nil, err
}
return &wsConn{
Conn: mnc,
}, nil
}
func (d *dialer) Matches(a ma.Multiaddr) bool {
return WsFmt.Matches(a)
}
package websocket
import (
"context"
"fmt"
"net/http"
"net/url"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
)
type wsConn struct {
manet.Conn
t tpt.Transport
}
var _ tpt.Conn = (*wsConn)(nil)
func (c *wsConn) Transport() tpt.Transport {
return c.t
}
type listener struct {
manet.Listener
incoming chan *Conn
tpt tpt.Transport
origin *url.URL
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, "Failed to upgrade websocket", 400)
return
}
ctx, cancel := context.WithCancel(context.Background())
l.incoming <- NewConn(c, cancel)
// wait until conn gets closed, otherwise the handler closes it early
<-ctx.Done()
}
func (l *listener) Accept() (tpt.Conn, error) {
c, ok := <-l.incoming
if !ok {
return nil, fmt.Errorf("listener is closed")
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
return nil, err
}
return &wsConn{
Conn: mnc,
t: l.tpt,
}, nil
}
func (l *listener) Multiaddr() ma.Multiaddr {
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
panic(err)
}
return l.Listener.Multiaddr().Encapsulate(wsma)
}
// Package websocket implements a websocket based transport for go-libp2p.
package websocket
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
wsGorilla "github.com/gorilla/websocket"
ws "github.com/gorilla/websocket"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
mafmt "github.com/whyrusleeping/mafmt"
ws "golang.org/x/net/websocket"
)
// WsProtocol is the multiaddr protocol definition for this transport.
var WsProtocol = ma.Protocol{
Code: 477,
Name: "ws",
VCode: ma.CodeToVarint(477),
}
// WsFmt is multiaddr formatter for WsProtocol
var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(WsProtocol.Code))
// WsCodec is the multiaddr-net codec definition for the websocket transport
var WsCodec = &manet.NetCodec{
NetAddrNetworks: []string{"websocket"},
ProtocolName: "ws",
......@@ -32,7 +33,7 @@ var WsCodec = &manet.NetCodec{
}
// Default gorilla upgrader
var upgrader = wsGorilla.Upgrader{}
var upgrader = ws.Upgrader{}
func init() {
err := ma.AddProtocol(WsProtocol)
......@@ -43,46 +44,11 @@ func init() {
manet.RegisterNetCodec(WsCodec)
}
func ConvertWebsocketMultiaddrToNetAddr(maddr ma.Multiaddr) (net.Addr, error) {
_, host, err := manet.DialArgs(maddr)
if err != nil {
return nil, err
}
a := &ws.Addr{
URL: &url.URL{
Host: host,
},
}
return a, nil
}
func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
wsa, ok := a.(*ws.Addr)
if !ok {
return nil, fmt.Errorf("not a websocket address")
}
tcpaddr, err := net.ResolveTCPAddr("tcp", wsa.Host)
if err != nil {
return nil, err
}
tcpma, err := manet.FromNetAddr(tcpaddr)
if err != nil {
return nil, err
}
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
return nil, err
}
return tcpma.Encapsulate(wsma), nil
}
// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct{}
var _ tpt.Transport = (*WebsocketTransport)(nil)
func (t *WebsocketTransport) Matches(a ma.Multiaddr) bool {
return WsFmt.Matches(a)
}
......@@ -91,79 +57,6 @@ func (t *WebsocketTransport) Dialer(_ ma.Multiaddr, opts ...tpt.DialOpt) (tpt.Di
return &dialer{}, nil
}
type dialer struct{}
func parseMultiaddr(a ma.Multiaddr) (string, error) {
_, host, err := manet.DialArgs(a)
if err != nil {
return "", err
}
return "ws://" + host, nil
}
func (d *dialer) Dial(raddr ma.Multiaddr) (tpt.Conn, error) {
return d.DialContext(context.Background(), raddr)
}
func (d *dialer) DialContext(ctx context.Context, raddr ma.Multiaddr) (tpt.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
// TODO: figure out origins, probably don't work for us
// header := http.Header{}
// header.Set("Origin", "http://127.0.0.1:0/")
wscon, _, err := wsGorilla.DefaultDialer.Dial(wsurl, nil)
if err != nil {
return nil, err
}
mnc, err := manet.WrapNetConn(NewGorillaNetConn(wscon))
if err != nil {
return nil, err
}
return &wsConn{
Conn: mnc,
}, nil
}
func (d *dialer) Matches(a ma.Multiaddr) bool {
return WsFmt.Matches(a)
}
type wsConn struct {
manet.Conn
t tpt.Transport
}
func (c *wsConn) Transport() tpt.Transport {
return c.t
}
type listener struct {
manet.Listener
incoming chan *conn
tpt tpt.Transport
origin *url.URL
}
type conn struct {
*GorillaNetConn
done func()
}
func (c *conn) Close() error {
c.done()
return c.GorillaNetConn.Close()
}
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
list, err := manet.Listen(a)
if err != nil {
......@@ -185,112 +78,8 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
func (t *WebsocketTransport) wrapListener(l manet.Listener, origin *url.URL) *listener {
return &listener{
Listener: l,
incoming: make(chan *conn),
incoming: make(chan *Conn),
tpt: t,
origin: origin,
}
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, "Failed to upgrade websocket", 400)
return
}
ctx, cancel := context.WithCancel(context.Background())
wrapped := NewGorillaNetConn(c)
l.incoming <- &conn{
GorillaNetConn: wrapped,
done: cancel,
}
// wait until conn gets closed, otherwise the handler closes it early
<-ctx.Done()
}
func (l *listener) Accept() (tpt.Conn, error) {
c, ok := <-l.incoming
if !ok {
return nil, fmt.Errorf("listener is closed")
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
return nil, err
}
return &wsConn{
Conn: mnc,
t: l.tpt,
}, nil
}
func (l *listener) Multiaddr() ma.Multiaddr {
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
panic(err)
}
return l.Listener.Multiaddr().Encapsulate(wsma)
}
var _ tpt.Transport = (*WebsocketTransport)(nil)
type GorillaNetConn struct {
Inner *wsGorilla.Conn
DefaultMessageType int
}
func (c *GorillaNetConn) Read(b []byte) (n int, err error) {
_, r, err := c.Inner.NextReader()
if err != nil {
return 0, err
}
return r.Read(b)
}
func (c *GorillaNetConn) Write(b []byte) (n int, err error) {
if err := c.Inner.WriteMessage(c.DefaultMessageType, b); err != nil {
return 0, err
}
return len(b), nil
}
func (c *GorillaNetConn) Close() error {
return c.Inner.Close()
}
func (c *GorillaNetConn) LocalAddr() net.Addr {
return c.Inner.LocalAddr()
}
func (c *GorillaNetConn) RemoteAddr() net.Addr {
return c.Inner.RemoteAddr()
}
func (c *GorillaNetConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}
func (c *GorillaNetConn) SetReadDeadline(t time.Time) error {
return c.Inner.SetReadDeadline(t)
}
func (c *GorillaNetConn) SetWriteDeadline(t time.Time) error {
return c.Inner.SetWriteDeadline(t)
}
func NewGorillaNetConn(raw *wsGorilla.Conn) *GorillaNetConn {
return &GorillaNetConn{
Inner: raw,
DefaultMessageType: wsGorilla.BinaryMessage,
}
}
......@@ -7,18 +7,6 @@ import (
ma "github.com/multiformats/go-multiaddr"
)
func TestMultiaddrParsing(t *testing.T) {
addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/ws")
if err != nil {
t.Fatal(err)
}
_, err = parseMultiaddr(addr)
if err != nil {
t.Fatal(err)
}
}
func TestWebsocketListen(t *testing.T) {
zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment