Commit 444f47d7 authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

mock2: link map fixes

parent 59d1426c
......@@ -19,8 +19,8 @@ type link struct {
sync.RWMutex
}
func newLink(mn *mocknet) *link {
return &link{mock: mn, opts: mn.linkDefaults}
func newLink(mn *mocknet, opts LinkOptions) *link {
return &link{mock: mn, opts: opts}
}
func (l *link) newConnPair() (*conn, *conn) {
......
......@@ -164,18 +164,21 @@ func (mn *mocknet) validate(n inet.Network) (*peernet, error) {
}
func (mn *mocknet) LinkNets(n1, n2 inet.Network) (Link, error) {
mn.Lock()
defer mn.Unlock()
mn.RLock()
n1r, err1 := mn.validate(n1)
n2r, err2 := mn.validate(n1)
ld := mn.linkDefaults
mn.RUnlock()
if _, err := mn.validate(n1); err != nil {
return nil, err
if err1 != nil {
return nil, err1
}
if _, err := mn.validate(n2); err != nil {
return nil, err
if err2 != nil {
return nil, err2
}
l := newLink(mn)
l := newLink(mn, ld)
l.nets = append(l.nets, n1r, n2r)
mn.addLink(l)
return l, nil
}
......@@ -209,13 +212,31 @@ func (mn *mocknet) UnlinkNets(n1, n2 inet.Network) error {
return mn.UnlinkPeers(n1.LocalPeer(), n2.LocalPeer())
}
// get from the links map. and lazily contruct.
func (mn *mocknet) linksMapGet(p1, p2 peer.Peer) *map[*link]struct{} {
l1, found := mn.links[pid(p1)]
if !found {
mn.links[pid(p1)] = map[peerID]map[*link]struct{}{}
l1 = mn.links[pid(p1)] // so we make sure it's there.
}
l2, found := l1[pid(p2)]
if !found {
m := map[*link]struct{}{}
l1[pid(p2)] = m
l2 = l1[pid(p2)]
}
return &l2
}
func (mn *mocknet) addLink(l *link) {
mn.Lock()
defer mn.Unlock()
n1, n2 := l.nets[0], l.nets[1]
mn.links[pid(n1.peer)][pid(n2.peer)][l] = struct{}{}
mn.links[pid(n2.peer)][pid(n1.peer)][l] = struct{}{}
(*mn.linksMapGet(n1.peer, n2.peer))[l] = struct{}{}
(*mn.linksMapGet(n2.peer, n1.peer))[l] = struct{}{}
}
func (mn *mocknet) removeLink(l *link) {
......@@ -223,8 +244,8 @@ func (mn *mocknet) removeLink(l *link) {
defer mn.Unlock()
n1, n2 := l.nets[0], l.nets[1]
delete(mn.links[pid(n1.peer)][pid(n2.peer)], l)
delete(mn.links[pid(n2.peer)][pid(n1.peer)], l)
delete(*mn.linksMapGet(n1.peer, n2.peer), l)
delete(*mn.linksMapGet(n2.peer, n1.peer), l)
}
func (mn *mocknet) ConnectAll() error {
......@@ -263,16 +284,7 @@ func (mn *mocknet) LinksBetweenPeers(p1, p2 peer.Peer) []Link {
mn.RLock()
defer mn.RUnlock()
ls1, found := mn.links[pid(p1)]
if !found {
return nil
}
ls2, found := ls1[pid(p2)]
if !found {
return nil
}
ls2 := *mn.linksMapGet(p1, p2)
cp := make([]Link, 0, len(ls2))
for l := range ls2 {
cp = append(cp, l)
......
package mocknet
import (
"bytes"
"io"
"math/rand"
"sync"
"testing"
inet "github.com/jbenet/go-ipfs/net"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)
// func TestNetworkSetup(t *testing.T) {
// p1 := testutil.RandPeer()
// p2 := testutil.RandPeer()
// p3 := testutil.RandPeer()
// peers := []peer.Peer{p1, p2, p3}
// nets, err := MakeNetworks(context.Background(), peers)
// if err != nil {
// t.Fatal(err)
// }
// // check things
// if len(nets) != 3 {
// t.Error("nets must be 3")
// }
// for i, n := range nets {
// if n.local != peers[i] {
// t.Error("peer mismatch")
// }
// if len(n.conns) != len(nets) {
// t.Error("conn mismatch")
// }
// for _, c := range n.conns {
// if c.remote.conns[n.local] == nil {
// t.Error("conn other side fail")
// }
// if c.remote.conns[n.local].remote.local != n.local {
// t.Error("conn other side fail")
// }
// }
// }
// }
func TestStreams(t *testing.T) {
mn, err := FullMeshConnected(context.Background(), 3)
if err != nil {
t.Fatal(err)
}
handler := func(s inet.Stream) {
go func() {
b := make([]byte, 4)
if _, err := io.ReadFull(s, b); err != nil {
panic(err)
}
if !bytes.Equal(b, []byte("beep")) {
panic("bytes mismatch")
}
if _, err := s.Write([]byte("boop")); err != nil {
panic(err)
}
s.Close()
}()
}
nets := mn.Nets()
for _, n := range nets {
n.SetHandler(inet.ProtocolDHT, handler)
}
s, err := nets[0].NewStream(inet.ProtocolDHT, nets[1].LocalPeer())
if err != nil {
t.Fatal(err)
}
if _, err := s.Write([]byte("beep")); err != nil {
panic(err)
}
b := make([]byte, 4)
if _, err := io.ReadFull(s, b); err != nil {
panic(err)
}
if !bytes.Equal(b, []byte("boop")) {
panic("bytes mismatch 2")
}
}
func makePinger(st string, n int) func(inet.Stream) {
return func(s inet.Stream) {
go func() {
defer s.Close()
for i := 0; i < n; i++ {
b := make([]byte, 4+len(st))
if _, err := s.Write([]byte("ping" + st)); err != nil {
panic(err)
}
if _, err := io.ReadFull(s, b); err != nil {
panic(err)
}
if !bytes.Equal(b, []byte("pong"+st)) {
panic("bytes mismatch")
}
}
}()
}
}
func makePonger(st string) func(inet.Stream) {
return func(s inet.Stream) {
go func() {
defer s.Close()
for {
b := make([]byte, 4+len(st))
if _, err := io.ReadFull(s, b); err != nil {
if err == io.EOF {
return
}
panic(err)
}
if !bytes.Equal(b, []byte("ping"+st)) {
panic("bytes mismatch")
}
if _, err := s.Write([]byte("pong" + st)); err != nil {
panic(err)
}
}
}()
}
}
func TestStreamsStress(t *testing.T) {
mn, err := FullMeshConnected(context.Background(), 100)
if err != nil {
t.Fatal(err)
}
protos := []inet.ProtocolID{
inet.ProtocolDHT,
inet.ProtocolBitswap,
inet.ProtocolDiag,
}
nets := mn.Nets()
for _, n := range nets {
for _, p := range protos {
n.SetHandler(p, makePonger(string(p)))
}
}
var wg sync.WaitGroup
for i := 0; i < 1000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
from := rand.Intn(len(nets))
to := rand.Intn(len(nets))
p := rand.Intn(3)
proto := protos[p]
log.Debug("%d (%s) %d (%s) %d (%s)", from, nets[from], to, nets[to], p, protos[p])
s, err := nets[from].NewStream(protos[p], nets[to].LocalPeer())
if err != nil {
panic(err)
}
log.Infof("%d start pinging", i)
makePinger(string(proto), rand.Intn(100))(s)
log.Infof("%d done pinging", i)
}(i)
}
wg.Done()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment