Unverified Commit 5084f006 authored by Hannah Howard's avatar Hannah Howard Committed by GitHub

Merge pull request #48 from ipfs/feat/add-default-validation

Add a default validation policy
parents 5f97d3e8 6835fb7e
......@@ -8,6 +8,8 @@ import (
"testing"
"time"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
blocks "github.com/ipfs/go-block-format"
......@@ -21,6 +23,7 @@ import (
"github.com/ipfs/go-graphsync/testutil"
ipld "github.com/ipld/go-ipld-prime"
ipldselector "github.com/ipld/go-ipld-prime/traversal/selector"
"github.com/ipld/go-ipld-prime/traversal/selector/builder"
"github.com/libp2p/go-libp2p-core/peer"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
mh "github.com/multiformats/go-multihash"
......@@ -165,16 +168,12 @@ func TestMakeRequestToNetwork(t *testing.T) {
blockChainLength := 100
blockChain := setupBlockChain(ctx, t, storer, bridge, 100, blockChainLength)
spec, err := bridge.BuildSelector(func(ssb ipldbridge.SelectorSpecBuilder) ipldbridge.SelectorSpec {
return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
}))
})
if err != nil {
t.Fatal("Failed creating selector")
}
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
})).Node()
extensionData := testutil.RandomBytes(100)
extensionName := graphsync.ExtensionName("AppleSauce/McGee")
......@@ -262,16 +261,12 @@ func TestSendResponseToIncomingRequest(t *testing.T) {
blockChainLength := 100
blockChain := setupBlockChain(ctx, t, storer, bridge, 100, blockChainLength)
spec, err := bridge.BuildSelector(func(ssb ipldbridge.SelectorSpecBuilder) ipldbridge.SelectorSpec {
return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
}))
})
if err != nil {
t.Fatal("Failed creating selector")
}
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
})).Node()
selectorData, err := bridge.EncodeNode(spec)
if err != nil {
......@@ -360,16 +355,12 @@ func TestGraphsyncRoundTrip(t *testing.T) {
// initialize graphsync on second node to response to requests
New(ctx, gsnet2, bridge2, loader2, storer2)
spec, err := bridge1.BuildSelector(func(ssb ipldbridge.SelectorSpecBuilder) ipldbridge.SelectorSpec {
return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
}))
})
if err != nil {
t.Fatal("Failed creating selector")
}
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
})).Node()
progressChan, errChan := requestor.Request(ctx, host2.ID(), blockChain.tipLink, spec)
......@@ -458,16 +449,12 @@ func TestRoundTripLargeBlocksSlowNetwork(t *testing.T) {
// initialize graphsync on second node to response to requests
New(ctx, gsnet2, bridge2, loader2, storer2)
spec, err := bridge1.BuildSelector(func(ssb ipldbridge.SelectorSpecBuilder) ipldbridge.SelectorSpec {
return ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
}))
})
if err != nil {
t.Fatal("Failed creating selector")
}
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
spec := ssb.ExploreRecursive(ipldselector.RecursionLimitDepth(blockChainLength),
ssb.ExploreFields(func(efsb ipldbridge.ExploreFieldsSpecBuilder) {
efsb.Insert("Parents", ssb.ExploreAll(
ssb.ExploreRecursiveEdge()))
})).Node()
progressChan, errChan := requestor.Request(ctx, host2.ID(), blockChain.tipLink, spec)
......
......@@ -11,7 +11,6 @@ import (
free "github.com/ipld/go-ipld-prime/impl/free"
ipldtraversal "github.com/ipld/go-ipld-prime/traversal"
ipldselector "github.com/ipld/go-ipld-prime/traversal/selector"
selectorbuilder "github.com/ipld/go-ipld-prime/traversal/selector/builder"
)
// TraversalConfig is an alias from ipld, in case it's renamed/moved.
......@@ -49,18 +48,6 @@ func (rb *ipldBridge) BuildNode(buildFn func(NodeBuilder) ipld.Node) (ipld.Node,
return node, nil
}
func (rb *ipldBridge) BuildSelector(buildFn func(SelectorSpecBuilder) SelectorSpec) (ipld.Node, error) {
var node ipld.Node
err := fluent.Recover(func() {
ssb := selectorbuilder.NewSelectorSpecBuilder(free.NodeBuilder())
node = buildFn(ssb).Node()
})
if err != nil {
return nil, err
}
return node, nil
}
func (rb *ipldBridge) Traverse(ctx context.Context, loader Loader, root ipld.Link, s Selector, fn AdvVisitFn) error {
node, err := root.Load(ctx, LinkContext{}, free.NodeBuilder(), loader)
if err != nil {
......@@ -74,6 +61,10 @@ func (rb *ipldBridge) Traverse(ctx context.Context, loader Loader, root ipld.Lin
}.WalkAdv(node, s, fn)
}
func (rb *ipldBridge) WalkMatching(node ipld.Node, s Selector, fn VisitFn) error {
return ipldtraversal.WalkMatching(node, s, fn)
}
func (rb *ipldBridge) EncodeNode(node ipld.Node) ([]byte, error) {
var buffer bytes.Buffer
err := dagcbor.Encoder(node, &buffer)
......
......@@ -28,6 +28,9 @@ type Storer = ipld.Storer
// StoreCommitter is an alias from ipld, in case it's renamed/moved.
type StoreCommitter = ipld.StoreCommitter
// VisitFn is an alias from ipld, in case it's renamed/moved
type VisitFn = ipldtraversal.VisitFn
// AdvVisitFn is an alias from ipld, in case it's renamed/moved.
type AdvVisitFn = ipldtraversal.AdvVisitFn
......@@ -77,10 +80,6 @@ type IPLDBridge interface {
// interface
BuildNode(func(NodeBuilder) ipld.Node) (ipld.Node, error)
// BuildSelector provides a mechanism to build selector nodes quickly with
// ipld's SelectorSpecBuilder
BuildSelector(func(SelectorSpecBuilder) SelectorSpec) (ipld.Node, error)
// EncodeNode encodes an IPLD Node to bytes for network transfer.
EncodeNode(ipld.Node) ([]byte, error)
......@@ -95,4 +94,7 @@ type IPLDBridge interface {
// and the given link loader. The given visit function will be called for each node
// visited.
Traverse(ctx context.Context, loader Loader, root ipld.Link, s Selector, fn AdvVisitFn) error
// WalkMatching is a wrapper around direct selector traversal
WalkMatching(node ipld.Node, s Selector, fn VisitFn) error
}
......@@ -4,12 +4,12 @@ import (
"context"
"time"
cid "github.com/ipfs/go-cid"
"github.com/ipfs/go-graphsync"
"github.com/ipfs/go-graphsync/ipldbridge"
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/responsemanager/loader"
"github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager"
"github.com/ipfs/go-graphsync/responsemanager/selectorvalidator"
"github.com/ipfs/go-peertaskqueue/peertask"
ipld "github.com/ipld/go-ipld-prime"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
......@@ -18,14 +18,14 @@ import (
const (
maxInProcessRequests = 6
maxRecursionDepth = 100
thawSpeed = time.Millisecond * 100
)
type inProgressResponseStatus struct {
ctx context.Context
cancelFn func()
root cid.Cid
selector []byte
request gsmsg.GraphSyncRequest
}
type responseKey struct {
......@@ -34,9 +34,8 @@ type responseKey struct {
}
type responseTaskData struct {
ctx context.Context
root cid.Cid
selector []byte
ctx context.Context
request gsmsg.GraphSyncRequest
}
// QueryQueue is an interface that can receive new selector query tasks
......@@ -163,7 +162,7 @@ func (rm *ResponseManager) processQueriesWorker() {
case <-rm.ctx.Done():
return
}
rm.executeQuery(taskData.ctx, key.p, key.requestID, taskData.root, taskData.selector)
rm.executeQuery(taskData.ctx, key.p, taskData.request)
select {
case rm.messages <- &finishResponseRequest{key}:
case <-rm.ctx.Done():
......@@ -181,28 +180,31 @@ func noopVisitor(tp ipldbridge.TraversalProgress, n ipld.Node, tr ipldbridge.Tra
func (rm *ResponseManager) executeQuery(ctx context.Context,
p peer.ID,
requestID graphsync.RequestID,
root cid.Cid,
selectorBytes []byte) {
request gsmsg.GraphSyncRequest) {
peerResponseSender := rm.peerManager.SenderForPeer(p)
selectorSpec, err := rm.ipldBridge.DecodeNode(selectorBytes)
selectorSpec, err := rm.ipldBridge.DecodeNode(request.Selector())
if err != nil {
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
err = selectorvalidator.ValidateSelector(rm.ipldBridge, selectorSpec, maxRecursionDepth)
if err != nil {
peerResponseSender.FinishWithError(requestID, graphsync.RequestFailedUnknown)
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
rootLink := cidlink.Link{Cid: root}
rootLink := cidlink.Link{Cid: request.Root()}
selector, err := rm.ipldBridge.ParseSelector(selectorSpec)
if err != nil {
peerResponseSender.FinishWithError(requestID, graphsync.RequestFailedUnknown)
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
wrappedLoader := loader.WrapLoader(rm.loader, requestID, peerResponseSender)
wrappedLoader := loader.WrapLoader(rm.loader, request.ID(), peerResponseSender)
err = rm.ipldBridge.Traverse(ctx, wrappedLoader, rootLink, selector, noopVisitor)
if err != nil {
peerResponseSender.FinishWithError(requestID, graphsync.RequestFailedUnknown)
peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown)
return
}
peerResponseSender.FinishRequest(requestID)
peerResponseSender.FinishRequest(request.ID())
}
// Startup starts processing for the WantManager.
......@@ -246,8 +248,7 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) {
inProgressResponseStatus{
ctx: ctx,
cancelFn: cancelFn,
root: request.Root(),
selector: request.Selector(),
request: request,
}
rm.queryQueue.PushBlock(prm.p, peertask.Task{Identifier: key, Priority: int(request.Priority())})
select {
......@@ -268,7 +269,7 @@ func (rdr *responseDataRequest) handle(rm *ResponseManager) {
response, ok := rm.inProgressResponses[rdr.key]
var taskData *responseTaskData
if ok {
taskData = &responseTaskData{response.ctx, response.root, response.selector}
taskData = &responseTaskData{response.ctx, response.request}
} else {
taskData = nil
}
......
package selectorvalidator
import (
"errors"
"github.com/ipfs/go-graphsync/ipldbridge"
ipld "github.com/ipld/go-ipld-prime"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
"github.com/ipld/go-ipld-prime/traversal"
"github.com/ipld/go-ipld-prime/traversal/selector"
"github.com/ipld/go-ipld-prime/traversal/selector/builder"
)
var (
// ErrInvalidLimit means this type of recursive selector limit is not supported by default
// -- to prevent DDOS attacks
ErrInvalidLimit = errors.New("unsupported recursive selector limit")
)
// ValidateSelector applies the default selector validation policy to a selector
// on an incoming request -- which by default is to limit recursive selectors
// to a fixed depth
func ValidateSelector(bridge ipldbridge.IPLDBridge, node ipld.Node, maxAcceptedDepth int) error {
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
// this selector is a selector for traversing selectors...
// it traverses the various selector types looking for recursion limit fields
// and matches them
s, err := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_ExploreRecursive, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Limit, ssb.Matcher())
efsb.Insert(selector.SelectorKey_Sequence, ssb.ExploreRecursiveEdge())
}))
efsb.Insert(selector.SelectorKey_ExploreFields, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Fields, ssb.ExploreAll(ssb.ExploreRecursiveEdge()))
}))
efsb.Insert(selector.SelectorKey_ExploreUnion, ssb.ExploreAll(ssb.ExploreRecursiveEdge()))
efsb.Insert(selector.SelectorKey_ExploreAll, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Next, ssb.ExploreRecursiveEdge())
}))
efsb.Insert(selector.SelectorKey_ExploreIndex, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Next, ssb.ExploreRecursiveEdge())
}))
efsb.Insert(selector.SelectorKey_ExploreRange, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Next, ssb.ExploreRecursiveEdge())
}))
efsb.Insert(selector.SelectorKey_ExploreConditional, ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert(selector.SelectorKey_Next, ssb.ExploreRecursiveEdge())
}))
})).Selector()
if err != nil {
return err
}
return bridge.WalkMatching(node, s, func(progress traversal.Progress, visited ipld.Node) error {
if visited.ReprKind() != ipld.ReprKind_Map || visited.Length() != 1 {
return ErrInvalidLimit
}
kn, v, _ := visited.MapIterator().Next()
kstr, _ := kn.AsString()
switch kstr {
case selector.SelectorKey_LimitDepth:
maxDepthValue, err := v.AsInt()
if err != nil {
return ErrInvalidLimit
}
if maxDepthValue > maxAcceptedDepth {
return ErrInvalidLimit
}
return nil
case selector.SelectorKey_LimitNone:
return ErrInvalidLimit
default:
return ErrInvalidLimit
}
})
}
package selectorvalidator
import (
"testing"
ipld "github.com/ipld/go-ipld-prime"
ipldfree "github.com/ipld/go-ipld-prime/impl/free"
"github.com/ipfs/go-graphsync/ipldbridge"
"github.com/ipld/go-ipld-prime/traversal/selector"
"github.com/ipld/go-ipld-prime/traversal/selector/builder"
)
func TestValidateSelector(t *testing.T) {
bridge := ipldbridge.NewIPLDBridge()
ssb := builder.NewSelectorSpecBuilder(ipldfree.NodeBuilder())
successBase := ssb.ExploreRecursive(selector.RecursionLimitDepth(80), ssb.ExploreRecursiveEdge())
failBase := ssb.ExploreRecursive(selector.RecursionLimitDepth(120), ssb.ExploreRecursiveEdge())
failNoneBase := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreRecursiveEdge())
verifyOutcomes := func(t *testing.T, success ipld.Node, fail ipld.Node, failNone ipld.Node) {
err := ValidateSelector(bridge, success, 100)
if err != nil {
t.Fatal("valid selector returned error")
}
err = ValidateSelector(bridge, fail, 100)
if err != ErrInvalidLimit {
t.Fatal("selector should have failed on invalid limit")
}
err = ValidateSelector(bridge, failNone, 100)
if err != ErrInvalidLimit {
t.Fatal("selector should have failed on invalid limit")
}
}
t.Run("ExploreRecursive", func(t *testing.T) {
success := successBase.Node()
fail := failBase.Node()
failNone := failNoneBase.Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("ExploreAll", func(t *testing.T) {
success := ssb.ExploreAll(successBase).Node()
fail := ssb.ExploreAll(failBase).Node()
failNone := ssb.ExploreAll(failNoneBase).Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("ExploreIndex", func(t *testing.T) {
success := ssb.ExploreIndex(0, successBase).Node()
fail := ssb.ExploreIndex(0, failBase).Node()
failNone := ssb.ExploreIndex(0, failNoneBase).Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("ExploreRange", func(t *testing.T) {
success := ssb.ExploreRange(0, 10, successBase).Node()
fail := ssb.ExploreRange(0, 10, failBase).Node()
failNone := ssb.ExploreRange(0, 10, failNoneBase).Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("ExploreUnion", func(t *testing.T) {
success := ssb.ExploreUnion(successBase, successBase).Node()
fail := ssb.ExploreUnion(successBase, failBase).Node()
failNone := ssb.ExploreUnion(successBase, failNoneBase).Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("ExploreFields", func(t *testing.T) {
success := ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert("apples", successBase)
efsb.Insert("oranges", successBase)
}).Node()
fail := ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert("apples", successBase)
efsb.Insert("oranges", failBase)
}).Node()
failNone := ssb.ExploreFields(func(efsb builder.ExploreFieldsSpecBuilder) {
efsb.Insert("apples", successBase)
efsb.Insert("oranges", failNoneBase)
}).Node()
verifyOutcomes(t, success, fail, failNone)
})
t.Run("nested ExploreRecursive", func(t *testing.T) {
success := ssb.ExploreRecursive(
selector.RecursionLimitDepth(10),
ssb.ExploreUnion(
ssb.ExploreAll(ssb.ExploreRecursiveEdge()),
ssb.ExploreIndex(0, successBase),
),
).Node()
fail := ssb.ExploreRecursive(
selector.RecursionLimitDepth(10),
ssb.ExploreUnion(
ssb.ExploreAll(ssb.ExploreRecursiveEdge()),
ssb.ExploreIndex(0, failBase),
),
).Node()
failNone := ssb.ExploreRecursive(
selector.RecursionLimitDepth(10),
ssb.ExploreUnion(
ssb.ExploreAll(ssb.ExploreRecursiveEdge()),
ssb.ExploreIndex(0, failNoneBase),
),
).Node()
verifyOutcomes(t, success, fail, failNone)
})
}
......@@ -14,7 +14,6 @@ import (
"github.com/ipld/go-ipld-prime/fluent"
free "github.com/ipld/go-ipld-prime/impl/free"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
selectorbuilder "github.com/ipld/go-ipld-prime/traversal/selector/builder"
multihash "github.com/multiformats/go-multihash"
)
......@@ -52,18 +51,6 @@ func (mb *mockIPLDBridge) BuildNode(buildFn func(ipldbridge.NodeBuilder) ipld.No
return node, nil
}
func (mb *mockIPLDBridge) BuildSelector(buildFn func(ipldbridge.SelectorSpecBuilder) ipldbridge.SelectorSpec) (ipld.Node, error) {
var node ipld.Node
err := fluent.Recover(func() {
ssb := selectorbuilder.NewSelectorSpecBuilder(free.NodeBuilder())
node = buildFn(ssb).Node()
})
if err != nil {
return nil, err
}
return node, nil
}
func (mb *mockIPLDBridge) EncodeNode(node ipld.Node) ([]byte, error) {
spec, ok := node.(*mockSelectorSpec)
if ok {
......@@ -125,6 +112,10 @@ func (mb *mockIPLDBridge) Traverse(ctx context.Context, loader ipldbridge.Loader
return nil
}
func (mb *mockIPLDBridge) WalkMatching(node ipld.Node, s ipldbridge.Selector, fn ipldbridge.VisitFn) error {
return nil
}
func loadNode(lnk cid.Cid, loader ipldbridge.Loader) (ipld.Node, error) {
r, err := loader(cidlink.Link{Cid: lnk}, ipldbridge.LinkContext{})
if err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment