Commit cdf4113c authored by Steven Allen's avatar Steven Allen

Optimize Deferred

We use this quite frequently so it should be fast.

Note: this removes the depth restriction because the algorithm is no longer recursive.
parent 3c783b99
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
) )
const maxCidLength = 100 const maxCidLength = 100
const maxHeaderSize = 9
func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { func ScanForLinks(br io.Reader, cb func(cid.Cid)) error {
buf := make([]byte, maxCidLength) buf := make([]byte, maxCidLength)
...@@ -104,14 +105,6 @@ type CBORMarshaler interface { ...@@ -104,14 +105,6 @@ type CBORMarshaler interface {
type Deferred struct { type Deferred struct {
Raw []byte Raw []byte
nestedLevel int
}
func (d *Deferred) Child() Deferred {
return Deferred{
Raw: nil,
nestedLevel: d.nestedLevel + 1,
}
} }
func (d *Deferred) MarshalCBOR(w io.Writer) error { func (d *Deferred) MarshalCBOR(w io.Writer) error {
...@@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { ...@@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error {
} }
func (d *Deferred) UnmarshalCBOR(br io.Reader) error { func (d *Deferred) UnmarshalCBOR(br io.Reader) error {
// TODO: theres a more efficient way to implement this method, but for now // Reuse any existing buffers.
// this is fine reusedBuf := d.Raw[:0]
maj, extra, err := CborReadHeader(br) d.Raw = nil
buf := bytes.NewBuffer(reusedBuf)
// Allocate some scratch space.
scratch := make([]byte, maxHeaderSize)
// Algorithm:
//
// 1. We start off expecting to read one element.
// 2. If we see a tag, we expect to read one more element so we increment "remaining".
// 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining".
// 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining".
// 5. While "remaining" is non-zero, read more elements.
// define this once so we don't keep allocating it.
limitedReader := io.LimitedReader{R: br}
for remaining := uint64(1); remaining > 0; remaining-- {
maj, extra, err := CborReadHeaderBuf(br, scratch)
if err != nil { if err != nil {
return err return err
} }
header := CborEncodeMajorType(maj, extra) if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil {
return err
switch maj {
case MajTag, MajArray, MajMap:
if d.nestedLevel >= MaxLength {
return maxLengthError
}
} }
switch maj { switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther: case MajUnsignedInt, MajNegativeInt, MajOther:
d.Raw = header // nothing fancy to do
return nil
case MajByteString, MajTextString: case MajByteString, MajTextString:
if extra > ByteArrayMaxLen { if extra > ByteArrayMaxLen {
return maxLengthError return maxLengthError
} }
buf := make([]byte, int(extra)+len(header)) // Copy the bytes
copy(buf, header) limitedReader.N = int64(extra)
if _, err := io.ReadFull(br, buf[len(header):]); err != nil { buf.Grow(int(extra))
if n, err := buf.ReadFrom(&limitedReader); err != nil {
return err return err
} else if n < int64(extra) {
return io.ErrUnexpectedEOF
} }
d.Raw = buf
return nil
case MajTag: case MajTag:
sub := d.Child() remaining++
if err := sub.UnmarshalCBOR(br); err != nil {
return err
}
d.Raw = append(header, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
return maxLengthError
}
return nil
case MajArray: case MajArray:
d.Raw = header if extra > MaxLength {
for i := 0; i < int(extra); i++ {
sub := d.Child()
if err := sub.UnmarshalCBOR(br); err != nil {
return err
}
d.Raw = append(d.Raw, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
return maxLengthError return maxLengthError
} }
} remaining += extra
return nil
case MajMap: case MajMap:
d.Raw = header if extra > MaxLength {
sub := d.Child()
for i := 0; i < int(extra*2); i++ {
sub.Raw = sub.Raw[:0]
if err := sub.UnmarshalCBOR(br); err != nil {
return err
}
d.Raw = append(d.Raw, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
return maxLengthError return maxLengthError
} }
} remaining += extra * 2
return nil
default: default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj) return fmt.Errorf("unhandled deferred cbor type: %d", maj)
} }
}
d.Raw = buf.Bytes()
return nil
} }
func readByte(r io.Reader) (byte, error) { func readByte(r io.Reader) (byte, error) {
......
...@@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) { ...@@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) {
t.Fatal("deferred: allowed more than the maximum allocation supported") t.Fatal("deferred: allowed more than the maximum allocation supported")
} }
} }
func TestDeferredMaxLengthRecursive(t *testing.T) {
var header bytes.Buffer
for i := 0; i < MaxLength+1; i++ {
if err := WriteMajorTypeHeader(&header, MajTag, 0); err != nil {
t.Fatal("failed to write header")
}
}
var deferred Deferred
err := deferred.UnmarshalCBOR(&header)
if err != maxLengthError {
t.Fatal("deferred: allowed more than the maximum number of elements")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment