Commit 26f7ee2f authored by Kevin Atkinson's avatar Kevin Atkinson

Refactor EnumerateChildrenAsync to take in a function to get the links.

For now it is always called with the helper function GetLinksDirect to
avoid any change in behaviour.

License: MIT
Signed-off-by: default avatarKevin Atkinson <k@kevina.org>
parent 090bf568
...@@ -138,11 +138,23 @@ func (n *dagService) Remove(nd node.Node) error { ...@@ -138,11 +138,23 @@ func (n *dagService) Remove(nd node.Node) error {
return n.Blocks.DeleteBlock(nd) return n.Blocks.DeleteBlock(nd)
} }
// get the links for a node, from the node, bypassing the
// LinkService
func GetLinksDirect(serv DAGService) GetLinks {
return func(ctx context.Context, c *cid.Cid) ([]*node.Link, error) {
node, err := serv.Get(ctx, c)
if err != nil {
return nil, err
}
return node.Links(), nil
}
}
// FetchGraph fetches all nodes that are children of the given node // FetchGraph fetches all nodes that are children of the given node
func FetchGraph(ctx context.Context, root *cid.Cid, serv DAGService) error { func FetchGraph(ctx context.Context, root *cid.Cid, serv DAGService) error {
v, _ := ctx.Value("progress").(*ProgressTracker) v, _ := ctx.Value("progress").(*ProgressTracker)
if v == nil { if v == nil {
return EnumerateChildrenAsync(ctx, serv, root, cid.NewSet().Visit) return EnumerateChildrenAsync(ctx, GetLinksDirect(serv), root, cid.NewSet().Visit)
} }
set := cid.NewSet() set := cid.NewSet()
visit := func(c *cid.Cid) bool { visit := func(c *cid.Cid) bool {
...@@ -153,7 +165,7 @@ func FetchGraph(ctx context.Context, root *cid.Cid, serv DAGService) error { ...@@ -153,7 +165,7 @@ func FetchGraph(ctx context.Context, root *cid.Cid, serv DAGService) error {
return false return false
} }
} }
return EnumerateChildrenAsync(ctx, serv, root, visit) return EnumerateChildrenAsync(ctx, GetLinksDirect(serv), root, visit)
} }
// FindLinks searches this nodes links for the given key, // FindLinks searches this nodes links for the given key,
...@@ -380,10 +392,11 @@ func (t *Batch) Commit() error { ...@@ -380,10 +392,11 @@ func (t *Batch) Commit() error {
return err return err
} }
type GetLinks func(context.Context, *cid.Cid) ([]*node.Link, error)
// EnumerateChildren will walk the dag below the given root node and add all // EnumerateChildren will walk the dag below the given root node and add all
// unseen children to the passed in set. // unseen children to the passed in set.
// TODO: parallelize to avoid disk latency perf hits? // TODO: parallelize to avoid disk latency perf hits?
type GetLinks func(context.Context, *cid.Cid) ([]*node.Link, error)
func EnumerateChildren(ctx context.Context, getLinks GetLinks, root *cid.Cid, visit func(*cid.Cid) bool) error { func EnumerateChildren(ctx context.Context, getLinks GetLinks, root *cid.Cid, visit func(*cid.Cid) bool) error {
links, err := getLinks(ctx, root) links, err := getLinks(ctx, root)
if err != nil { if err != nil {
...@@ -426,9 +439,9 @@ func (p *ProgressTracker) Value() int { ...@@ -426,9 +439,9 @@ func (p *ProgressTracker) Value() int {
// 'fetchNodes' will start at a time // 'fetchNodes' will start at a time
var FetchGraphConcurrency = 8 var FetchGraphConcurrency = 8
func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visit func(*cid.Cid) bool) error { func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c *cid.Cid, visit func(*cid.Cid) bool) error {
feed := make(chan *cid.Cid) feed := make(chan *cid.Cid)
out := make(chan node.Node) out := make(chan []*node.Link)
done := make(chan struct{}) done := make(chan struct{})
var setlk sync.Mutex var setlk sync.Mutex
...@@ -441,7 +454,7 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi ...@@ -441,7 +454,7 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi
for i := 0; i < FetchGraphConcurrency; i++ { for i := 0; i < FetchGraphConcurrency; i++ {
go func() { go func() {
for ic := range feed { for ic := range feed {
n, err := ds.Get(ctx, ic) links, err := getLinks(ctx, ic)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
...@@ -453,7 +466,7 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi ...@@ -453,7 +466,7 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi
if unseen { if unseen {
select { select {
case out <- n: case out <- links:
case <-fetchersCtx.Done(): case <-fetchersCtx.Done():
return return
} }
...@@ -488,8 +501,8 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi ...@@ -488,8 +501,8 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visi
if inProgress == 0 && next == nil { if inProgress == 0 && next == nil {
return nil return nil
} }
case nd := <-out: case links := <-out:
for _, lnk := range nd.Links() { for _, lnk := range links {
if next == nil { if next == nil {
next = lnk.Cid next = lnk.Cid
send = feed send = feed
......
...@@ -543,7 +543,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) { ...@@ -543,7 +543,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
} }
cset := cid.NewSet() cset := cid.NewSet()
err = EnumerateChildrenAsync(context.Background(), ds, pcid, cset.Visit) err = EnumerateChildrenAsync(context.Background(), GetLinksDirect(ds), pcid, cset.Visit)
if err == nil { if err == nil {
t.Fatal("this should have failed") t.Fatal("this should have failed")
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment