diff --git a/traversal/selector/exploreRecursive.go b/traversal/selector/exploreRecursive.go index edd9645b98f497333d6a72e353eef67856567d96..be61b529b86630339b2ccb9f55ee81c7ea83676a 100644 --- a/traversal/selector/exploreRecursive.go +++ b/traversal/selector/exploreRecursive.go @@ -45,17 +45,58 @@ func (s ExploreRecursive) Interests() []PathSegment { // Explore returns the node's selector for all fields func (s ExploreRecursive) Explore(n ipld.Node, p PathSegment) Selector { nextSelector := s.current.Explore(n, p) + maxDepth := s.maxDepth if nextSelector == nil { return nil } - _, ok := nextSelector.(ExploreRecursiveEdge) - if !ok { - return ExploreRecursive{s.sequence, nextSelector, s.maxDepth} + if !s.hasRecursiveEdge(nextSelector) { + return ExploreRecursive{s.sequence, nextSelector, maxDepth} } - if s.maxDepth < 2 { - return nil + if maxDepth < 2 { + return s.replaceRecursiveEdge(nextSelector, nil) + } + return ExploreRecursive{s.sequence, s.replaceRecursiveEdge(nextSelector, s.sequence), s.maxDepth - 1} +} + +func (s ExploreRecursive) hasRecursiveEdge(nextSelector Selector) bool { + _, isRecursiveEdge := nextSelector.(ExploreRecursiveEdge) + if isRecursiveEdge { + return true + } + exploreUnion, isUnion := nextSelector.(ExploreUnion) + if isUnion { + for _, selector := range exploreUnion.Members { + if s.hasRecursiveEdge(selector) { + return true + } + } + } + return false +} + +func (s ExploreRecursive) replaceRecursiveEdge(nextSelector Selector, replacement Selector) Selector { + _, isRecursiveEdge := nextSelector.(ExploreRecursiveEdge) + if isRecursiveEdge { + return replacement + } + exploreUnion, isUnion := nextSelector.(ExploreUnion) + if isUnion { + replacementMembers := make([]Selector, 0, len(exploreUnion.Members)) + for _, selector := range exploreUnion.Members { + newSelector := s.replaceRecursiveEdge(selector, replacement) + if newSelector != nil { + replacementMembers = append(replacementMembers, newSelector) + } + } + if len(replacementMembers) == 0 { + return nil + } + if len(replacementMembers) == 1 { + return replacementMembers[0] + } + return ExploreUnion{replacementMembers} } - return ExploreRecursive{s.sequence, s.sequence, s.maxDepth - 1} + return nextSelector } // Decide always returns false because this is not a matcher diff --git a/traversal/selector/exploreRecursive_test.go b/traversal/selector/exploreRecursive_test.go index 04f3686ecaae174e1afdd68ba1c9b8201cde0e55..330682268125fc0d017c193cf40286ca78126c39 100644 --- a/traversal/selector/exploreRecursive_test.go +++ b/traversal/selector/exploreRecursive_test.go @@ -277,4 +277,44 @@ func TestExploreRecursiveExplore(t *testing.T) { Wish(t, rs, ShouldEqual, ExploreRecursive{subTree, ExploreRecursive{sideSelector, sideSelector, maxDepth - 2}, maxDepth - 1}) Wish(t, err, ShouldEqual, nil) }) + t.Run("exploring should work with explore union and recursion", func(t *testing.T) { + parentsSelector := ExploreUnion{[]Selector{ExploreAll{Matcher{}}, ExploreIndex{recursiveEdge, [1]PathSegment{PathSegmentInt{0}}}}} + subTree := ExploreFields{map[string]Selector{"Parents": parentsSelector}, []PathSegment{PathSegmentString{S: "Parents"}}} + rs = ExploreRecursive{subTree, subTree, maxDepth} + nodeString := `{ + "Parents": [ + { + "Parents": [ + { + "Parents": [ + { + "Parents": [] + } + ] + } + ] + } + ] + } + ` + rn, err := dagjson.Decoder(ipldfree.NodeBuilder(), bytes.NewBufferString(nodeString)) + Wish(t, err, ShouldEqual, nil) + rs = rs.Explore(rn, PathSegmentString{S: "Parents"}) + rn, err = rn.TraverseField("Parents") + Wish(t, rs, ShouldEqual, ExploreRecursive{subTree, parentsSelector, maxDepth}) + Wish(t, err, ShouldEqual, nil) + rs = rs.Explore(rn, PathSegmentInt{I: 0}) + rn, err = rn.TraverseIndex(0) + Wish(t, rs, ShouldEqual, ExploreRecursive{subTree, ExploreUnion{[]Selector{Matcher{}, subTree}}, maxDepth - 1}) + Wish(t, err, ShouldEqual, nil) + rs = rs.Explore(rn, PathSegmentString{S: "Parents"}) + + rn, err = rn.TraverseField("Parents") + Wish(t, rs, ShouldEqual, ExploreRecursive{subTree, parentsSelector, maxDepth - 1}) + Wish(t, err, ShouldEqual, nil) + rs = rs.Explore(rn, PathSegmentInt{I: 0}) + rn, err = rn.TraverseIndex(0) + Wish(t, rs, ShouldEqual, ExploreRecursive{subTree, ExploreUnion{[]Selector{Matcher{}, subTree}}, maxDepth - 2}) + Wish(t, err, ShouldEqual, nil) + }) }