Commit 70e5c88f authored by Jeromy's avatar Jeromy Committed by Jeromy

don't fail promises that already succeeded

License: MIT
Signed-off-by: default avatarJeromy <jeromyj@gmail.com>
parent 1535a6a9
...@@ -176,9 +176,8 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { ...@@ -176,9 +176,8 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter {
} }
promises := make([]NodeGetter, len(keys)) promises := make([]NodeGetter, len(keys))
sendChans := make([]chan<- *Node, len(keys))
for i := range keys { for i := range keys {
promises[i], sendChans[i] = newNodePromise(ctx) promises[i] = newNodePromise(ctx)
} }
dedupedKeys := dedupeKeys(keys) dedupedKeys := dedupeKeys(keys)
...@@ -199,7 +198,9 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { ...@@ -199,7 +198,9 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter {
} }
if opt.Err != nil { if opt.Err != nil {
log.Error("error fetching: ", opt.Err) for _, p := range promises {
p.Fail(opt.Err)
}
return return
} }
...@@ -214,7 +215,7 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { ...@@ -214,7 +215,7 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter {
is := FindLinks(keys, k, 0) is := FindLinks(keys, k, 0)
for _, i := range is { for _, i := range is {
count++ count++
sendChans[i] <- nd promises[i].Send(nd)
} }
case <-ctx.Done(): case <-ctx.Done():
return return
...@@ -237,18 +238,18 @@ func dedupeKeys(ks []key.Key) []key.Key { ...@@ -237,18 +238,18 @@ func dedupeKeys(ks []key.Key) []key.Key {
return out return out
} }
func newNodePromise(ctx context.Context) (NodeGetter, chan<- *Node) { func newNodePromise(ctx context.Context) NodeGetter {
ch := make(chan *Node, 1)
return &nodePromise{ return &nodePromise{
recv: ch, recv: make(chan *Node, 1),
ctx: ctx, ctx: ctx,
err: make(chan error, 1), err: make(chan error, 1),
}, ch }
} }
type nodePromise struct { type nodePromise struct {
cache *Node cache *Node
recv <-chan *Node clk sync.Mutex
recv chan *Node
ctx context.Context ctx context.Context
err chan error err chan error
} }
...@@ -260,20 +261,49 @@ type nodePromise struct { ...@@ -260,20 +261,49 @@ type nodePromise struct {
type NodeGetter interface { type NodeGetter interface {
Get(context.Context) (*Node, error) Get(context.Context) (*Node, error)
Fail(err error) Fail(err error)
Send(*Node)
} }
func (np *nodePromise) Fail(err error) { func (np *nodePromise) Fail(err error) {
np.clk.Lock()
v := np.cache
np.clk.Unlock()
// if promise has a value, don't fail it
if v != nil {
return
}
np.err <- err np.err <- err
} }
func (np *nodePromise) Get(ctx context.Context) (*Node, error) { func (np *nodePromise) Send(nd *Node) {
var already bool
np.clk.Lock()
if np.cache != nil { if np.cache != nil {
return np.cache, nil already = true
}
np.cache = nd
np.clk.Unlock()
if already {
panic("sending twice to the same promise is an error!")
}
np.recv <- nd
}
func (np *nodePromise) Get(ctx context.Context) (*Node, error) {
np.clk.Lock()
c := np.cache
np.clk.Unlock()
if c != nil {
return c, nil
} }
select { select {
case blk := <-np.recv: case nd := <-np.recv:
np.cache = blk return nd, nil
case <-np.ctx.Done(): case <-np.ctx.Done():
return nil, np.ctx.Err() return nil, np.ctx.Err()
case <-ctx.Done(): case <-ctx.Done():
...@@ -281,7 +311,6 @@ func (np *nodePromise) Get(ctx context.Context) (*Node, error) { ...@@ -281,7 +311,6 @@ func (np *nodePromise) Get(ctx context.Context) (*Node, error) {
case err := <-np.err: case err := <-np.err:
return nil, err return nil, err
} }
return np.cache, nil
} }
type Batch struct { type Batch struct {
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
imp "github.com/ipfs/go-ipfs/importer" imp "github.com/ipfs/go-ipfs/importer"
chunk "github.com/ipfs/go-ipfs/importer/chunk" chunk "github.com/ipfs/go-ipfs/importer/chunk"
. "github.com/ipfs/go-ipfs/merkledag" . "github.com/ipfs/go-ipfs/merkledag"
dstest "github.com/ipfs/go-ipfs/merkledag/test"
"github.com/ipfs/go-ipfs/pin" "github.com/ipfs/go-ipfs/pin"
uio "github.com/ipfs/go-ipfs/unixfs/io" uio "github.com/ipfs/go-ipfs/unixfs/io"
u "gx/ipfs/QmZNVWh8LLjAavuQ2JXuFmuYH3C11xo988vSgp7UQrTRj1/go-ipfs-util" u "gx/ipfs/QmZNVWh8LLjAavuQ2JXuFmuYH3C11xo988vSgp7UQrTRj1/go-ipfs-util"
...@@ -323,3 +324,46 @@ func TestEnumerateChildren(t *testing.T) { ...@@ -323,3 +324,46 @@ func TestEnumerateChildren(t *testing.T) {
traverse(root) traverse(root)
} }
func TestFetchFailure(t *testing.T) {
ds := dstest.Mock()
ds_bad := dstest.Mock()
top := new(Node)
for i := 0; i < 10; i++ {
nd := &Node{Data: []byte{byte('a' + i)}}
_, err := ds.Add(nd)
if err != nil {
t.Fatal(err)
}
err = top.AddNodeLinkClean(fmt.Sprintf("AA%d", i), nd)
if err != nil {
t.Fatal(err)
}
}
for i := 0; i < 10; i++ {
nd := &Node{Data: []byte{'f', 'a' + byte(i)}}
_, err := ds_bad.Add(nd)
if err != nil {
t.Fatal(err)
}
err = top.AddNodeLinkClean(fmt.Sprintf("BB%d", i), nd)
if err != nil {
t.Fatal(err)
}
}
getters := GetDAG(context.Background(), ds, top)
for i, getter := range getters {
_, err := getter.Get(context.Background())
if err != nil && i < 10 {
t.Fatal(err)
}
if err == nil && i >= 10 {
t.Fatal("should have failed request")
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment