From cdf4113c4e00da989a7b6ab39e49aa2d73649d83 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 4 Aug 2020 20:57:19 -0700 Subject: [PATCH] Optimize Deferred We use this quite frequently so it should be fast. Note: this removes the depth restriction because the algorithm is no longer recursive. --- utils.go | 119 +++++++++++++++++++++----------------------------- utils_test.go | 15 ------- 2 files changed, 49 insertions(+), 85 deletions(-) diff --git a/utils.go b/utils.go index 4e3aacf..100a7b9 100644 --- a/utils.go +++ b/utils.go @@ -14,6 +14,7 @@ import ( ) const maxCidLength = 100 +const maxHeaderSize = 9 func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { buf := make([]byte, maxCidLength) @@ -103,15 +104,7 @@ type CBORMarshaler interface { } type Deferred struct { - Raw []byte - nestedLevel int -} - -func (d *Deferred) Child() Deferred { - return Deferred{ - Raw: nil, - nestedLevel: d.nestedLevel + 1, - } + Raw []byte } func (d *Deferred) MarshalCBOR(w io.Writer) error { @@ -127,80 +120,66 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { } func (d *Deferred) UnmarshalCBOR(br io.Reader) error { - // TODO: theres a more efficient way to implement this method, but for now - // this is fine - maj, extra, err := CborReadHeader(br) - if err != nil { - return err - } - header := CborEncodeMajorType(maj, extra) - - switch maj { - case MajTag, MajArray, MajMap: - if d.nestedLevel >= MaxLength { - return maxLengthError - } - } - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - d.Raw = header - return nil - case MajByteString, MajTextString: - if extra > ByteArrayMaxLen { - return maxLengthError - } - buf := make([]byte, int(extra)+len(header)) - copy(buf, header) - if _, err := io.ReadFull(br, buf[len(header):]); err != nil { + // Reuse any existing buffers. + reusedBuf := d.Raw[:0] + d.Raw = nil + buf := bytes.NewBuffer(reusedBuf) + + // Allocate some scratch space. + scratch := make([]byte, maxHeaderSize) + + // Algorithm: + // + // 1. We start off expecting to read one element. + // 2. If we see a tag, we expect to read one more element so we increment "remaining". + // 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining". + // 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining". + // 5. While "remaining" is non-zero, read more elements. + + // define this once so we don't keep allocating it. + limitedReader := io.LimitedReader{R: br} + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) + if err != nil { return err } - - d.Raw = buf - - return nil - case MajTag: - sub := d.Child() - if err := sub.UnmarshalCBOR(br); err != nil { + if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil { return err } - d.Raw = append(header, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { - return maxLengthError - } - return nil - case MajArray: - d.Raw = header - for i := 0; i < int(extra); i++ { - sub := d.Child() - if err := sub.UnmarshalCBOR(br); err != nil { - return err - } - - d.Raw = append(d.Raw, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + // nothing fancy to do + case MajByteString, MajTextString: + if extra > ByteArrayMaxLen { return maxLengthError } - } - return nil - case MajMap: - d.Raw = header - sub := d.Child() - for i := 0; i < int(extra*2); i++ { - sub.Raw = sub.Raw[:0] - if err := sub.UnmarshalCBOR(br); err != nil { + // Copy the bytes + limitedReader.N = int64(extra) + buf.Grow(int(extra)) + if n, err := buf.ReadFrom(&limitedReader); err != nil { return err + } else if n < int64(extra) { + return io.ErrUnexpectedEOF } - d.Raw = append(d.Raw, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { + case MajTag: + remaining++ + case MajArray: + if extra > MaxLength { return maxLengthError } + remaining += extra + case MajMap: + if extra > MaxLength { + return maxLengthError + } + remaining += extra * 2 + default: + return fmt.Errorf("unhandled deferred cbor type: %d", maj) } - return nil - default: - return fmt.Errorf("unhandled deferred cbor type: %d", maj) } + d.Raw = buf.Bytes() + return nil } func readByte(r io.Reader) (byte, error) { diff --git a/utils_test.go b/utils_test.go index daf8ca1..f220cfd 100644 --- a/utils_test.go +++ b/utils_test.go @@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) { t.Fatal("deferred: allowed more than the maximum allocation supported") } } - -func TestDeferredMaxLengthRecursive(t *testing.T) { - var header bytes.Buffer - for i := 0; i < MaxLength+1; i++ { - if err := WriteMajorTypeHeader(&header, MajTag, 0); err != nil { - t.Fatal("failed to write header") - } - } - - var deferred Deferred - err := deferred.UnmarshalCBOR(&header) - if err != maxLengthError { - t.Fatal("deferred: allowed more than the maximum number of elements") - } -} -- GitLab