Commit 4fe1dd9b authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

net: have an explicit IdentifyConn on dial

- Make sure we call IdentifyConn on dialed out conns
- we wait until the identify is **done** before return
- on listening case, we can also wait.
- tests now make sure dial does wait.
- tests now make sure we can wait on listening case.
parent d9961893
package net
import (
"sync"
handshake "github.com/jbenet/go-ipfs/net/handshake"
pb "github.com/jbenet/go-ipfs/net/handshake/pb"
......@@ -18,14 +20,54 @@ import (
// * Our public Listen Addresses
type IDService struct {
Network Network
// connections undergoing identification
// for wait purposes
currid map[Conn]chan struct{}
currmu sync.RWMutex
}
func NewIDService(n Network) *IDService {
s := &IDService{Network: n}
s := &IDService{
Network: n,
currid: make(map[Conn]chan struct{}),
}
n.SetHandler(ProtocolIdentify, s.RequestHandler)
return s
}
func (ids *IDService) IdentifyConn(c Conn) {
ids.currmu.Lock()
if _, found := ids.currid[c]; found {
ids.currmu.Unlock()
log.Debugf("IdentifyConn called twice on: %s", c)
return // already identifying it.
}
ids.currid[c] = make(chan struct{})
ids.currmu.Unlock()
s, err := c.NewStreamWithProtocol(ProtocolIdentify)
if err != nil {
log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
}
// ok give the response to our handler.
ids.ResponseHandler(s)
ids.currmu.Lock()
ch, found := ids.currid[c]
delete(ids.currid, c)
ids.currmu.Unlock()
if !found {
log.Errorf("IdentifyConn failed to find channel (programmer error) for %s", c)
return
}
close(ch) // release everyone waiting.
}
func (ids *IDService) RequestHandler(s Stream) {
defer s.Close()
c := s.Conn()
......@@ -101,6 +143,7 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) {
// update our peerstore with the addresses.
ids.Network.Peerstore().AddAddresses(p, lmaddrs)
log.Debugf("%s received listen addrs for %s: %s", c.LocalPeer(), c.RemotePeer(), lmaddrs)
// get protocol versions
pv := *mes.H1.ProtocolVersion
......@@ -108,3 +151,23 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) {
ids.Network.Peerstore().Put(p, "ProtocolVersion", pv)
ids.Network.Peerstore().Put(p, "AgentVersion", av)
}
// IdentifyWait returns a channel which will be closed once
// "ProtocolIdentify" (handshake3) finishes on given conn.
// This happens async so the connection can start to be used
// even if handshake3 knowledge is not necesary.
// Users **MUST** call IdentifyWait _after_ IdentifyConn
func (ids *IDService) IdentifyWait(c Conn) <-chan struct{} {
ids.currmu.Lock()
ch, found := ids.currid[c]
ids.currmu.Unlock()
if found {
return ch
}
// if not found, it means we are already done identifying it, or
// haven't even started. either way, return a new channel closed.
ch = make(chan struct{})
close(ch)
return ch
}
......@@ -32,7 +32,7 @@ func DivulgeAddresses(a, b inet.Network) {
b.Peerstore().AddAddresses(id, addrs)
}
func TestIDService(t *testing.T) {
func subtestIDService(t *testing.T, postDialWait time.Duration) {
// the generated networks should have the id service wired in.
ctx := context.Background()
......@@ -55,16 +55,26 @@ func TestIDService(t *testing.T) {
t.Fatalf("Failed to dial:", err)
}
// this is shitty. dial should wait for connecting to end
<-time.After(100 * time.Millisecond)
// we need to wait here if Dial returns before ID service is finished.
if postDialWait > 0 {
<-time.After(postDialWait)
}
// the IDService should be opened automatically, by the network.
// what we should see now is that both peers know about each others listen addresses.
testKnowsAddrs(t, n1, n2p, n2.Peerstore().Addresses(n2p)) // has them
testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them
testHasProtocolVersions(t, n1, n2p)
// now, this wait we do have to do. it's the wait for the Listening side
// to be done identifying the connection.
c := n2.ConnsToPeer(n1.LocalPeer())
if len(c) < 1 {
t.Fatal("should have connection by now at least.")
}
<-n2.IdentifyProtocol().IdentifyWait(c[0])
// and the protocol versions.
testHasProtocolVersions(t, n1, n2p)
testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them
testHasProtocolVersions(t, n2, n1p)
}
......@@ -82,18 +92,39 @@ func testKnowsAddrs(t *testing.T, n inet.Network, p peer.ID, expected []ma.Multi
for _, addr := range expected {
if _, found := have[addr.String()]; !found {
t.Errorf("%s did not have addr for %s: %s", n.LocalPeer(), p, addr)
panic("ahhhhhhh")
// panic("ahhhhhhh")
}
}
}
func testHasProtocolVersions(t *testing.T, n inet.Network, p peer.ID) {
v, err := n.Peerstore().Get(p, "ProtocolVersion")
if v == nil {
t.Error("no protocol version")
return
}
if v.(string) != handshake.IpfsVersion.String() {
t.Fatal("protocol mismatch", err)
t.Error("protocol mismatch", err)
}
v, err = n.Peerstore().Get(p, "AgentVersion")
if v.(string) != handshake.ClientVersion {
t.Fatal("agent version mismatch", err)
t.Error("agent version mismatch", err)
}
}
// TestIDServiceWait gives the ID service 100ms to finish after dialing
// this is becasue it used to be concurrent. Now, Dial wait till the
// id service is done.
func TestIDServiceWait(t *testing.T) {
N := 3
for i := 0; i < N; i++ {
subtestIDService(t, 100*time.Millisecond)
}
}
func TestIDServiceNoWait(t *testing.T) {
N := 3
for i := 0; i < N; i++ {
subtestIDService(t, 0)
}
}
......@@ -88,6 +88,9 @@ type Network interface {
// Conns returns the connections in this Netowrk
Conns() []Conn
// ConnsToPeer returns the connections in this Netowrk for given peer.
ConnsToPeer(p peer.ID) []Conn
// BandwidthTotals returns the total number of bytes passed through
// the network since it was instantiated
BandwidthTotals() (uint64, uint64)
......@@ -102,6 +105,11 @@ type Network interface {
// CtxGroup returns the network's contextGroup
CtxGroup() ctxgroup.ContextGroup
// IdentifyProtocol returns the instance of the object running the Identify
// Protocol. This is what runs the ifps handshake-- this should be removed
// if this abstracted out to its own package.
IdentifyProtocol() *IDService
}
// Dialer represents a service that can dial out to peers
......
......@@ -29,6 +29,7 @@ type peernet struct {
// needed to implement inet.Network
mux inet.Mux
ids *inet.IDService
cg ctxgroup.ContextGroup
sync.RWMutex
......@@ -61,6 +62,11 @@ func newPeernet(ctx context.Context, m *mocknet, k ic.PrivKey,
}
n.cg.SetTeardown(n.teardown)
// setup a conn handler that immediately "asks the other side about them"
// this is ProtocolIdentify.
n.ids = inet.NewIDService(n)
return n, nil
}
......@@ -158,6 +164,10 @@ func (pn *peernet) remoteOpenedConn(c *conn) {
// addConn constructs and adds a connection
// to given remote peer over given link
func (pn *peernet) addConn(c *conn) {
// run the Identify protocol/handshake.
pn.ids.IdentifyConn(c)
pn.Lock()
cs, found := pn.connsByPeer[c.RemotePeer()]
if !found {
......@@ -327,3 +337,7 @@ func (pn *peernet) NewStream(pr inet.ProtocolID, p peer.ID) (inet.Stream, error)
func (pn *peernet) SetHandler(p inet.ProtocolID, h inet.StreamHandler) {
pn.mux.SetHandler(p, h)
}
func (pn *peernet) IdentifyProtocol() *inet.IDService {
return pn.ids
}
......@@ -129,21 +129,21 @@ func NewNetwork(ctx context.Context, listen []ma.Multiaddr, local peer.ID,
func (n *network) newConnHandler(c *swarm.Conn) {
cc := (*conn_)(c)
s, err := cc.NewStreamWithProtocol(ProtocolIdentify)
if err != nil {
log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
log.Event(n.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
}
// ok give the response to our handler.
n.ids.ResponseHandler(s)
n.ids.IdentifyConn(cc)
}
// DialPeer attempts to establish a connection to a given peer.
// Respects the context.
func (n *network) DialPeer(ctx context.Context, p peer.ID) error {
_, err := n.swarm.Dial(ctx, p)
return err
sc, err := n.swarm.Dial(ctx, p)
if err != nil {
return err
}
// identify the connection before returning.
n.ids.IdentifyConn((*conn_)(sc))
log.Debugf("network for %s finished dialing %s", n.local, p)
return nil
}
func (n *network) Protocols() []ProtocolID {
......@@ -185,6 +185,16 @@ func (n *network) Conns() []Conn {
return out
}
// ConnsToPeer returns the connections in this Netowrk for given peer.
func (n *network) ConnsToPeer(p peer.ID) []Conn {
conns1 := n.swarm.ConnectionsToPeer(p)
out := make([]Conn, len(conns1))
for i, c := range conns1 {
out[i] = (*conn_)(c)
}
return out
}
// ClosePeer connection to peer
func (n *network) ClosePeer(p peer.ID) error {
return n.swarm.CloseConnection(p)
......@@ -254,6 +264,10 @@ func (n *network) SetHandler(p ProtocolID, h StreamHandler) {
n.mux.SetHandler(p, h)
}
func (n *network) IdentifyProtocol() *IDService {
return n.ids
}
func WriteProtocolHeader(pr ProtocolID, s Stream) error {
if pr != "" { // only write proper protocol headers
if err := WriteLengthPrefix(s, string(pr)); err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment