diff --git a/net/id.go b/net/id.go index 40cedd0345404ffe8166210a4947abf991b4fcdd..7ee08356c0fac6bb298c882c6efcde0d24909400 100644 --- a/net/id.go +++ b/net/id.go @@ -1,6 +1,8 @@ package net import ( + "sync" + handshake "github.com/jbenet/go-ipfs/net/handshake" pb "github.com/jbenet/go-ipfs/net/handshake/pb" @@ -18,14 +20,54 @@ import ( // * Our public Listen Addresses type IDService struct { Network Network + + // connections undergoing identification + // for wait purposes + currid map[Conn]chan struct{} + currmu sync.RWMutex } func NewIDService(n Network) *IDService { - s := &IDService{Network: n} + s := &IDService{ + Network: n, + currid: make(map[Conn]chan struct{}), + } n.SetHandler(ProtocolIdentify, s.RequestHandler) return s } +func (ids *IDService) IdentifyConn(c Conn) { + ids.currmu.Lock() + if _, found := ids.currid[c]; found { + ids.currmu.Unlock() + log.Debugf("IdentifyConn called twice on: %s", c) + return // already identifying it. + } + ids.currid[c] = make(chan struct{}) + ids.currmu.Unlock() + + s, err := c.NewStreamWithProtocol(ProtocolIdentify) + if err != nil { + log.Error("network: unable to open initial stream for %s", ProtocolIdentify) + log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer()) + } + + // ok give the response to our handler. + ids.ResponseHandler(s) + + ids.currmu.Lock() + ch, found := ids.currid[c] + delete(ids.currid, c) + ids.currmu.Unlock() + + if !found { + log.Errorf("IdentifyConn failed to find channel (programmer error) for %s", c) + return + } + + close(ch) // release everyone waiting. +} + func (ids *IDService) RequestHandler(s Stream) { defer s.Close() c := s.Conn() @@ -101,6 +143,7 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) { // update our peerstore with the addresses. ids.Network.Peerstore().AddAddresses(p, lmaddrs) + log.Debugf("%s received listen addrs for %s: %s", c.LocalPeer(), c.RemotePeer(), lmaddrs) // get protocol versions pv := *mes.H1.ProtocolVersion @@ -108,3 +151,23 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) { ids.Network.Peerstore().Put(p, "ProtocolVersion", pv) ids.Network.Peerstore().Put(p, "AgentVersion", av) } + +// IdentifyWait returns a channel which will be closed once +// "ProtocolIdentify" (handshake3) finishes on given conn. +// This happens async so the connection can start to be used +// even if handshake3 knowledge is not necesary. +// Users **MUST** call IdentifyWait _after_ IdentifyConn +func (ids *IDService) IdentifyWait(c Conn) <-chan struct{} { + ids.currmu.Lock() + ch, found := ids.currid[c] + ids.currmu.Unlock() + if found { + return ch + } + + // if not found, it means we are already done identifying it, or + // haven't even started. either way, return a new channel closed. + ch = make(chan struct{}) + close(ch) + return ch +} diff --git a/net/id_test.go b/net/id_test.go index 14deb319fc1110aff909fad1d7443fa26236696e..70ae10e4523584392e011f6be06cd30500a1ff08 100644 --- a/net/id_test.go +++ b/net/id_test.go @@ -32,7 +32,7 @@ func DivulgeAddresses(a, b inet.Network) { b.Peerstore().AddAddresses(id, addrs) } -func TestIDService(t *testing.T) { +func subtestIDService(t *testing.T, postDialWait time.Duration) { // the generated networks should have the id service wired in. ctx := context.Background() @@ -55,16 +55,26 @@ func TestIDService(t *testing.T) { t.Fatalf("Failed to dial:", err) } - // this is shitty. dial should wait for connecting to end - <-time.After(100 * time.Millisecond) + // we need to wait here if Dial returns before ID service is finished. + if postDialWait > 0 { + <-time.After(postDialWait) + } // the IDService should be opened automatically, by the network. // what we should see now is that both peers know about each others listen addresses. testKnowsAddrs(t, n1, n2p, n2.Peerstore().Addresses(n2p)) // has them - testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them + testHasProtocolVersions(t, n1, n2p) + + // now, this wait we do have to do. it's the wait for the Listening side + // to be done identifying the connection. + c := n2.ConnsToPeer(n1.LocalPeer()) + if len(c) < 1 { + t.Fatal("should have connection by now at least.") + } + <-n2.IdentifyProtocol().IdentifyWait(c[0]) // and the protocol versions. - testHasProtocolVersions(t, n1, n2p) + testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them testHasProtocolVersions(t, n2, n1p) } @@ -82,18 +92,39 @@ func testKnowsAddrs(t *testing.T, n inet.Network, p peer.ID, expected []ma.Multi for _, addr := range expected { if _, found := have[addr.String()]; !found { t.Errorf("%s did not have addr for %s: %s", n.LocalPeer(), p, addr) - panic("ahhhhhhh") + // panic("ahhhhhhh") } } } func testHasProtocolVersions(t *testing.T, n inet.Network, p peer.ID) { v, err := n.Peerstore().Get(p, "ProtocolVersion") + if v == nil { + t.Error("no protocol version") + return + } if v.(string) != handshake.IpfsVersion.String() { - t.Fatal("protocol mismatch", err) + t.Error("protocol mismatch", err) } v, err = n.Peerstore().Get(p, "AgentVersion") if v.(string) != handshake.ClientVersion { - t.Fatal("agent version mismatch", err) + t.Error("agent version mismatch", err) + } +} + +// TestIDServiceWait gives the ID service 100ms to finish after dialing +// this is becasue it used to be concurrent. Now, Dial wait till the +// id service is done. +func TestIDServiceWait(t *testing.T) { + N := 3 + for i := 0; i < N; i++ { + subtestIDService(t, 100*time.Millisecond) + } +} + +func TestIDServiceNoWait(t *testing.T) { + N := 3 + for i := 0; i < N; i++ { + subtestIDService(t, 0) } } diff --git a/net/interface.go b/net/interface.go index 2f6d933e2d9bf7137d42c6368159e99e54d7bb2a..74354e5cd2a4da71a89d19aa4003b7cb07c3f5e4 100644 --- a/net/interface.go +++ b/net/interface.go @@ -88,6 +88,9 @@ type Network interface { // Conns returns the connections in this Netowrk Conns() []Conn + // ConnsToPeer returns the connections in this Netowrk for given peer. + ConnsToPeer(p peer.ID) []Conn + // BandwidthTotals returns the total number of bytes passed through // the network since it was instantiated BandwidthTotals() (uint64, uint64) @@ -102,6 +105,11 @@ type Network interface { // CtxGroup returns the network's contextGroup CtxGroup() ctxgroup.ContextGroup + + // IdentifyProtocol returns the instance of the object running the Identify + // Protocol. This is what runs the ifps handshake-- this should be removed + // if this abstracted out to its own package. + IdentifyProtocol() *IDService } // Dialer represents a service that can dial out to peers diff --git a/net/mock/mock_peernet.go b/net/mock/mock_peernet.go index 3190152f6663c415510758b74485c42bcfc0b877..ae10417e596ae1e16660828cfa3c2c938cfc4244 100644 --- a/net/mock/mock_peernet.go +++ b/net/mock/mock_peernet.go @@ -29,6 +29,7 @@ type peernet struct { // needed to implement inet.Network mux inet.Mux + ids *inet.IDService cg ctxgroup.ContextGroup sync.RWMutex @@ -61,6 +62,11 @@ func newPeernet(ctx context.Context, m *mocknet, k ic.PrivKey, } n.cg.SetTeardown(n.teardown) + + // setup a conn handler that immediately "asks the other side about them" + // this is ProtocolIdentify. + n.ids = inet.NewIDService(n) + return n, nil } @@ -158,6 +164,10 @@ func (pn *peernet) remoteOpenedConn(c *conn) { // addConn constructs and adds a connection // to given remote peer over given link func (pn *peernet) addConn(c *conn) { + + // run the Identify protocol/handshake. + pn.ids.IdentifyConn(c) + pn.Lock() cs, found := pn.connsByPeer[c.RemotePeer()] if !found { @@ -327,3 +337,7 @@ func (pn *peernet) NewStream(pr inet.ProtocolID, p peer.ID) (inet.Stream, error) func (pn *peernet) SetHandler(p inet.ProtocolID, h inet.StreamHandler) { pn.mux.SetHandler(p, h) } + +func (pn *peernet) IdentifyProtocol() *inet.IDService { + return pn.ids +} diff --git a/net/net.go b/net/net.go index f481ee53a253bbe9d6e45a986d4393bcad1bce88..9fda9944e9e05a92c545df41420fa3872edb77a9 100644 --- a/net/net.go +++ b/net/net.go @@ -129,21 +129,21 @@ func NewNetwork(ctx context.Context, listen []ma.Multiaddr, local peer.ID, func (n *network) newConnHandler(c *swarm.Conn) { cc := (*conn_)(c) - s, err := cc.NewStreamWithProtocol(ProtocolIdentify) - if err != nil { - log.Error("network: unable to open initial stream for %s", ProtocolIdentify) - log.Event(n.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer()) - } - - // ok give the response to our handler. - n.ids.ResponseHandler(s) + n.ids.IdentifyConn(cc) } // DialPeer attempts to establish a connection to a given peer. // Respects the context. func (n *network) DialPeer(ctx context.Context, p peer.ID) error { - _, err := n.swarm.Dial(ctx, p) - return err + sc, err := n.swarm.Dial(ctx, p) + if err != nil { + return err + } + + // identify the connection before returning. + n.ids.IdentifyConn((*conn_)(sc)) + log.Debugf("network for %s finished dialing %s", n.local, p) + return nil } func (n *network) Protocols() []ProtocolID { @@ -185,6 +185,16 @@ func (n *network) Conns() []Conn { return out } +// ConnsToPeer returns the connections in this Netowrk for given peer. +func (n *network) ConnsToPeer(p peer.ID) []Conn { + conns1 := n.swarm.ConnectionsToPeer(p) + out := make([]Conn, len(conns1)) + for i, c := range conns1 { + out[i] = (*conn_)(c) + } + return out +} + // ClosePeer connection to peer func (n *network) ClosePeer(p peer.ID) error { return n.swarm.CloseConnection(p) @@ -254,6 +264,10 @@ func (n *network) SetHandler(p ProtocolID, h StreamHandler) { n.mux.SetHandler(p, h) } +func (n *network) IdentifyProtocol() *IDService { + return n.ids +} + func WriteProtocolHeader(pr ProtocolID, s Stream) error { if pr != "" { // only write proper protocol headers if err := WriteLengthPrefix(s, string(pr)); err != nil {