Commit 56115933 authored by Michael Muré's avatar Michael Muré

pin: fix a too aggressive refactor and connect some contexts

parent ba91b68a
...@@ -281,7 +281,9 @@ func (p *pinner) isInternalPin(c cid.Cid) bool { ...@@ -281,7 +281,9 @@ func (p *pinner) isInternalPin(c cid.Cid) bool {
// IsPinned returns whether or not the given key is pinned // IsPinned returns whether or not the given key is pinned
// and an explanation of why its pinned // and an explanation of why its pinned
func (p *pinner) IsPinned(ctx context.Context, c cid.Cid) (string, bool, error) { func (p *pinner) IsPinned(ctx context.Context, c cid.Cid) (string, bool, error) {
return p.isPinnedWithType(c, Any) p.lock.RLock()
defer p.lock.RUnlock()
return p.isPinnedWithType(ctx, c, Any)
} }
// IsPinnedWithType returns whether or not the given cid is pinned with the // IsPinnedWithType returns whether or not the given cid is pinned with the
...@@ -289,12 +291,12 @@ func (p *pinner) IsPinned(ctx context.Context, c cid.Cid) (string, bool, error) ...@@ -289,12 +291,12 @@ func (p *pinner) IsPinned(ctx context.Context, c cid.Cid) (string, bool, error)
func (p *pinner) IsPinnedWithType(ctx context.Context, c cid.Cid, mode Mode) (string, bool, error) { func (p *pinner) IsPinnedWithType(ctx context.Context, c cid.Cid, mode Mode) (string, bool, error) {
p.lock.RLock() p.lock.RLock()
defer p.lock.RUnlock() defer p.lock.RUnlock()
return p.isPinnedWithType(c, mode) return p.isPinnedWithType(ctx, c, mode)
} }
// isPinnedWithType is the implementation of IsPinnedWithType that does not lock. // isPinnedWithType is the implementation of IsPinnedWithType that does not lock.
// intended for use by other pinned methods that already take locks // intended for use by other pinned methods that already take locks
func (p *pinner) isPinnedWithType(c cid.Cid, mode Mode) (string, bool, error) { func (p *pinner) isPinnedWithType(ctx context.Context, c cid.Cid, mode Mode) (string, bool, error) {
switch mode { switch mode {
case Any, Direct, Indirect, Recursive, Internal: case Any, Direct, Indirect, Recursive, Internal:
default: default:
...@@ -326,7 +328,7 @@ func (p *pinner) isPinnedWithType(c cid.Cid, mode Mode) (string, bool, error) { ...@@ -326,7 +328,7 @@ func (p *pinner) isPinnedWithType(c cid.Cid, mode Mode) (string, bool, error) {
// Default is Indirect // Default is Indirect
visitedSet := cid.NewSet() visitedSet := cid.NewSet()
for _, rc := range p.recursePin.Keys() { for _, rc := range p.recursePin.Keys() {
has, err := hasChild(p.dserv, rc, c, visitedSet.Visit) has, err := hasChild(ctx, p.dserv, rc, c, visitedSet.Visit)
if err != nil { if err != nil {
return "", false, err return "", false, err
} }
...@@ -361,7 +363,7 @@ func (p *pinner) CheckIfPinned(ctx context.Context, cids ...cid.Cid) ([]Pinned, ...@@ -361,7 +363,7 @@ func (p *pinner) CheckIfPinned(ctx context.Context, cids ...cid.Cid) ([]Pinned,
// Now walk all recursive pins to check for indirect pins // Now walk all recursive pins to check for indirect pins
var checkChildren func(cid.Cid, cid.Cid) error var checkChildren func(cid.Cid, cid.Cid) error
checkChildren = func(rk, parentKey cid.Cid) error { checkChildren = func(rk, parentKey cid.Cid) error {
links, err := ipld.GetLinks(context.TODO(), p.dserv, parentKey) links, err := ipld.GetLinks(ctx, p.dserv, parentKey)
if err != nil { if err != nil {
return err return err
} }
...@@ -607,8 +609,8 @@ func (p *pinner) PinWithMode(c cid.Cid, mode Mode) { ...@@ -607,8 +609,8 @@ func (p *pinner) PinWithMode(c cid.Cid, mode Mode) {
// hasChild recursively looks for a Cid among the children of a root Cid. // hasChild recursively looks for a Cid among the children of a root Cid.
// The visit function can be used to shortcut already-visited branches. // The visit function can be used to shortcut already-visited branches.
func hasChild(ng ipld.NodeGetter, root cid.Cid, child cid.Cid, visit func(cid.Cid) bool) (bool, error) { func hasChild(ctx context.Context, ng ipld.NodeGetter, root cid.Cid, child cid.Cid, visit func(cid.Cid) bool) (bool, error) {
links, err := ipld.GetLinks(context.TODO(), ng, root) links, err := ipld.GetLinks(ctx, ng, root)
if err != nil { if err != nil {
return false, err return false, err
} }
...@@ -618,7 +620,7 @@ func hasChild(ng ipld.NodeGetter, root cid.Cid, child cid.Cid, visit func(cid.Ci ...@@ -618,7 +620,7 @@ func hasChild(ng ipld.NodeGetter, root cid.Cid, child cid.Cid, visit func(cid.Ci
return true, nil return true, nil
} }
if visit(c) { if visit(c) {
has, err := hasChild(ng, c, child, visit) has, err := hasChild(ctx, ng, c, child, visit)
if err != nil { if err != nil {
return false, err return false, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment