Commit c9dcb025 authored by Kevin Atkinson's avatar Kevin Atkinson

Report progress during 'pin add'.

License: MIT
Signed-off-by: default avatarKevin Atkinson <k@kevina.org>
parent 5dc5456c
......@@ -139,8 +139,21 @@ func (n *dagService) Remove(nd node.Node) error {
}
// FetchGraph fetches all nodes that are children of the given node
func FetchGraph(ctx context.Context, c *cid.Cid, serv DAGService) error {
return EnumerateChildrenAsync(ctx, serv, c, cid.NewSet().Visit)
func FetchGraph(ctx context.Context, root *cid.Cid, serv DAGService) error {
v, _ := ctx.Value("progress").(*ProgressTracker)
if v == nil {
return EnumerateChildrenAsync(ctx, serv, root, cid.NewSet().Visit)
}
set := cid.NewSet()
visit := func(c *cid.Cid) bool {
if set.Visit(c) {
v.Increment()
return true
} else {
return false
}
}
return EnumerateChildrenAsync(ctx, serv, root, visit)
}
// FindLinks searches this nodes links for the given key,
......@@ -389,6 +402,27 @@ func EnumerateChildren(ctx context.Context, ds LinkService, root *cid.Cid, visit
return nil
}
type ProgressTracker struct {
Total int
lk sync.Mutex
}
func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context {
return context.WithValue(ctx, "progress", p)
}
func (p *ProgressTracker) Increment() {
p.lk.Lock()
defer p.lk.Unlock()
p.Total++
}
func (p *ProgressTracker) Value() int {
p.lk.Lock()
defer p.lk.Unlock()
return p.Total
}
// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 8
......
......@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"strings"
"sync"
"testing"
......@@ -547,3 +548,80 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
t.Fatal("this should have failed")
}
}
func TestProgressIndicator(t *testing.T) {
testProgressIndicator(t, 5)
}
func TestProgressIndicatorNoChildren(t *testing.T) {
testProgressIndicator(t, 0)
}
func testProgressIndicator(t *testing.T, depth int) {
ds := dstest.Mock()
top, numChildren := mkDag(ds, depth)
v := new(ProgressTracker)
ctx := v.DeriveContext(context.Background())
err := FetchGraph(ctx, top, ds)
if err != nil {
t.Fatal(err)
}
if v.Value() != numChildren+1 {
t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d",
numChildren+1, v.Value())
}
}
func mkDag(ds DAGService, depth int) (*cid.Cid, int) {
totalChildren := 0
f := func() *ProtoNode {
p := new(ProtoNode)
buf := make([]byte, 16)
rand.Read(buf)
p.SetData(buf)
_, err := ds.Add(p)
if err != nil {
panic(err)
}
return p
}
for i := 0; i < depth; i++ {
thisf := f
f = func() *ProtoNode {
pn := mkNodeWithChildren(thisf, 10)
_, err := ds.Add(pn)
if err != nil {
panic(err)
}
totalChildren += 10
return pn
}
}
nd := f()
c, err := ds.Add(nd)
if err != nil {
panic(err)
}
return c, totalChildren
}
func mkNodeWithChildren(getChild func() *ProtoNode, width int) *ProtoNode {
cur := new(ProtoNode)
for i := 0; i < width; i++ {
c := getChild()
if err := cur.AddNodeLinkClean(fmt.Sprint(i), c); err != nil {
panic(err)
}
}
return cur
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment