rework the graph walking functions with functional options

This commit:
- reduce the API to 2 simpler functions
- add consistent and clear control over if the root should be visited
- add control over the concurrency factor
parent 625228c5
......@@ -162,7 +162,7 @@ func FetchGraph(ctx context.Context, root cid.Cid, serv ipld.DAGService) error {
}
// FetchGraphWithDepthLimit fetches all nodes that are children to the given
// node down to the given depth. maxDetph=0 means "only fetch root",
// node down to the given depth. maxDepth=0 means "only fetch root",
// maxDepth=1 means "fetch root and its direct children" and so on...
// maxDepth=-1 means unlimited.
func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ipld.DAGService) error {
......@@ -195,9 +195,10 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
return false
}
// If we have a ProgressTracker, we wrap the visit function to handle it
v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
if v == nil {
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit)
return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent(), WithRoot())
}
visitProgress := func(c cid.Cid, depth int) bool {
......@@ -207,7 +208,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
}
return false
}
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress)
return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent(), WithRoot())
}
// GetMany gets many nodes from the DAG at once.
......@@ -281,21 +282,77 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks {
}
}
// defaultConcurrentFetch is the default maximum number of concurrent fetches
// that 'fetchNodes' will start at a time
const defaultConcurrentFetch = 32
// WalkOptions represent the parameters of a graph walking algorithm
type WalkOptions struct {
WithRoot bool
IgnoreBadBlock bool
Concurrency int
}
// WalkOption is a setter for WalkOptions
type WalkOption func(*WalkOptions)
// WithRoot is a WalkOption indicating that the root node should be visited
func WithRoot() WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.WithRoot = true
}
}
// Concurrent is a WalkOption indicating that node fetching should be done in
// parallel, with the default concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrent() WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.Concurrency = defaultConcurrentFetch
}
}
// Concurrency is a WalkOption indicating that node fetching should be done in
// parallel, with a specific concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrency(worker int) WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.Concurrency = worker
}
}
// WalkGraph will walk the dag in order (depth first) starting at the given root.
func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error {
func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}
return WalkDepth(ctx, getLinks, root, 0, visitDepth)
return WalkDepth(ctx, getLinks, c, visitDepth, options...)
}
// WalkDepth walks the dag starting at the given root and passes the current
// depth to a given visit function. The visit function can be used to limit DAG
// exploration.
func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error {
if !visit(root, depth) {
return nil
func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error {
opts := &WalkOptions{}
for _, opt := range options {
opt(opts)
}
if opts.Concurrency > 1 {
return parallelWalkDepth(ctx, getLinks, c, visit, opts)
} else {
return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts)
}
}
func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *WalkOptions) error {
if depth != 0 || options.WithRoot {
if !visit(root, depth) {
return nil
}
}
links, err := getLinks(ctx, root)
......@@ -304,7 +361,7 @@ func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int,
}
for _, lnk := range links {
if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil {
if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil {
return err
}
}
......@@ -337,27 +394,7 @@ func (p *ProgressTracker) Value() int {
return p.Total
}
// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 32
// WalkParallel is equivalent to Walk *except* that it explores multiple paths
// in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}
return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth)
}
// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches
// children in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error {
func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *WalkOptions) error {
type cidDepth struct {
cid cid.Cid
depth int
......@@ -372,14 +409,14 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
out := make(chan *linksDepth)
done := make(chan struct{})
var setlk sync.Mutex
var visitlk sync.Mutex
var wg sync.WaitGroup
errChan := make(chan error)
fetchersCtx, cancel := context.WithCancel(ctx)
defer wg.Wait()
defer cancel()
for i := 0; i < FetchGraphConcurrency; i++ {
for i := 0; i < options.Concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
......@@ -387,9 +424,16 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
ci := cdepth.cid
depth := cdepth.depth
setlk.Lock()
shouldVisit := visit(ci, depth)
setlk.Unlock()
var shouldVisit bool
// bypass the root if needed
if depth != 0 || options.WithRoot {
visitlk.Lock()
shouldVisit = visit(ci, depth)
visitlk.Unlock()
} else {
shouldVisit = true
}
if shouldVisit {
links, err := getLinks(ctx, ci)
......@@ -422,20 +466,21 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
defer close(feed)
send := feed
var todobuffer []*cidDepth
var todoQueue []*cidDepth
var inProgress int
next := &cidDepth{
cid: c,
depth: startDepth,
cid: root,
depth: 0,
}
for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
if len(todoQueue) > 0 {
next = todoQueue[0]
todoQueue = todoQueue[1:]
} else {
next = nil
send = nil
......@@ -456,7 +501,7 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
next = cd
send = feed
} else {
todobuffer = append(todobuffer, cd)
todoQueue = append(todoQueue, cd)
}
}
case err := <-errChan:
......@@ -466,7 +511,6 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
return ctx.Err()
}
}
}
var _ ipld.LinkGetter = &dagService{}
......
......@@ -203,8 +203,11 @@ func makeTestDAG(t *testing.T, read io.Reader, ds ipld.DAGService) ipld.Node {
// Add a root referencing all created nodes
root := NodeWithData(nil)
for _, n := range nodes {
root.AddNodeLink(n.Cid().String(), n)
err := ds.Add(ctx, n)
err := root.AddNodeLink(n.Cid().String(), n)
if err != nil {
t.Fatal(err)
}
err = ds.Add(ctx, n)
if err != nil {
t.Fatal(err)
}
......@@ -383,7 +386,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) {
}
err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF)
err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), visitF, WithRoot())
if err != nil {
t.Fatal(err)
}
......@@ -736,7 +739,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
}
cset := cid.NewSet()
err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
err = Walk(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment