merkledag.go 14.9 KB
Newer Older
1
// Package merkledag implements the DMS3 Merkle DAG data structures.
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
2 3 4
package merkledag

import (
5
	"context"
6
	"fmt"
7
	"sync"
Jeromy's avatar
Jeromy committed
8

9 10 11 12 13
	blocks "gitlab.dms3.io/dms3/go-block-format"
	bserv "gitlab.dms3.io/dms3/go-blockservice"
	cid "gitlab.dms3.io/dms3/go-cid"
	ldcbor "gitlab.dms3.io/dms3/go-ld-cbor"
	ld "gitlab.dms3.io/dms3/go-ld-format"
Juan Batiz-Benet's avatar
Juan Batiz-Benet committed
14 15
)

16 17
// TODO: We should move these registrations elsewhere. Really, most of the LD
// functionality should go in a `go-ld` repo but that will take a lot of work
18 19
// and design.
func init() {
20 21 22
	ld.Register(cid.DagProtobuf, DecodeProtobufBlock)
	ld.Register(cid.Raw, DecodeRawBlock)
	ld.Register(cid.DagCBOR, ldcbor.DecodeBlock)
23 24
}

25 26 27 28 29
// contextKey is a type to use as value for the ProgressTracker contexts.
type contextKey string

const progressContextKey contextKey = "progress"

30
// NewDAGService constructs a new DAGService (using the default implementation).
31
// Note that the default implementation is also an ld.LinkGetter.
32
func NewDAGService(bs bserv.BlockService) *dagService {
33
	return &dagService{Blocks: bs}
Jeromy's avatar
Jeromy committed
34 35
}

36
// dagService is an DMS3 Merkle DAG service.
Juan Batiz-Benet's avatar
go lint  
Juan Batiz-Benet committed
37 38
// - the root is virtual (like a forest)
// - stores nodes' data in a BlockService
39 40
// TODO: should cache Nodes that are in memory, and be
//       able to free some of them when vm pressure is high
41
type dagService struct {
42
	Blocks bserv.BlockService
43 44
}

45
// Add adds a node to the dagService, storing the block in the BlockService
46
func (n *dagService) Add(ctx context.Context, nd ld.Node) error {
47
	if n == nil { // FIXME remove this assertion. protect with constructor invariant
48
		return fmt.Errorf("dagService is nil")
49 50
	}

51
	return n.Blocks.AddBlock(nd)
52 53
}

54
func (n *dagService) AddMany(ctx context.Context, nds []ld.Node) error {
55 56 57
	blks := make([]blocks.Block, len(nds))
	for i, nd := range nds {
		blks[i] = nd
58
	}
59
	return n.Blocks.AddBlocks(blks)
60 61
}

62
// Get retrieves a node from the dagService, fetching the block in the BlockService
63
func (n *dagService) Get(ctx context.Context, c cid.Cid) (ld.Node, error) {
64
	if n == nil {
65
		return nil, fmt.Errorf("dagService is nil")
66
	}
Jeromy's avatar
Jeromy committed
67

68 69
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
70

Jeromy's avatar
Jeromy committed
71
	b, err := n.Blocks.GetBlock(ctx, c)
72
	if err != nil {
73
		if err == bserv.ErrNotFound {
74
			return nil, ld.ErrNotFound
75
		}
Łukasz Magiera's avatar
Łukasz Magiera committed
76
		return nil, fmt.Errorf("failed to get block for %s: %v", c, err)
77 78
	}

79
	return ld.Decode(b)
80
}
Jeromy's avatar
Jeromy committed
81

82 83
// GetLinks return the links for the node, the node doesn't necessarily have
// to exist locally.
84
func (n *dagService) GetLinks(ctx context.Context, c cid.Cid) ([]*ld.Link, error) {
85 86 87
	if c.Type() == cid.Raw {
		return nil, nil
	}
88 89 90 91
	node, err := n.Get(ctx, c)
	if err != nil {
		return nil, err
	}
92
	return node.Links(), nil
93 94
}

95
func (n *dagService) Remove(ctx context.Context, c cid.Cid) error {
96
	return n.Blocks.DeleteBlock(c)
97 98
}

99 100 101 102 103
// RemoveMany removes multiple nodes from the DAG. It will likely be faster than
// removing them individually.
//
// This operation is not atomic. If it returns an error, some nodes may or may
// not have been removed.
104
func (n *dagService) RemoveMany(ctx context.Context, cids []cid.Cid) error {
105 106 107 108 109 110 111
	// TODO(#4608): make this batch all the way down.
	for _, c := range cids {
		if err := n.Blocks.DeleteBlock(c); err != nil {
			return err
		}
	}
	return nil
Jeromy's avatar
Jeromy committed
112 113
}

114 115 116
// GetLinksDirect creates a function to get the links for a node, from
// the node, bypassing the LinkService.  If the node does not exist
// locally (and can not be retrieved) an error will be returned.
117 118
func GetLinksDirect(serv ld.NodeGetter) GetLinks {
	return func(ctx context.Context, c cid.Cid) ([]*ld.Link, error) {
119
		nd, err := serv.Get(ctx, c)
120
		if err != nil {
Jeromy's avatar
Jeromy committed
121
			if err == bserv.ErrNotFound {
122
				err = ld.ErrNotFound
Jeromy's avatar
Jeromy committed
123
			}
124 125
			return nil, err
		}
126
		return nd.Links(), nil
127 128 129
	}
}

130 131 132 133
type sesGetter struct {
	bs *bserv.Session
}

134
// Get gets a single node from the DAG.
135
func (sg *sesGetter) Get(ctx context.Context, c cid.Cid) (ld.Node, error) {
136
	blk, err := sg.bs.GetBlock(ctx, c)
Jeromy's avatar
Jeromy committed
137 138
	switch err {
	case bserv.ErrNotFound:
139
		return nil, ld.ErrNotFound
Jeromy's avatar
Jeromy committed
140
	default:
141
		return nil, err
Jeromy's avatar
Jeromy committed
142 143
	case nil:
		// noop
144 145
	}

146
	return ld.Decode(blk)
147 148
}

149
// GetMany gets many nodes at once, batching the request if possible.
150
func (sg *sesGetter) GetMany(ctx context.Context, keys []cid.Cid) <-chan *ld.NodeOption {
151 152 153
	return getNodesFromBG(ctx, sg.bs, keys)
}

Jeromy's avatar
Jeromy committed
154
// Session returns a NodeGetter using a new session for block fetches.
155
func (n *dagService) Session(ctx context.Context) ld.NodeGetter {
156
	return &sesGetter{bserv.NewSession(ctx, n.Blocks)}
Jeromy's avatar
Jeromy committed
157 158
}

159
// FetchGraph fetches all nodes that are children of the given node
160
func FetchGraph(ctx context.Context, root cid.Cid, serv ld.DAGService) error {
161 162 163 164
	return FetchGraphWithDepthLimit(ctx, root, -1, serv)
}

// FetchGraphWithDepthLimit fetches all nodes that are children to the given
165
// node down to the given depth. maxDepth=0 means "only fetch root",
166 167
// maxDepth=1 means "fetch root and its direct children" and so on...
// maxDepth=-1 means unlimited.
168 169
func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ld.DAGService) error {
	var ng ld.NodeGetter = NewSession(ctx, serv)
170

Steven Allen's avatar
Steven Allen committed
171
	set := make(map[cid.Cid]int)
172 173 174 175 176 177 178 179

	// Visit function returns true when:
	// * The element is not in the set and we're not over depthLim
	// * The element is in the set but recorded depth is deeper
	//   than currently seen (if we find it higher in the tree we'll need
	//   to explore deeper than before).
	// depthLim = -1 means we only return true if the element is not in the
	// set.
180
	visit := func(c cid.Cid, depth int) bool {
Steven Allen's avatar
Steven Allen committed
181
		oldDepth, ok := set[c]
182 183 184 185 186 187

		if (ok && depthLim < 0) || (depthLim >= 0 && depth > depthLim) {
			return false
		}

		if !ok || oldDepth > depth {
Steven Allen's avatar
Steven Allen committed
188
			set[c] = depth
189 190 191 192 193
			return true
		}
		return false
	}

194
	// If we have a ProgressTracker, we wrap the visit function to handle it
195
	v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
196
	if v == nil {
197
		return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent())
198
	}
199

200
	visitProgress := func(c cid.Cid, depth int) bool {
201
		if visit(c, depth) {
202 203 204
			v.Increment()
			return true
		}
205
		return false
206
	}
207
	return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent())
Jeromy's avatar
Jeromy committed
208
}
209

210 211 212 213 214
// GetMany gets many nodes from the DAG at once.
//
// This method may not return all requested nodes (and may or may not return an
// error indicating that it failed to do so. It is up to the caller to verify
// that it received all nodes.
215
func (n *dagService) GetMany(ctx context.Context, keys []cid.Cid) <-chan *ld.NodeOption {
216
	return getNodesFromBG(ctx, n.Blocks, keys)
217 218
}

219
func dedupKeys(keys []cid.Cid) []cid.Cid {
Steven Allen's avatar
Steven Allen committed
220 221 222 223 224 225 226 227 228 229
	set := cid.NewSet()
	for _, c := range keys {
		set.Add(c)
	}
	if set.Len() == len(keys) {
		return keys
	}
	return set.Keys()
}

230
func getNodesFromBG(ctx context.Context, bs bserv.BlockGetter, keys []cid.Cid) <-chan *ld.NodeOption {
Steven Allen's avatar
Steven Allen committed
231 232
	keys = dedupKeys(keys)

233
	out := make(chan *ld.NodeOption, len(keys))
234
	blocks := bs.GetBlocks(ctx, keys)
235 236
	var count int

237 238 239 240 241 242
	go func() {
		defer close(out)
		for {
			select {
			case b, ok := <-blocks:
				if !ok {
243
					if count != len(keys) {
244
						out <- &ld.NodeOption{Err: fmt.Errorf("failed to fetch all nodes")}
245
					}
246 247
					return
				}
Jeromy's avatar
Jeromy committed
248

249
				nd, err := ld.Decode(b)
250
				if err != nil {
251
					out <- &ld.NodeOption{Err: err}
252 253
					return
				}
Jeromy's avatar
Jeromy committed
254

255
				out <- &ld.NodeOption{Node: nd}
Jeromy's avatar
Jeromy committed
256 257
				count++

258
			case <-ctx.Done():
259
				out <- &ld.NodeOption{Err: ctx.Err()}
Jeromy's avatar
Jeromy committed
260
				return
261 262 263
			}
		}
	}()
264
	return out
265 266
}

267
// GetLinks is the type of function passed to the EnumerateChildren function(s)
268 269
// for getting the children of an LD node.
type GetLinks func(context.Context, cid.Cid) ([]*ld.Link, error)
Jeromy's avatar
Jeromy committed
270

271
// GetLinksWithDAG returns a GetLinks function that tries to use the given
272
// NodeGetter as a LinkGetter to get the children of a given LD node. This may
273 274
// allow us to traverse the DAG without actually loading and parsing the node in
// question (if we already have the links cached).
275 276 277
func GetLinksWithDAG(ng ld.NodeGetter) GetLinks {
	return func(ctx context.Context, c cid.Cid) ([]*ld.Link, error) {
		return ld.GetLinks(ctx, ng, c)
Jeromy's avatar
Jeromy committed
278
	}
279
}
280

281 282 283 284
// defaultConcurrentFetch is the default maximum number of concurrent fetches
// that 'fetchNodes' will start at a time
const defaultConcurrentFetch = 32

285 286
// walkOptions represent the parameters of a graph walking algorithm
type walkOptions struct {
287
	SkipRoot     bool
288
	Concurrency  int
289
	ErrorHandler func(c cid.Cid, err error) error
290 291
}

292 293 294 295 296 297 298 299 300 301 302 303
// WalkOption is a setter for walkOptions
type WalkOption func(*walkOptions)

func (wo *walkOptions) addHandler(handler func(c cid.Cid, err error) error) {
	if wo.ErrorHandler != nil {
		wo.ErrorHandler = func(c cid.Cid, err error) error {
			return handler(c, wo.ErrorHandler(c, err))
		}
	} else {
		wo.ErrorHandler = handler
	}
}
304

305 306
// SkipRoot is a WalkOption indicating that the root node should skipped
func SkipRoot() WalkOption {
307
	return func(walkOptions *walkOptions) {
308
		walkOptions.SkipRoot = true
309 310 311 312 313 314 315 316
	}
}

// Concurrent is a WalkOption indicating that node fetching should be done in
// parallel, with the default concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrent() WalkOption {
317
	return func(walkOptions *walkOptions) {
318 319 320 321 322 323 324 325 326
		walkOptions.Concurrency = defaultConcurrentFetch
	}
}

// Concurrency is a WalkOption indicating that node fetching should be done in
// parallel, with a specific concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrency(worker int) WalkOption {
327
	return func(walkOptions *walkOptions) {
328 329 330 331
		walkOptions.Concurrency = worker
	}
}

332 333 334
// IgnoreErrors is a WalkOption indicating that the walk should attempt to
// continue even when an error occur.
func IgnoreErrors() WalkOption {
335 336 337 338 339 340 341 342 343 344 345 346
	return func(walkOptions *walkOptions) {
		walkOptions.addHandler(func(c cid.Cid, err error) error {
			return nil
		})
	}
}

// IgnoreMissing is a WalkOption indicating that the walk should continue when
// a node is missing.
func IgnoreMissing() WalkOption {
	return func(walkOptions *walkOptions) {
		walkOptions.addHandler(func(c cid.Cid, err error) error {
347
			if err == ld.ErrNotFound {
348 349 350 351 352 353 354 355 356 357 358 359
				return nil
			}
			return err
		})
	}
}

// OnMissing is a WalkOption adding a callback that will be triggered on a missing
// node.
func OnMissing(callback func(c cid.Cid)) WalkOption {
	return func(walkOptions *walkOptions) {
		walkOptions.addHandler(func(c cid.Cid, err error) error {
360
			if err == ld.ErrNotFound {
361 362 363 364 365 366 367 368 369 370 371 372
				callback(c)
			}
			return err
		})
	}
}

// OnError is a WalkOption adding a custom error handler.
// If this handler return a nil error, the walk will continue.
func OnError(handler func(c cid.Cid, err error) error) WalkOption {
	return func(walkOptions *walkOptions) {
		walkOptions.addHandler(handler)
373 374 375
	}
}

376
// WalkGraph will walk the dag in order (depth first) starting at the given root.
377
func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error {
378
	visitDepth := func(c cid.Cid, depth int) bool {
379 380 381
		return visit(c)
	}

382
	return WalkDepth(ctx, getLinks, c, visitDepth, options...)
383 384
}

385 386 387
// WalkDepth walks the dag starting at the given root and passes the current
// depth to a given visit function. The visit function can be used to limit DAG
// exploration.
388
func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error {
389
	opts := &walkOptions{}
390 391 392 393 394 395 396 397 398 399 400
	for _, opt := range options {
		opt(opts)
	}

	if opts.Concurrency > 1 {
		return parallelWalkDepth(ctx, getLinks, c, visit, opts)
	} else {
		return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts)
	}
}

401
func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *walkOptions) error {
402
	if !(options.SkipRoot && depth == 0) {
403 404 405
		if !visit(root, depth) {
			return nil
		}
406 407
	}

408
	links, err := getLinks(ctx, root)
409 410 411 412
	if err != nil && options.ErrorHandler != nil {
		err = options.ErrorHandler(root, err)
	}
	if err != nil {
413 414
		return err
	}
415

416
	for _, lnk := range links {
417
		if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil {
418
			return err
419 420 421 422
		}
	}
	return nil
}
Jeromy's avatar
Jeromy committed
423

424
// ProgressTracker is used to show progress when fetching nodes.
425 426 427 428 429
type ProgressTracker struct {
	Total int
	lk    sync.Mutex
}

430 431
// DeriveContext returns a new context with value "progress" derived from
// the given one.
432
func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context {
433
	return context.WithValue(ctx, progressContextKey, p)
434 435
}

436
// Increment adds one to the total progress.
437 438 439 440 441 442
func (p *ProgressTracker) Increment() {
	p.lk.Lock()
	defer p.lk.Unlock()
	p.Total++
}

443
// Value returns the current progress.
444 445 446 447 448 449
func (p *ProgressTracker) Value() int {
	p.lk.Lock()
	defer p.lk.Unlock()
	return p.Total
}

450
func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *walkOptions) error {
451
	type cidDepth struct {
452
		cid   cid.Cid
453 454 455 456
		depth int
	}

	type linksDepth struct {
457
		links []*ld.Link
458 459 460
		depth int
	}

461 462
	feed := make(chan cidDepth)
	out := make(chan linksDepth)
463 464
	done := make(chan struct{})

465
	var visitlk sync.Mutex
466
	var wg sync.WaitGroup
467

468 469
	errChan := make(chan error)
	fetchersCtx, cancel := context.WithCancel(ctx)
470
	defer wg.Wait()
471
	defer cancel()
472
	for i := 0; i < options.Concurrency; i++ {
473
		wg.Add(1)
474
		go func() {
475
			defer wg.Done()
476 477 478 479
			for cdepth := range feed {
				ci := cdepth.cid
				depth := cdepth.depth

480 481 482
				var shouldVisit bool

				// bypass the root if needed
483
				if !(options.SkipRoot && depth == 0) {
484 485 486 487 488 489
					visitlk.Lock()
					shouldVisit = visit(ci, depth)
					visitlk.Unlock()
				} else {
					shouldVisit = true
				}
490

491
				if shouldVisit {
492
					links, err := getLinks(ctx, ci)
493 494 495 496
					if err != nil && options.ErrorHandler != nil {
						err = options.ErrorHandler(root, err)
					}
					if err != nil {
Steven Allen's avatar
Steven Allen committed
497 498 499 500
						select {
						case errChan <- err:
						case <-fetchersCtx.Done():
						}
501 502 503
						return
					}

504
					outLinks := linksDepth{
505 506 507 508
						links: links,
						depth: depth + 1,
					}

509
					select {
510
					case out <- outLinks:
511
					case <-fetchersCtx.Done():
512 513 514
						return
					}
				}
Jeromy's avatar
Jeromy committed
515
				select {
516
				case done <- struct{}{}:
517
				case <-fetchersCtx.Done():
Jeromy's avatar
Jeromy committed
518 519
				}
			}
520
		}()
Jeromy's avatar
Jeromy committed
521
	}
522
	defer close(feed)
Jeromy's avatar
Jeromy committed
523

524
	send := feed
525
	var todoQueue []cidDepth
526
	var inProgress int
Jeromy's avatar
Jeromy committed
527

528
	next := cidDepth{
529 530
		cid:   root,
		depth: 0,
531
	}
532

533 534 535 536
	for {
		select {
		case send <- next:
			inProgress++
537 538 539
			if len(todoQueue) > 0 {
				next = todoQueue[0]
				todoQueue = todoQueue[1:]
540
			} else {
541
				next = cidDepth{}
542 543 544 545
				send = nil
			}
		case <-done:
			inProgress--
546
			if inProgress == 0 && !next.cid.Defined() {
547 548
				return nil
			}
549 550
		case linksDepth := <-out:
			for _, lnk := range linksDepth.links {
551
				cd := cidDepth{
552 553 554 555
					cid:   lnk.Cid,
					depth: linksDepth.depth,
				}

556
				if !next.cid.Defined() {
557
					next = cd
558 559
					send = feed
				} else {
560
					todoQueue = append(todoQueue, cd)
561
				}
Jeromy's avatar
Jeromy committed
562
			}
563 564
		case err := <-errChan:
			return err
565

566
		case <-ctx.Done():
567
			return ctx.Err()
568
		}
Jeromy's avatar
Jeromy committed
569 570
	}
}
571

572 573 574 575
var _ ld.LinkGetter = &dagService{}
var _ ld.NodeGetter = &dagService{}
var _ ld.NodeGetter = &sesGetter{}
var _ ld.DAGService = &dagService{}