Commit b00063cb authored by Jeromy's avatar Jeromy

refactor and clean up dagreader

parent 50433b18
...@@ -25,8 +25,8 @@ type DAGService interface { ...@@ -25,8 +25,8 @@ type DAGService interface {
// GetDAG returns, in order, all the single leve child // GetDAG returns, in order, all the single leve child
// nodes of the passed in node. // nodes of the passed in node.
GetDAG(context.Context, *Node) <-chan *Node GetDAG(context.Context, *Node) []NodeGetter
GetNodes(context.Context, []u.Key) <-chan *Node GetNodes(context.Context, []u.Key) []NodeGetter
} }
func NewDAGService(bs *bserv.BlockService) DAGService { func NewDAGService(bs *bserv.BlockService) DAGService {
...@@ -168,7 +168,7 @@ func FindLinks(links []u.Key, k u.Key, start int) []int { ...@@ -168,7 +168,7 @@ func FindLinks(links []u.Key, k u.Key, start int) []int {
// GetDAG will fill out all of the links of the given Node. // GetDAG will fill out all of the links of the given Node.
// It returns a channel of nodes, which the caller can receive // It returns a channel of nodes, which the caller can receive
// all the child nodes of 'root' on, in proper order. // all the child nodes of 'root' on, in proper order.
func (ds *dagService) GetDAG(ctx context.Context, root *Node) <-chan *Node { func (ds *dagService) GetDAG(ctx context.Context, root *Node) []NodeGetter {
var keys []u.Key var keys []u.Key
for _, lnk := range root.Links { for _, lnk := range root.Links {
keys = append(keys, u.Key(lnk.Hash)) keys = append(keys, u.Key(lnk.Hash))
...@@ -177,46 +177,69 @@ func (ds *dagService) GetDAG(ctx context.Context, root *Node) <-chan *Node { ...@@ -177,46 +177,69 @@ func (ds *dagService) GetDAG(ctx context.Context, root *Node) <-chan *Node {
return ds.GetNodes(ctx, keys) return ds.GetNodes(ctx, keys)
} }
func (ds *dagService) GetNodes(ctx context.Context, keys []u.Key) <-chan *Node { func (ds *dagService) GetNodes(ctx context.Context, keys []u.Key) []NodeGetter {
sig := make(chan *Node) promises := make([]NodeGetter, len(keys))
sendChans := make([]chan<- *Node, len(keys))
for i, _ := range keys {
promises[i], sendChans[i] = newNodePromise(ctx)
}
go func() { go func() {
defer close(sig)
blkchan := ds.Blocks.GetBlocks(ctx, keys) blkchan := ds.Blocks.GetBlocks(ctx, keys)
nodes := make([]*Node, len(keys))
next := 0
for { for {
select { select {
case blk, ok := <-blkchan: case blk, ok := <-blkchan:
if !ok { if !ok {
if next < len(nodes) {
log.Errorf("Did not receive correct number of nodes!")
}
return return
} }
nd, err := Decoded(blk.Data) nd, err := Decoded(blk.Data)
if err != nil { if err != nil {
// NB: can occur in normal situations, with improperly formatted // NB: can happen with improperly formatted input data
// input data
log.Error("Got back bad block!") log.Error("Got back bad block!")
break return
} }
is := FindLinks(keys, blk.Key(), next) is := FindLinks(keys, blk.Key(), 0)
for _, i := range is { for _, i := range is {
nodes[i] = nd sendChans[i] <- nd
} }
for ; next < len(nodes) && nodes[next] != nil; next++ {
select {
case sig <- nodes[next]:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
case <-ctx.Done(): }()
return return promises
}
func newNodePromise(ctx context.Context) (NodeGetter, chan<- *Node) {
ch := make(chan *Node, 1)
return &nodePromise{
recv: ch,
ctx: ctx,
}, ch
}
type nodePromise struct {
cache *Node
recv <-chan *Node
ctx context.Context
}
type NodeGetter interface {
Get() (*Node, error)
}
func (np *nodePromise) Get() (*Node, error) {
if np.cache != nil {
return np.cache, nil
} }
select {
case blk := <-np.recv:
np.cache = blk
case <-np.ctx.Done():
return nil, np.ctx.Err()
} }
}() return np.cache, nil
return sig
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment