Commit cdf4113c authored by Steven Allen's avatar Steven Allen

Optimize Deferred

We use this quite frequently so it should be fast.

Note: this removes the depth restriction because the algorithm is no longer recursive.
parent 3c783b99
......@@ -14,6 +14,7 @@ import (
)
const maxCidLength = 100
const maxHeaderSize = 9
func ScanForLinks(br io.Reader, cb func(cid.Cid)) error {
buf := make([]byte, maxCidLength)
......@@ -103,15 +104,7 @@ type CBORMarshaler interface {
}
type Deferred struct {
Raw []byte
nestedLevel int
}
func (d *Deferred) Child() Deferred {
return Deferred{
Raw: nil,
nestedLevel: d.nestedLevel + 1,
}
Raw []byte
}
func (d *Deferred) MarshalCBOR(w io.Writer) error {
......@@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error {
}
func (d *Deferred) UnmarshalCBOR(br io.Reader) error {
// TODO: theres a more efficient way to implement this method, but for now
// this is fine
maj, extra, err := CborReadHeader(br)
if err != nil {
return err
}
header := CborEncodeMajorType(maj, extra)
switch maj {
case MajTag, MajArray, MajMap:
if d.nestedLevel >= MaxLength {
return maxLengthError
}
}
switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther:
d.Raw = header
return nil
case MajByteString, MajTextString:
if extra > ByteArrayMaxLen {
return maxLengthError
}
buf := make([]byte, int(extra)+len(header))
copy(buf, header)
if _, err := io.ReadFull(br, buf[len(header):]); err != nil {
// Reuse any existing buffers.
reusedBuf := d.Raw[:0]
d.Raw = nil
buf := bytes.NewBuffer(reusedBuf)
// Allocate some scratch space.
scratch := make([]byte, maxHeaderSize)
// Algorithm:
//
// 1. We start off expecting to read one element.
// 2. If we see a tag, we expect to read one more element so we increment "remaining".
// 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining".
// 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining".
// 5. While "remaining" is non-zero, read more elements.
// define this once so we don't keep allocating it.
limitedReader := io.LimitedReader{R: br}
for remaining := uint64(1); remaining > 0; remaining-- {
maj, extra, err := CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
d.Raw = buf
return nil
case MajTag:
sub := d.Child()
if err := sub.UnmarshalCBOR(br); err != nil {
if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil {
return err
}
d.Raw = append(header, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
return maxLengthError
}
return nil
case MajArray:
d.Raw = header
for i := 0; i < int(extra); i++ {
sub := d.Child()
if err := sub.UnmarshalCBOR(br); err != nil {
return err
}
d.Raw = append(d.Raw, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther:
// nothing fancy to do
case MajByteString, MajTextString:
if extra > ByteArrayMaxLen {
return maxLengthError
}
}
return nil
case MajMap:
d.Raw = header
sub := d.Child()
for i := 0; i < int(extra*2); i++ {
sub.Raw = sub.Raw[:0]
if err := sub.UnmarshalCBOR(br); err != nil {
// Copy the bytes
limitedReader.N = int64(extra)
buf.Grow(int(extra))
if n, err := buf.ReadFrom(&limitedReader); err != nil {
return err
} else if n < int64(extra) {
return io.ErrUnexpectedEOF
}
d.Raw = append(d.Raw, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
case MajTag:
remaining++
case MajArray:
if extra > MaxLength {
return maxLengthError
}
remaining += extra
case MajMap:
if extra > MaxLength {
return maxLengthError
}
remaining += extra * 2
default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj)
}
return nil
default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj)
}
d.Raw = buf.Bytes()
return nil
}
func readByte(r io.Reader) (byte, error) {
......
......@@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) {
t.Fatal("deferred: allowed more than the maximum allocation supported")
}
}
func TestDeferredMaxLengthRecursive(t *testing.T) {
var header bytes.Buffer
for i := 0; i < MaxLength+1; i++ {
if err := WriteMajorTypeHeader(&header, MajTag, 0); err != nil {
t.Fatal("failed to write header")
}
}
var deferred Deferred
err := deferred.UnmarshalCBOR(&header)
if err != maxLengthError {
t.Fatal("deferred: allowed more than the maximum number of elements")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment