Commit 8d8a1dc7 authored by Juan Batiz-Benet's avatar Juan Batiz-Benet

Merge pull request #473 from jbenet/dht-test-providers

dht fixes
parents 32589ad4 0938471d
......@@ -50,32 +50,44 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (
return nil, err
}
select {
case <-ctx.Done():
maconn.Close()
return nil, ctx.Err()
default:
}
var connOut Conn
var errOut error
done := make(chan struct{})
// do it async to ensure we respect don contexteone
go func() {
defer func() { done <- struct{}{} }()
c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn)
if err != nil {
errOut = err
return
}
c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn)
if err != nil {
return nil, err
}
if d.PrivateKey == nil {
log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr)
connOut = c
return
}
c2, err := newSecureConn(ctx, d.PrivateKey, c)
if err != nil {
errOut = err
c.Close()
return
}
if d.PrivateKey == nil {
log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr)
return c, nil
}
connOut = c2
}()
select {
case <-ctx.Done():
c.Close()
maconn.Close()
return nil, ctx.Err()
default:
case <-done:
// whew, finished.
}
// return c, nil
return newSecureConn(ctx, d.PrivateKey, c)
return connOut, errOut
}
// MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks.
......
......@@ -109,7 +109,7 @@ func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey
}
l.cg.SetTeardown(l.teardown)
log.Infof("swarm listening on %s\n", l.Multiaddr())
log.Infof("swarm listening on %s", l.Multiaddr())
log.Event(ctx, "swarmListen", l)
return l, nil
}
......@@ -38,10 +38,11 @@ func NewIDService(n Network) *IDService {
func (ids *IDService) IdentifyConn(c Conn) {
ids.currmu.Lock()
if _, found := ids.currid[c]; found {
if wait, found := ids.currid[c]; found {
ids.currmu.Unlock()
log.Debugf("IdentifyConn called twice on: %s", c)
return // already identifying it.
<-wait // already identifying it. wait for it.
return
}
ids.currid[c] = make(chan struct{})
ids.currmu.Unlock()
......@@ -50,10 +51,11 @@ func (ids *IDService) IdentifyConn(c Conn) {
if err != nil {
log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
}
} else {
// ok give the response to our handler.
ids.ResponseHandler(s)
// ok give the response to our handler.
ids.ResponseHandler(s)
}
ids.currmu.Lock()
ch, found := ids.currid[c]
......
......@@ -82,15 +82,6 @@ type Network interface {
// If ProtocolID is "", writes no header.
NewStream(ProtocolID, peer.ID) (Stream, error)
// Peers returns the peers connected
Peers() []peer.ID
// Conns returns the connections in this Netowrk
Conns() []Conn
// ConnsToPeer returns the connections in this Netowrk for given peer.
ConnsToPeer(p peer.ID) []Conn
// BandwidthTotals returns the total number of bytes passed through
// the network since it was instantiated
BandwidthTotals() (uint64, uint64)
......@@ -133,6 +124,15 @@ type Dialer interface {
// Connectedness returns a state signaling connection capabilities
Connectedness(peer.ID) Connectedness
// Peers returns the peers connected
Peers() []peer.ID
// Conns returns the connections in this Netowrk
Conns() []Conn
// ConnsToPeer returns the connections in this Netowrk for given peer.
ConnsToPeer(p peer.ID) []Conn
}
// Connectedness signals the capacity for a connection with a given node.
......
......@@ -148,7 +148,19 @@ func (n *network) DialPeer(ctx context.Context, p peer.ID) error {
}
// identify the connection before returning.
n.ids.IdentifyConn((*conn_)(sc))
done := make(chan struct{})
go func() {
n.ids.IdentifyConn((*conn_)(sc))
close(done)
}()
// respect don contexteone
select {
case <-done:
case <-ctx.Done():
return ctx.Err()
}
log.Debugf("network for %s finished dialing %s", n.local, p)
return nil
}
......
......@@ -248,15 +248,21 @@ func TestConnHandler(t *testing.T) {
<-time.After(time.Millisecond)
// should've gotten 5 by now.
close(gotconn)
swarms[0].SetConnHandler(nil)
expect := 4
actual := 0
for _ = range gotconn {
actual++
for i := 0; i < expect; i++ {
select {
case <-time.After(time.Second):
t.Fatal("failed to get connections")
case <-gotconn:
}
}
if actual != expect {
t.Fatal("should have connected to %d swarms. got: %d", actual, expect)
select {
case <-gotconn:
t.Fatalf("should have connected to %d swarms", expect)
default:
}
}
......@@ -28,6 +28,10 @@ var log = eventlog.Logger("dht")
const doPinging = false
// NumBootstrapQueries defines the number of random dht queries to do to
// collect members of the routing table.
const NumBootstrapQueries = 5
// TODO. SEE https://github.com/jbenet/node-ipfs/blob/master/submodules/ipfs-dht/index.js
// IpfsDHT is an implementation of Kademlia with Coral and S/Kademlia modifications.
......@@ -361,25 +365,20 @@ func (dht *IpfsDHT) PingRoutine(t time.Duration) {
}
// Bootstrap builds up list of peers by requesting random peer IDs
func (dht *IpfsDHT) Bootstrap(ctx context.Context) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
id := make([]byte, 16)
rand.Read(id)
pi, err := dht.FindPeer(ctx, peer.ID(id))
if err != nil {
// NOTE: this is not an error. this is expected!
log.Errorf("Bootstrap peer error: %s", err)
}
func (dht *IpfsDHT) Bootstrap(ctx context.Context, queries int) {
// bootstrap sequentially, as results will compound
for i := 0; i < NumBootstrapQueries; i++ {
id := make([]byte, 16)
rand.Read(id)
pi, err := dht.FindPeer(ctx, peer.ID(id))
if err == routing.ErrNotFound {
// this isn't an error. this is precisely what we expect.
} else if err != nil {
log.Errorf("Bootstrap peer error: %s", err)
} else {
// woah, we got a peer under a random id? it _cannot_ be valid.
log.Errorf("dht seemingly found a peer at a random bootstrap id (%s)...", pi)
}()
}
}
wg.Wait()
}
......@@ -7,6 +7,7 @@ import (
inet "github.com/jbenet/go-ipfs/net"
peer "github.com/jbenet/go-ipfs/peer"
pb "github.com/jbenet/go-ipfs/routing/dht/pb"
ctxutil "github.com/jbenet/go-ipfs/util/ctx"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io"
......@@ -21,18 +22,21 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) {
defer s.Close()
ctx := dht.Context()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s)
cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(cw)
mPeer := s.Conn().RemotePeer()
// receive msg
pmes := new(pb.Message)
if err := r.ReadMsg(pmes); err != nil {
log.Error("Error unmarshaling data")
log.Errorf("Error unmarshaling data: %s", err)
return
}
// update the peer (on valid msgs only)
dht.Update(ctx, mPeer)
dht.updateFromMessage(ctx, mPeer, pmes)
log.Event(ctx, "foo", dht.self, mPeer, pmes)
......@@ -76,8 +80,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
}
defer s.Close()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s)
cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(cw)
start := time.Now()
......@@ -98,6 +104,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
return nil, errors.New("no response to request")
}
// update the peer (on valid msgs only)
dht.updateFromMessage(ctx, p, rpmes)
dht.peerstore.RecordLatency(p, time.Since(start))
log.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
return rpmes, nil
......@@ -113,7 +122,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
}
defer s.Close()
w := ggio.NewDelimitedWriter(s)
cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
w := ggio.NewDelimitedWriter(cw)
log.Debugf("%s writing", dht.self)
if err := w.WriteMsg(pmes); err != nil {
......@@ -123,3 +133,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
log.Debugf("%s done", dht.self)
return nil
}
func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
dht.Update(ctx, p)
return nil
}
......@@ -2,7 +2,9 @@ package dht
import (
"bytes"
"fmt"
"sort"
"sync"
"testing"
"time"
......@@ -15,10 +17,22 @@ import (
// ci "github.com/jbenet/go-ipfs/crypto"
inet "github.com/jbenet/go-ipfs/net"
peer "github.com/jbenet/go-ipfs/peer"
routing "github.com/jbenet/go-ipfs/routing"
u "github.com/jbenet/go-ipfs/util"
testutil "github.com/jbenet/go-ipfs/util/testutil"
)
var testCaseValues = map[u.Key][]byte{}
func init() {
testCaseValues["hello"] = []byte("world")
for i := 0; i < 100; i++ {
k := fmt.Sprintf("%d -- key", i)
v := fmt.Sprintf("%d -- value", i)
testCaseValues[u.Key(k)] = []byte(v)
}
}
func setupDHT(ctx context.Context, t *testing.T, addr ma.Multiaddr) *IpfsDHT {
sk, pk, err := testutil.RandKeyPair(512)
......@@ -78,6 +92,27 @@ func connect(t *testing.T, ctx context.Context, a, b *IpfsDHT) {
}
}
func bootstrap(t *testing.T, ctx context.Context, dhts []*IpfsDHT) {
ctx, cancel := context.WithCancel(ctx)
rounds := 1
for i := 0; i < rounds; i++ {
log.Debugf("bootstrapping round %d/%d\n", i, rounds)
// tried async. sequential fares much better. compare:
// 100 async https://gist.github.com/jbenet/56d12f0578d5f34810b2
// 100 sync https://gist.github.com/jbenet/6c59e7c15426e48aaedd
// probably because results compound
for _, dht := range dhts {
log.Debugf("bootstrapping round %d/%d -- %s\n", i, rounds, dht.self)
dht.Bootstrap(ctx, 3)
}
}
cancel()
}
func TestPing(t *testing.T) {
// t.Skip("skipping test to debug another")
ctx := context.Background()
......@@ -174,37 +209,208 @@ func TestProvides(t *testing.T) {
connect(t, ctx, dhts[1], dhts[2])
connect(t, ctx, dhts[1], dhts[3])
err := dhts[3].putLocal(u.Key("hello"), []byte("world"))
if err != nil {
t.Fatal(err)
for k, v := range testCaseValues {
log.Debugf("adding local values for %s = %s", k, v)
err := dhts[3].putLocal(k, v)
if err != nil {
t.Fatal(err)
}
bits, err := dhts[3].getLocal(k)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bits, v) {
t.Fatal("didn't store the right bits (%s, %s)", k, v)
}
}
bits, err := dhts[3].getLocal(u.Key("hello"))
if err != nil && bytes.Equal(bits, []byte("world")) {
t.Fatal(err)
for k, _ := range testCaseValues {
log.Debugf("announcing provider for %s", k)
if err := dhts[3].Provide(ctx, k); err != nil {
t.Fatal(err)
}
}
err = dhts[3].Provide(ctx, u.Key("hello"))
if err != nil {
t.Fatal(err)
// what is this timeout for? was 60ms before.
time.Sleep(time.Millisecond * 6)
n := 0
for k, _ := range testCaseValues {
n = (n + 1) % 3
log.Debugf("getting providers for %s from %d", k, n)
ctxT, _ := context.WithTimeout(ctx, time.Second)
provchan := dhts[n].FindProvidersAsync(ctxT, k, 1)
select {
case prov := <-provchan:
if prov.ID == "" {
t.Fatal("Got back nil provider")
}
if prov.ID != dhts[3].self {
t.Fatal("Got back wrong provider")
}
case <-ctxT.Done():
t.Fatal("Did not get a provider back.")
}
}
}
func TestBootstrap(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
ctx := context.Background()
nDHTs := 15
_, _, dhts := setupDHTS(ctx, nDHTs, t)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].network.Close()
}
}()
t.Logf("connecting %d dhts in a ring", nDHTs)
for i := 0; i < nDHTs; i++ {
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
}
<-time.After(100 * time.Millisecond)
t.Logf("bootstrapping them so they find each other", nDHTs)
ctxT, _ := context.WithTimeout(ctx, 5*time.Second)
bootstrap(t, ctxT, dhts)
if u.Debug {
// the routing tables should be full now. let's inspect them.
<-time.After(5 * time.Second)
t.Logf("checking routing table of %d", nDHTs)
for _, dht := range dhts {
fmt.Printf("checking routing table of %s\n", dht.self)
dht.routingTable.Print()
fmt.Println("")
}
}
// test "well-formed-ness" (>= 3 peers in every routing table)
for _, dht := range dhts {
rtlen := dht.routingTable.Size()
if rtlen < 4 {
t.Errorf("routing table for %s only has %d peers", dht.self, rtlen)
}
}
}
func TestProvidesMany(t *testing.T) {
t.Skip("this test doesn't work")
// t.Skip("skipping test to debug another")
ctx := context.Background()
nDHTs := 40
_, _, dhts := setupDHTS(ctx, nDHTs, t)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].network.Close()
}
}()
t.Logf("connecting %d dhts in a ring", nDHTs)
for i := 0; i < nDHTs; i++ {
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
}
<-time.After(100 * time.Millisecond)
t.Logf("bootstrapping them so they find each other", nDHTs)
ctxT, _ := context.WithTimeout(ctx, 5*time.Second)
bootstrap(t, ctxT, dhts)
if u.Debug {
// the routing tables should be full now. let's inspect them.
<-time.After(5 * time.Second)
t.Logf("checking routing table of %d", nDHTs)
for _, dht := range dhts {
fmt.Printf("checking routing table of %s\n", dht.self)
dht.routingTable.Print()
fmt.Println("")
}
}
var providers = map[u.Key]peer.ID{}
d := 0
for k, v := range testCaseValues {
d = (d + 1) % len(dhts)
dht := dhts[d]
providers[k] = dht.self
t.Logf("adding local values for %s = %s (on %s)", k, v, dht.self)
err := dht.putLocal(k, v)
if err != nil {
t.Fatal(err)
}
bits, err := dht.getLocal(k)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bits, v) {
t.Fatal("didn't store the right bits (%s, %s)", k, v)
}
t.Logf("announcing provider for %s", k)
if err := dht.Provide(ctx, k); err != nil {
t.Fatal(err)
}
}
// what is this timeout for? was 60ms before.
time.Sleep(time.Millisecond * 6)
ctxT, _ := context.WithTimeout(ctx, time.Second)
provchan := dhts[0].FindProvidersAsync(ctxT, u.Key("hello"), 1)
errchan := make(chan error)
select {
case prov := <-provchan:
if prov.ID == "" {
t.Fatal("Got back nil provider")
ctxT, _ = context.WithTimeout(ctx, 5*time.Second)
var wg sync.WaitGroup
getProvider := func(dht *IpfsDHT, k u.Key) {
defer wg.Done()
expected := providers[k]
provchan := dht.FindProvidersAsync(ctxT, k, 1)
select {
case prov := <-provchan:
actual := prov.ID
if actual == "" {
errchan <- fmt.Errorf("Got back nil provider (%s at %s)", k, dht.self)
} else if actual != expected {
errchan <- fmt.Errorf("Got back wrong provider (%s != %s) (%s at %s)",
expected, actual, k, dht.self)
}
case <-ctxT.Done():
errchan <- fmt.Errorf("Did not get a provider back (%s at %s)", k, dht.self)
}
if prov.ID != dhts[3].self {
t.Fatal("Got back nil provider")
}
for k, _ := range testCaseValues {
// everyone should be able to find it...
for _, dht := range dhts {
log.Debugf("getting providers for %s at %s", k, dht.self)
wg.Add(1)
go getProvider(dht, k)
}
case <-ctxT.Done():
t.Fatal("Did not get a provider back.")
}
// we need this because of printing errors
go func() {
wg.Wait()
close(errchan)
}()
for err := range errchan {
t.Error(err)
}
}
......@@ -291,18 +497,20 @@ func TestLayeredGet(t *testing.T) {
t.Fatal(err)
}
time.Sleep(time.Millisecond * 60)
time.Sleep(time.Millisecond * 6)
t.Log("interface was changed. GetValue should not use providers.")
ctxT, _ := context.WithTimeout(ctx, time.Second)
val, err := dhts[0].GetValue(ctxT, u.Key("/v/hello"))
if err != nil {
t.Fatal(err)
if err != routing.ErrNotFound {
t.Error(err)
}
if string(val) != "world" {
t.Fatal("Got incorrect value.")
if string(val) == "world" {
t.Error("should not get value.")
}
if len(val) > 0 && string(val) != "world" {
t.Error("worse, there's a value and its not even the right one.")
}
}
func TestFindPeer(t *testing.T) {
......
......@@ -73,7 +73,7 @@ func TestGetFailures(t *testing.T) {
})
// This one should fail with NotFound
ctx2, _ := context.WithTimeout(context.Background(), time.Second)
ctx2, _ := context.WithTimeout(context.Background(), 3*time.Second)
_, err = d.GetValue(ctx2, u.Key("test"))
if err != nil {
if err != routing.ErrNotFound {
......
......@@ -148,7 +148,7 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, p peer.ID, pmes *pb.Mess
}
if closest == nil {
log.Errorf("handleFindPeer: could not find anything.")
log.Debugf("handleFindPeer: could not find anything.")
return resp, nil
}
......
......@@ -12,6 +12,7 @@ import (
todoctr "github.com/jbenet/go-ipfs/util/todocounter"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
ctxgroup "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-ctxgroup"
)
var maxQueryConcurrency = AlphaValue
......@@ -78,9 +79,8 @@ type dhtQueryRunner struct {
// peersRemaining is a counter of peers remaining (toQuery + processing)
peersRemaining todoctr.Counter
// context
ctx context.Context
cancel context.CancelFunc
// context group
cg ctxgroup.ContextGroup
// result
result *dhtQueryResult
......@@ -93,16 +93,13 @@ type dhtQueryRunner struct {
}
func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner {
ctx, cancel := context.WithCancel(ctx)
return &dhtQueryRunner{
ctx: ctx,
cancel: cancel,
query: q,
peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)),
peersRemaining: todoctr.NewSyncCounter(),
peersSeen: peer.Set{},
rateLimit: make(chan struct{}, q.concurrency),
cg: ctxgroup.WithContext(ctx),
}
}
......@@ -120,11 +117,13 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
// add all the peers we got first.
for _, p := range peers {
r.addPeerToQuery(p, "") // don't have access to self here...
r.addPeerToQuery(r.cg.Context(), p, "") // don't have access to self here...
}
// go do this thing.
go r.spawnWorkers()
// do it as a child func to make sure Run exits
// ONLY AFTER spawn workers has exited.
r.cg.AddChildFunc(r.spawnWorkers)
// so workers are working.
......@@ -133,7 +132,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
select {
case <-r.peersRemaining.Done():
r.cancel() // ran all and nothing. cancel all outstanding workers.
r.cg.Close()
r.RLock()
defer r.RUnlock()
......@@ -141,10 +140,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
err = r.errs[0]
}
case <-r.ctx.Done():
case <-r.cg.Closed():
r.RLock()
defer r.RUnlock()
err = r.ctx.Err()
err = r.cg.Context().Err() // collect the error.
}
if r.result != nil && r.result.success {
......@@ -154,7 +153,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
return nil, err
}
func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) {
func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID, benchmark peer.ID) {
// if new peer is ourselves...
if next == r.query.dialer.LocalPeer() {
return
......@@ -180,43 +179,48 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) {
r.peersSeen[next] = struct{}{}
r.Unlock()
log.Debugf("adding peer to query: %v\n", next)
log.Debugf("adding peer to query: %v", next)
// do this after unlocking to prevent possible deadlocks.
r.peersRemaining.Increment(1)
select {
case r.peersToQuery.EnqChan <- next:
case <-r.ctx.Done():
case <-ctx.Done():
}
}
func (r *dhtQueryRunner) spawnWorkers() {
func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) {
for {
select {
case <-r.peersRemaining.Done():
return
case <-r.ctx.Done():
case <-r.cg.Closing():
return
case p, more := <-r.peersToQuery.DeqChan:
if !more {
return // channel closed.
}
log.Debugf("spawning worker for: %v\n", p)
go r.queryPeer(p)
log.Debugf("spawning worker for: %v", p)
// do it as a child func to make sure Run exits
// ONLY AFTER spawn workers has exited.
parent.AddChildFunc(func(cg ctxgroup.ContextGroup) {
r.queryPeer(cg, p)
})
}
}
}
func (r *dhtQueryRunner) queryPeer(p peer.ID) {
func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) {
log.Debugf("spawned worker for: %v", p)
// make sure we rate limit concurrency.
select {
case <-r.rateLimit:
case <-r.ctx.Done():
case <-cg.Closing():
r.peersRemaining.Decrement(1)
return
}
......@@ -233,17 +237,22 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) {
}()
// make sure we're connected to the peer.
err := r.query.dialer.DialPeer(r.ctx, p)
if err != nil {
log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err)
r.Lock()
r.errs = append(r.errs, err)
r.Unlock()
return
if conns := r.query.dialer.ConnsToPeer(p); len(conns) == 0 {
log.Infof("worker for: %v -- not connected. dial start", p)
if err := r.query.dialer.DialPeer(cg.Context(), p); err != nil {
log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err)
r.Lock()
r.errs = append(r.errs, err)
r.Unlock()
return
}
log.Infof("worker for: %v -- not connected. dial success!", p)
}
// finally, run the query against this peer
res, err := r.query.qfunc(r.ctx, p)
res, err := r.query.qfunc(cg.Context(), p)
if err != nil {
log.Debugf("ERROR worker for: %v %v", p, err)
......@@ -256,14 +265,20 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) {
r.Lock()
r.result = res
r.Unlock()
r.cancel() // signal to everyone that we're done.
go r.cg.Close() // signal to everyone that we're done.
// must be async, as we're one of the children, and Close blocks.
} else if len(res.closerPeers) > 0 {
log.Debugf("PEERS CLOSER -- worker for: %v (%d closer peers)", p, len(res.closerPeers))
for _, next := range res.closerPeers {
// add their addresses to the dialer's peerstore
conns := r.query.dialer.ConnsToPeer(next.ID)
if len(conns) == 0 {
log.Infof("PEERS CLOSER -- worker for %v FOUND NEW PEER: %s %s", p, next.ID, next.Addrs)
}
r.query.dialer.Peerstore().AddAddresses(next.ID, next.Addrs)
r.addPeerToQuery(next.ID, p)
r.addPeerToQuery(cg.Context(), next.ID, p)
log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs)
}
} else {
......
......@@ -223,8 +223,16 @@ func (rt *RoutingTable) ListPeers() []peer.ID {
func (rt *RoutingTable) Print() {
fmt.Printf("Routing Table, bs = %d, Max latency = %d\n", rt.bucketsize, rt.maxLatency)
rt.tabLock.RLock()
peers := rt.ListPeers()
for i, p := range peers {
fmt.Printf("%d) %s %s\n", i, p.Pretty(), rt.metrics.LatencyEWMA(p).String())
for i, b := range rt.Buckets {
fmt.Printf("\tbucket: %d\n", i)
b.lk.RLock()
for e := b.list.Front(); e != nil; e = e.Next() {
p := e.Value.(peer.ID)
fmt.Printf("\t\t- %s %s\n", p.Pretty(), rt.metrics.LatencyEWMA(p).String())
}
b.lk.RUnlock()
}
rt.tabLock.RUnlock()
}
package ctxutil
import (
"io"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)
type ioret struct {
n int
err error
}
type Writer interface {
io.Writer
}
type ctxWriter struct {
w io.Writer
ctx context.Context
}
// NewWriter wraps a writer to make it respect given Context.
// If there is a blocking write, the returned Writer will return
// whenever the context is cancelled (the return values are n=0
// and err=ctx.Err().)
//
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
// write-- there is no way to do that with the standard go io
// interface. So the read and write _will_ happen or hang. So, use
// this sparingly, make sure to cancel the read or write as necesary
// (e.g. closing a connection whose context is up, etc.)
//
// Furthermore, in order to protect your memory from being read
// _after_ you've cancelled the context, this io.Writer will
// first make a **copy** of the buffer.
func NewWriter(ctx context.Context, w io.Writer) *ctxWriter {
if ctx == nil {
ctx = context.Background()
}
return &ctxWriter{ctx: ctx, w: w}
}
func (w *ctxWriter) Write(buf []byte) (int, error) {
buf2 := make([]byte, len(buf))
copy(buf2, buf)
c := make(chan ioret, 1)
go func() {
n, err := w.w.Write(buf2)
c <- ioret{n, err}
close(c)
}()
select {
case r := <-c:
return r.n, r.err
case <-w.ctx.Done():
return 0, w.ctx.Err()
}
}
type Reader interface {
io.Reader
}
type ctxReader struct {
r io.Reader
ctx context.Context
}
// NewReader wraps a reader to make it respect given Context.
// If there is a blocking read, the returned Reader will return
// whenever the context is cancelled (the return values are n=0
// and err=ctx.Err().)
//
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
// write-- there is no way to do that with the standard go io
// interface. So the read and write _will_ happen or hang. So, use
// this sparingly, make sure to cancel the read or write as necesary
// (e.g. closing a connection whose context is up, etc.)
//
// Furthermore, in order to protect your memory from being read
// _before_ you've cancelled the context, this io.Reader will
// allocate a buffer of the same size, and **copy** into the client's
// if the read succeeds in time.
func NewReader(ctx context.Context, r io.Reader) *ctxReader {
return &ctxReader{ctx: ctx, r: r}
}
func (r *ctxReader) Read(buf []byte) (int, error) {
buf2 := make([]byte, len(buf))
c := make(chan ioret, 1)
go func() {
n, err := r.r.Read(buf2)
c <- ioret{n, err}
close(c)
}()
select {
case ret := <-c:
copy(buf, buf2)
return ret.n, ret.err
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
package ctxutil
import (
"bytes"
"io"
"testing"
"time"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)
func TestReader(t *testing.T) {
buf := []byte("abcdef")
buf2 := make([]byte, 3)
r := NewReader(context.Background(), bytes.NewReader(buf))
// read first half
n, err := r.Read(buf2)
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf2) != string(buf[:3]) {
t.Error("incorrect contents")
}
// read second half
n, err = r.Read(buf2)
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf2) != string(buf[3:6]) {
t.Error("incorrect contents")
}
// read more.
n, err = r.Read(buf2)
if n != 0 {
t.Error("n should be 0", n)
}
if err != io.EOF {
t.Error("should be EOF", err)
}
}
func TestWriter(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(context.Background(), &buf)
// write three
n, err := w.Write([]byte("abc"))
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf.Bytes()) != string("abc") {
t.Error("incorrect contents")
}
// write three more
n, err = w.Write([]byte("def"))
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf.Bytes()) != string("abcdef") {
t.Error("incorrect contents")
}
}
func TestReaderCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
r := NewReader(ctx, piper)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
pipew.Write([]byte("abcdefghij"))
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf) != "abcdefghij" {
t.Error("read contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to read")
}
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop reading after cancel")
}
}
func TestWriterCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
w := NewWriter(ctx, pipew)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := w.Write([]byte("abcdefghij"))
done <- ioret{n, err}
}()
piper.Read(buf)
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf) != "abcdefghij" {
t.Error("write contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to write")
}
go func() {
n, err := w.Write([]byte("abcdefghij"))
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop writing after cancel")
}
}
func TestReadPostCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
r := NewReader(ctx, piper)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop reading after cancel")
}
pipew.Write([]byte("abcdefghij"))
if !bytes.Equal(buf, make([]byte, len(buf))) {
t.Fatal("buffer should have not been written to")
}
}
func TestWritePostCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
w := NewWriter(ctx, pipew)
buf := []byte("abcdefghij")
buf2 := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := w.Write(buf)
done <- ioret{n, err}
}()
piper.Read(buf2)
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf2) != "abcdefghij" {
t.Error("write contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to write")
}
go func() {
n, err := w.Write(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop writing after cancel")
}
copy(buf, []byte("aaaaaaaaaa"))
piper.Read(buf2)
if string(buf2) == "aaaaaaaaaa" {
t.Error("buffer was read from after ctx cancel")
} else if string(buf2) != "abcdefghij" {
t.Error("write contents differ from expected")
}
}
......@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"sync"
"testing"
ci "github.com/jbenet/go-ipfs/crypto"
......@@ -49,17 +50,24 @@ func RandLocalTCPAddress() ma.Multiaddr {
// most ports above 10000 aren't in use by long running processes, so yay.
// (maybe there should be a range of "loopback" ports that are guaranteed
// to be open for the process, but naturally can only talk to self.)
if lastPort == 0 {
lastPort = 10000 + SeededRand.Intn(50000)
lastPort.Lock()
if lastPort.port == 0 {
lastPort.port = 10000 + SeededRand.Intn(50000)
}
lastPort++
port := lastPort.port
lastPort.port++
lastPort.Unlock()
addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", lastPort)
addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", port)
maddr, _ := ma.NewMultiaddr(addr)
return maddr
}
var lastPort = 0
var lastPort = struct {
port int
sync.Mutex
}{}
// PeerNetParams is a struct to bundle together the four things
// you need to run a connection with a peer: id, 2keys, and addr.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment