Commit cdf4113c authored by Steven Allen's avatar Steven Allen

Optimize Deferred

We use this quite frequently so it should be fast.

Note: this removes the depth restriction because the algorithm is no longer recursive.
parent 3c783b99
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
) )
const maxCidLength = 100 const maxCidLength = 100
const maxHeaderSize = 9
func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { func ScanForLinks(br io.Reader, cb func(cid.Cid)) error {
buf := make([]byte, maxCidLength) buf := make([]byte, maxCidLength)
...@@ -103,15 +104,7 @@ type CBORMarshaler interface { ...@@ -103,15 +104,7 @@ type CBORMarshaler interface {
} }
type Deferred struct { type Deferred struct {
Raw []byte Raw []byte
nestedLevel int
}
func (d *Deferred) Child() Deferred {
return Deferred{
Raw: nil,
nestedLevel: d.nestedLevel + 1,
}
} }
func (d *Deferred) MarshalCBOR(w io.Writer) error { func (d *Deferred) MarshalCBOR(w io.Writer) error {
...@@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { ...@@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error {
} }
func (d *Deferred) UnmarshalCBOR(br io.Reader) error { func (d *Deferred) UnmarshalCBOR(br io.Reader) error {
// TODO: theres a more efficient way to implement this method, but for now // Reuse any existing buffers.
// this is fine reusedBuf := d.Raw[:0]
maj, extra, err := CborReadHeader(br) d.Raw = nil
if err != nil { buf := bytes.NewBuffer(reusedBuf)
return err
} // Allocate some scratch space.
header := CborEncodeMajorType(maj, extra) scratch := make([]byte, maxHeaderSize)
switch maj { // Algorithm:
case MajTag, MajArray, MajMap: //
if d.nestedLevel >= MaxLength { // 1. We start off expecting to read one element.
return maxLengthError // 2. If we see a tag, we expect to read one more element so we increment "remaining".
} // 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining".
} // 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining".
// 5. While "remaining" is non-zero, read more elements.
switch maj {
case MajUnsignedInt, MajNegativeInt, MajOther: // define this once so we don't keep allocating it.
d.Raw = header limitedReader := io.LimitedReader{R: br}
return nil for remaining := uint64(1); remaining > 0; remaining-- {
case MajByteString, MajTextString: maj, extra, err := CborReadHeaderBuf(br, scratch)
if extra > ByteArrayMaxLen { if err != nil {
return maxLengthError
}
buf := make([]byte, int(extra)+len(header))
copy(buf, header)
if _, err := io.ReadFull(br, buf[len(header):]); err != nil {
return err return err
} }
if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil {
d.Raw = buf
return nil
case MajTag:
sub := d.Child()
if err := sub.UnmarshalCBOR(br); err != nil {
return err return err
} }
d.Raw = append(header, sub.Raw...) switch maj {
if len(d.Raw) > ByteArrayMaxLen { case MajUnsignedInt, MajNegativeInt, MajOther:
return maxLengthError // nothing fancy to do
} case MajByteString, MajTextString:
return nil if extra > ByteArrayMaxLen {
case MajArray:
d.Raw = header
for i := 0; i < int(extra); i++ {
sub := d.Child()
if err := sub.UnmarshalCBOR(br); err != nil {
return err
}
d.Raw = append(d.Raw, sub.Raw...)
if len(d.Raw) > ByteArrayMaxLen {
return maxLengthError return maxLengthError
} }
} // Copy the bytes
return nil limitedReader.N = int64(extra)
case MajMap: buf.Grow(int(extra))
d.Raw = header if n, err := buf.ReadFrom(&limitedReader); err != nil {
sub := d.Child()
for i := 0; i < int(extra*2); i++ {
sub.Raw = sub.Raw[:0]
if err := sub.UnmarshalCBOR(br); err != nil {
return err return err
} else if n < int64(extra) {
return io.ErrUnexpectedEOF
} }
d.Raw = append(d.Raw, sub.Raw...) case MajTag:
if len(d.Raw) > ByteArrayMaxLen { remaining++
case MajArray:
if extra > MaxLength {
return maxLengthError return maxLengthError
} }
remaining += extra
case MajMap:
if extra > MaxLength {
return maxLengthError
}
remaining += extra * 2
default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj)
} }
return nil
default:
return fmt.Errorf("unhandled deferred cbor type: %d", maj)
} }
d.Raw = buf.Bytes()
return nil
} }
func readByte(r io.Reader) (byte, error) { func readByte(r io.Reader) (byte, error) {
......
...@@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) { ...@@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) {
t.Fatal("deferred: allowed more than the maximum allocation supported") t.Fatal("deferred: allowed more than the maximum allocation supported")
} }
} }
func TestDeferredMaxLengthRecursive(t *testing.T) {
var header bytes.Buffer
for i := 0; i < MaxLength+1; i++ {
if err := WriteMajorTypeHeader(&header, MajTag, 0); err != nil {
t.Fatal("failed to write header")
}
}
var deferred Deferred
err := deferred.UnmarshalCBOR(&header)
if err != maxLengthError {
t.Fatal("deferred: allowed more than the maximum number of elements")
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment