Commit 9eb769d7 authored by Jeromy's avatar Jeromy

rewrite enumerate children async to be less fragile

License: MIT
Signed-off-by: default avatarJeromy <why@ipfs.io>
parent 49557764
...@@ -389,103 +389,92 @@ func EnumerateChildren(ctx context.Context, ds LinkService, root *cid.Cid, visit ...@@ -389,103 +389,92 @@ func EnumerateChildren(ctx context.Context, ds LinkService, root *cid.Cid, visit
return nil return nil
} }
func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visit func(*cid.Cid) bool) error { // FetchGraphConcurrency is total number of concurrent fetches that
toprocess := make(chan []*cid.Cid, 8) // 'fetchNodes' will start at a time
nodes := make(chan *NodeOption, 8) var FetchGraphConcurrency = 8
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer close(toprocess)
go fetchNodes(ctx, ds, toprocess, nodes) func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visit func(*cid.Cid) bool) error {
if !visit(c) {
return nil
}
root, err := ds.Get(ctx, c) root, err := ds.Get(ctx, c)
if err != nil { if err != nil {
return err return err
} }
nodes <- &NodeOption{Node: root} feed := make(chan node.Node)
live := 1 out := make(chan *NodeOption)
done := make(chan struct{})
for {
select { var setlk sync.Mutex
case opt, ok := <-nodes:
if !ok { for i := 0; i < FetchGraphConcurrency; i++ {
return nil go func() {
} for n := range feed {
links := n.Links()
if opt.Err != nil { cids := make([]*cid.Cid, 0, len(links))
return opt.Err for _, l := range links {
} setlk.Lock()
unseen := visit(l.Cid)
nd := opt.Node setlk.Unlock()
if unseen {
// a node has been fetched cids = append(cids, l.Cid)
live-- }
var cids []*cid.Cid
for _, lnk := range nd.Links() {
c := lnk.Cid
if visit(c) {
live++
cids = append(cids, c)
} }
}
if live == 0 {
return nil
}
if len(cids) > 0 { for nopt := range ds.GetMany(ctx, cids) {
select {
case out <- nopt:
case <-ctx.Done():
return
}
}
select { select {
case toprocess <- cids: case done <- struct{}{}:
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err()
} }
} }
case <-ctx.Done(): }()
return ctx.Err()
}
} }
} defer close(feed)
// FetchGraphConcurrency is total number of concurrenct fetches that send := feed
// 'fetchNodes' will start at a time var todobuffer []node.Node
var FetchGraphConcurrency = 8 var inProgress int
func fetchNodes(ctx context.Context, ds DAGService, in <-chan []*cid.Cid, out chan<- *NodeOption) {
var wg sync.WaitGroup
defer func() {
// wait for all 'get' calls to complete so we don't accidentally send
// on a closed channel
wg.Wait()
close(out)
}()
rateLimit := make(chan struct{}, FetchGraphConcurrency) next := root
for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
} else {
next = nil
send = nil
}
case <-done:
inProgress--
if inProgress == 0 && next == nil {
return nil
}
case nc := <-out:
if nc.Err != nil {
return nc.Err
}
get := func(ks []*cid.Cid) { if next == nil {
defer wg.Done() next = nc.Node
defer func() { send = feed
<-rateLimit } else {
}() todobuffer = append(todobuffer, nc.Node)
nodes := ds.GetMany(ctx, ks)
for opt := range nodes {
select {
case out <- opt:
case <-ctx.Done():
return
} }
}
}
for ks := range in {
select {
case rateLimit <- struct{}{}:
case <-ctx.Done(): case <-ctx.Done():
return return ctx.Err()
} }
wg.Add(1)
go get(ks)
} }
} }
...@@ -504,3 +504,46 @@ func TestCidRawDoesnNeedData(t *testing.T) { ...@@ -504,3 +504,46 @@ func TestCidRawDoesnNeedData(t *testing.T) {
t.Fatal("raw node shouldn't have any links") t.Fatal("raw node shouldn't have any links")
} }
} }
func TestEnumerateAsyncFailsNotFound(t *testing.T) {
a := NodeWithData([]byte("foo1"))
b := NodeWithData([]byte("foo2"))
c := NodeWithData([]byte("foo3"))
d := NodeWithData([]byte("foo4"))
ds := dstest.Mock()
for _, n := range []node.Node{a, b, c} {
_, err := ds.Add(n)
if err != nil {
t.Fatal(err)
}
}
parent := new(ProtoNode)
if err := parent.AddNodeLinkClean("a", a); err != nil {
t.Fatal(err)
}
if err := parent.AddNodeLinkClean("b", b); err != nil {
t.Fatal(err)
}
if err := parent.AddNodeLinkClean("c", c); err != nil {
t.Fatal(err)
}
if err := parent.AddNodeLinkClean("d", d); err != nil {
t.Fatal(err)
}
pcid, err := ds.Add(parent)
if err != nil {
t.Fatal(err)
}
cset := cid.NewSet()
err = EnumerateChildrenAsync(context.Background(), ds, pcid, cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment