Commit 98608888 authored by Jeromy's avatar Jeromy

get implementation according to kademlia spec.

parent 3454da1a
......@@ -496,6 +496,45 @@ out:
dht.network.Send(mes)
}
func (dht *IpfsDHT) getValueOrPeers(p *peer.Peer, key u.Key, timeout time.Duration, level int) ([]byte, []*peer.Peer, error) {
pmes, err := dht.getValueSingle(p, key, timeout, level)
if err != nil {
return nil, nil, u.WrapError(err, "getValue Error")
}
if pmes.GetSuccess() {
if pmes.Value == nil { // We were given provider[s]
val, err := dht.getFromPeerList(key, timeout, pmes.GetPeers(), level)
if err != nil {
return nil, nil, err
}
return val, nil, nil
}
// Success! We were given the value
return pmes.GetValue(), nil, nil
} else {
// We were given a closer node
var peers []*peer.Peer
for _, pb := range pmes.GetPeers() {
addr, err := ma.NewMultiaddr(pb.GetAddr())
if err != nil {
u.PErr(err.Error())
continue
}
np, err := dht.network.GetConnection(peer.ID(pb.GetId()), addr)
if err != nil {
u.PErr(err.Error())
continue
}
peers = append(peers, np)
}
return nil, peers, nil
}
}
// getValueSingle simply performs the get value RPC with the given parameters
func (dht *IpfsDHT) getValueSingle(p *peer.Peer, key u.Key, timeout time.Duration, level int) (*PBDHTMessage, error) {
pmes := DHTMessage{
......
......@@ -31,8 +31,8 @@ func (l *logDhtRpc) EndLog() {
func (l *logDhtRpc) Print() {
b, err := json.Marshal(l)
if err != nil {
u.POut(err.Error())
u.DOut(err.Error())
} else {
u.POut(string(b))
u.DOut(string(b))
}
}
......@@ -88,13 +88,9 @@ func TestGetFailures(t *testing.T) {
d.Update(other)
// This one should time out
_, err := d.GetValue(u.Key("test"), time.Millisecond*5)
_, err := d.GetValue(u.Key("test"), time.Millisecond*10)
if err != nil {
nerr, ok := err.(*u.IpfsError)
if !ok {
t.Fatal("Got different error than we expected.")
}
if nerr.Inner != u.ErrTimeout {
if err != u.ErrTimeout {
t.Fatal("Got different error than we expected.")
}
} else {
......@@ -119,10 +115,10 @@ func TestGetFailures(t *testing.T) {
})
// This one should fail with NotFound
_, err = d.GetValue(u.Key("test"), time.Millisecond*5)
_, err = d.GetValue(u.Key("test"), time.Millisecond*1000)
if err != nil {
if err != u.ErrNotFound {
t.Fatal("Expected ErrNotFound, got: %s", err)
t.Fatalf("Expected ErrNotFound, got: %s", err)
}
} else {
t.Fatal("expected error, got none.")
......
......@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"math/rand"
"sync"
"time"
proto "code.google.com/p/goprotobuf/proto"
......@@ -56,6 +57,30 @@ func (s *IpfsDHT) PutValue(key u.Key, value []byte) {
}
}
// A counter for incrementing a variable across multiple threads
type counter struct {
n int
mut sync.RWMutex
}
func (c *counter) Increment() {
c.mut.Lock()
c.n++
c.mut.Unlock()
}
func (c *counter) Decrement() {
c.mut.Lock()
c.n--
c.mut.Unlock()
}
func (c *counter) Size() int {
c.mut.RLock()
defer c.mut.RUnlock()
return c.n
}
// GetValue searches for the value corresponding to given Key.
// If the search does not succeed, a multiaddr string of a closer peer is
// returned along with util.ErrSearchIncomplete
......@@ -65,7 +90,6 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) {
ll.EndLog()
ll.Print()
}()
route_level := 0
// If we have it local, dont bother doing an RPC!
// NOTE: this might not be what we want to do...
......@@ -76,54 +100,90 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) {
return val, nil
}
p := s.routes[route_level].NearestPeer(kb.ConvertKey(key))
if p == nil {
route_level := 0
closest := s.routes[route_level].NearestPeers(kb.ConvertKey(key), PoolSize)
if closest == nil || len(closest) == 0 {
return nil, kb.ErrLookupFailure
}
for route_level < len(s.routes) && p != nil {
ll.RpcCount++
pmes, err := s.getValueSingle(p, key, timeout, route_level)
if err != nil {
return nil, u.WrapError(err, "getValue Error")
}
val_chan := make(chan []byte)
npeer_chan := make(chan *peer.Peer, 30)
proc_peer := make(chan *peer.Peer, 30)
err_chan := make(chan error)
after := time.After(timeout)
if pmes.GetSuccess() {
if pmes.Value == nil { // We were given provider[s]
ll.RpcCount++
return s.getFromPeerList(key, timeout, pmes.GetPeers(), route_level)
}
for _, p := range closest {
npeer_chan <- p
}
// Success! We were given the value
ll.Success = true
return pmes.GetValue(), nil
} else {
// We were given a closer node
closers := pmes.GetPeers()
if len(closers) > 0 {
if peer.ID(closers[0].GetId()).Equal(s.self.ID) {
u.DOut("Got myself back as a closer peer.")
return nil, u.ErrNotFound
c := counter{}
// This limit value is referred to as k in the kademlia paper
limit := 20
count := 0
go func() {
for {
select {
case p := <-npeer_chan:
count++
if count >= limit {
break
}
maddr, err := ma.NewMultiaddr(closers[0].GetAddr())
if err != nil {
// ??? Move up route level???
panic("not yet implemented")
c.Increment()
proc_peer <- p
default:
if c.Size() == 0 {
err_chan <- u.ErrNotFound
}
}
}
}()
np, err := s.network.GetConnection(peer.ID(closers[0].GetId()), maddr)
process := func() {
for {
select {
case p, ok := <-proc_peer:
if !ok || p == nil {
c.Decrement()
return
}
val, peers, err := s.getValueOrPeers(p, key, timeout/4, route_level)
if err != nil {
u.PErr("[%s] Failed to connect to: %s", s.self.ID.Pretty(), closers[0].GetAddr())
route_level++
u.DErr(err.Error())
c.Decrement()
continue
}
p = np
} else {
route_level++
if val != nil {
val_chan <- val
c.Decrement()
return
}
for _, np := range peers {
// TODO: filter out peers that arent closer
npeer_chan <- np
}
c.Decrement()
}
}
}
return nil, u.ErrNotFound
concurFactor := 3
for i := 0; i < concurFactor; i++ {
go process()
}
select {
case val := <-val_chan:
close(npeer_chan)
return val, nil
case err := <-err_chan:
close(npeer_chan)
return nil, err
case <-after:
close(npeer_chan)
return nil, u.ErrTimeout
}
}
// Value provider layer of indirection.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment