diff --git a/routing/dht/dht.go b/routing/dht/dht.go index ea8e1d861eef61e9467e75abbad52c3ac196f2da..b00ae0a4d5f20b297cf4ed2690e82475b7cefb11 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -496,6 +496,45 @@ out: dht.network.Send(mes) } +func (dht *IpfsDHT) getValueOrPeers(p *peer.Peer, key u.Key, timeout time.Duration, level int) ([]byte, []*peer.Peer, error) { + pmes, err := dht.getValueSingle(p, key, timeout, level) + if err != nil { + return nil, nil, u.WrapError(err, "getValue Error") + } + + if pmes.GetSuccess() { + if pmes.Value == nil { // We were given provider[s] + val, err := dht.getFromPeerList(key, timeout, pmes.GetPeers(), level) + if err != nil { + return nil, nil, err + } + return val, nil, nil + } + + // Success! We were given the value + return pmes.GetValue(), nil, nil + } else { + // We were given a closer node + var peers []*peer.Peer + for _, pb := range pmes.GetPeers() { + addr, err := ma.NewMultiaddr(pb.GetAddr()) + if err != nil { + u.PErr(err.Error()) + continue + } + + np, err := dht.network.GetConnection(peer.ID(pb.GetId()), addr) + if err != nil { + u.PErr(err.Error()) + continue + } + + peers = append(peers, np) + } + return nil, peers, nil + } +} + // getValueSingle simply performs the get value RPC with the given parameters func (dht *IpfsDHT) getValueSingle(p *peer.Peer, key u.Key, timeout time.Duration, level int) (*PBDHTMessage, error) { pmes := DHTMessage{ diff --git a/routing/dht/dht_logger.go b/routing/dht/dht_logger.go index c363add7bbc1d2e9c4a2e5ae87734fae340dca69..c892959f06f6bb88f4ab100d39307fe9a4c1419a 100644 --- a/routing/dht/dht_logger.go +++ b/routing/dht/dht_logger.go @@ -31,8 +31,8 @@ func (l *logDhtRpc) EndLog() { func (l *logDhtRpc) Print() { b, err := json.Marshal(l) if err != nil { - u.POut(err.Error()) + u.DOut(err.Error()) } else { - u.POut(string(b)) + u.DOut(string(b)) } } diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index fbf52a26371f884d746526b329f70c4e146a6d47..490c9f493e02aaf4b05aad2ce19abc64257ad5a1 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -88,13 +88,9 @@ func TestGetFailures(t *testing.T) { d.Update(other) // This one should time out - _, err := d.GetValue(u.Key("test"), time.Millisecond*5) + _, err := d.GetValue(u.Key("test"), time.Millisecond*10) if err != nil { - nerr, ok := err.(*u.IpfsError) - if !ok { - t.Fatal("Got different error than we expected.") - } - if nerr.Inner != u.ErrTimeout { + if err != u.ErrTimeout { t.Fatal("Got different error than we expected.") } } else { @@ -119,10 +115,10 @@ func TestGetFailures(t *testing.T) { }) // This one should fail with NotFound - _, err = d.GetValue(u.Key("test"), time.Millisecond*5) + _, err = d.GetValue(u.Key("test"), time.Millisecond*1000) if err != nil { if err != u.ErrNotFound { - t.Fatal("Expected ErrNotFound, got: %s", err) + t.Fatalf("Expected ErrNotFound, got: %s", err) } } else { t.Fatal("expected error, got none.") diff --git a/routing/dht/routing.go b/routing/dht/routing.go index 9923961d14fe14864d90b77ce9b8616c24a0678e..2ecd8ba4598aa1844404e1682a3a045afd25c61f 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "math/rand" + "sync" "time" proto "code.google.com/p/goprotobuf/proto" @@ -56,6 +57,30 @@ func (s *IpfsDHT) PutValue(key u.Key, value []byte) { } } +// A counter for incrementing a variable across multiple threads +type counter struct { + n int + mut sync.RWMutex +} + +func (c *counter) Increment() { + c.mut.Lock() + c.n++ + c.mut.Unlock() +} + +func (c *counter) Decrement() { + c.mut.Lock() + c.n-- + c.mut.Unlock() +} + +func (c *counter) Size() int { + c.mut.RLock() + defer c.mut.RUnlock() + return c.n +} + // GetValue searches for the value corresponding to given Key. // If the search does not succeed, a multiaddr string of a closer peer is // returned along with util.ErrSearchIncomplete @@ -65,7 +90,6 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { ll.EndLog() ll.Print() }() - route_level := 0 // If we have it local, dont bother doing an RPC! // NOTE: this might not be what we want to do... @@ -76,54 +100,90 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { return val, nil } - p := s.routes[route_level].NearestPeer(kb.ConvertKey(key)) - if p == nil { + route_level := 0 + closest := s.routes[route_level].NearestPeers(kb.ConvertKey(key), PoolSize) + if closest == nil || len(closest) == 0 { return nil, kb.ErrLookupFailure } - for route_level < len(s.routes) && p != nil { - ll.RpcCount++ - pmes, err := s.getValueSingle(p, key, timeout, route_level) - if err != nil { - return nil, u.WrapError(err, "getValue Error") - } + val_chan := make(chan []byte) + npeer_chan := make(chan *peer.Peer, 30) + proc_peer := make(chan *peer.Peer, 30) + err_chan := make(chan error) + after := time.After(timeout) - if pmes.GetSuccess() { - if pmes.Value == nil { // We were given provider[s] - ll.RpcCount++ - return s.getFromPeerList(key, timeout, pmes.GetPeers(), route_level) - } + for _, p := range closest { + npeer_chan <- p + } - // Success! We were given the value - ll.Success = true - return pmes.GetValue(), nil - } else { - // We were given a closer node - closers := pmes.GetPeers() - if len(closers) > 0 { - if peer.ID(closers[0].GetId()).Equal(s.self.ID) { - u.DOut("Got myself back as a closer peer.") - return nil, u.ErrNotFound + c := counter{} + + // This limit value is referred to as k in the kademlia paper + limit := 20 + count := 0 + go func() { + for { + select { + case p := <-npeer_chan: + count++ + if count >= limit { + break } - maddr, err := ma.NewMultiaddr(closers[0].GetAddr()) - if err != nil { - // ??? Move up route level??? - panic("not yet implemented") + c.Increment() + proc_peer <- p + default: + if c.Size() == 0 { + err_chan <- u.ErrNotFound } + } + } + }() - np, err := s.network.GetConnection(peer.ID(closers[0].GetId()), maddr) + process := func() { + for { + select { + case p, ok := <-proc_peer: + if !ok || p == nil { + c.Decrement() + return + } + val, peers, err := s.getValueOrPeers(p, key, timeout/4, route_level) if err != nil { - u.PErr("[%s] Failed to connect to: %s", s.self.ID.Pretty(), closers[0].GetAddr()) - route_level++ + u.DErr(err.Error()) + c.Decrement() continue } - p = np - } else { - route_level++ + if val != nil { + val_chan <- val + c.Decrement() + return + } + + for _, np := range peers { + // TODO: filter out peers that arent closer + npeer_chan <- np + } + c.Decrement() } } } - return nil, u.ErrNotFound + + concurFactor := 3 + for i := 0; i < concurFactor; i++ { + go process() + } + + select { + case val := <-val_chan: + close(npeer_chan) + return val, nil + case err := <-err_chan: + close(npeer_chan) + return nil, err + case <-after: + close(npeer_chan) + return nil, u.ErrTimeout + } } // Value provider layer of indirection.