diff --git a/gen.go b/gen.go index 67f31bdbc324ca4c6bf803bfbfb52165b571f936..c572c0d5e420a7b24db35d452314448ef07560eb 100644 --- a/gen.go +++ b/gen.go @@ -589,16 +589,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { {{ end }} c, err := cbg.ReadCid(br) if err != nil { @@ -628,16 +626,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { {{ .Name }} = new({{ .TypeName }}) if err := {{ .Name }}.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling {{ .Name }} pointer: %w", err) @@ -685,16 +681,14 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err diff --git a/peeker.go b/peeker.go new file mode 100644 index 0000000000000000000000000000000000000000..2b0658cc923dd58231169c32009f303fbb4155a4 --- /dev/null +++ b/peeker.go @@ -0,0 +1,80 @@ +package typegen + +import ( + "bufio" + "io" +) + +// BytePeeker combines the Reader and ByteScanner interfaces. +type BytePeeker interface { + io.Reader + io.ByteScanner +} + +func GetPeeker(r io.Reader) BytePeeker { + if r, ok := r.(BytePeeker); ok { + return r + } + return &peeker{reader: r} +} + +// peeker is a non-buffering BytePeeker. +type peeker struct { + reader io.Reader + peekState int + lastByte byte +} + +const ( + peekEmpty = iota + peekSet + peekUnread +) + +func (p *peeker) Read(buf []byte) (n int, err error) { + // Read "nothing". I.e., read an error, maybe. + if len(buf) == 0 { + // There's something pending in the + if p.peekState == peekUnread { + return 0, nil + } + return p.reader.Read(nil) + } + + if p.peekState == peekUnread { + buf[0] = p.lastByte + n, err = p.reader.Read(buf[1:]) + n += 1 + } else { + n, err = p.reader.Read(buf) + } + if n > 0 { + p.peekState = peekSet + p.lastByte = buf[n-1] + } + return n, err +} + +func (p *peeker) ReadByte() (byte, error) { + if p.peekState == peekUnread { + p.peekState = peekSet + return p.lastByte, nil + } + var buf [1]byte + n, err := p.reader.Read(buf[:]) + if n == 0 { + return 0, err + } + b := buf[0] + p.lastByte = b + p.peekState = peekSet + return b, err +} + +func (p *peeker) UnreadByte() error { + if p.peekState != peekSet { + return bufio.ErrInvalidUnreadByte + } + p.peekState = peekUnread + return nil +} diff --git a/peeker_test.go b/peeker_test.go new file mode 100644 index 0000000000000000000000000000000000000000..17df76305b65c8a5d3c3b248866b3cd20124ff0d --- /dev/null +++ b/peeker_test.go @@ -0,0 +1,103 @@ +package typegen + +import ( + "bufio" + "bytes" + "io" + "testing" +) + +func TestPeeker(t *testing.T) { + buf := bytes.NewBuffer([]byte{0, 1, 2, 3}) + p := peeker{reader: buf} + n, err := p.Read(nil) + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Fatal(err) + } + + err = p.UnreadByte() + if err != bufio.ErrInvalidUnreadByte { + t.Fatal(err) + } + + // read 2 bytes + var out [2]byte + n, err = p.Read(out[:]) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %d", n) + } + if !bytes.Equal(out[:], []byte{0, 1}) { + t.Fatalf("unexpected output") + } + + // unread that last byte and read it again. + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + b, err := p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 1 { + t.Fatal("expected 1") + } + + // unread that last byte then read 2 + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + n, err = p.Read(out[:]) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %d", n) + } + if !bytes.Equal(out[:], []byte{1, 2}) { + t.Fatalf("unexpected output") + } + + // read another byte + b, err = p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 3 { + t.Fatal("expected 1") + } + + // Should read eof at end. + n, err = p.Read(out[:]) + if err != io.EOF { + t.Fatal(err) + } + if n != 0 { + t.Fatal("should have been at end") + } + // should unread eof + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + + _, err = p.Read(nil) + if err != nil { + t.Fatal(err) + } + + b, err = p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 3 { + t.Fatal("expected 1") + } +} diff --git a/testing/bench_test.go b/testing/bench_test.go index 840c2cd08e8209b025bb50a1fd9b31afdfaff483..88216ed210c6b355b45a8752b4df50d212693252 100644 --- a/testing/bench_test.go +++ b/testing/bench_test.go @@ -2,7 +2,6 @@ package testing import ( "bytes" - "encoding/hex" "io" "io/ioutil" "math/rand" @@ -34,21 +33,28 @@ func BenchmarkMarshaling(b *testing.B) { } func BenchmarkUnmarshaling(b *testing.B) { - hx := "8989f68080807859f099a586f0908093f1af9fb6f3a0ad82e8aaa0efbfbdf1b88688f29d8aaeecacabf0a4be94f19295a0f19b9081f0b6bf8ff3ad83a6f09e9ca7f2be8a8bf187a8a0f280a8b3f4899a9bf181afb1f0bca2b0f1b5ab9bf2b0a3ac80f6f68384787af2b1a082f3b99d98f1adb9b6f3b9868df29fbbb0f3858791f3b5b39df2b68e92f2a9bb9af282b4b2f18fba9cf294b8b2f3a0a1a9f0aab1aaf28cb994f09796aef195bc90f488be81e0a59af2928183f1a0a4abf393bbbae39d8df28fb287f0bf8fa6f0b79a89f188babcf395b0b1f29ebab7f2b0a091f29db48f1b527ef13ee4f5321a403b2130370299eeb937847848f287b8b9f1ad9e90f1b1b9bbf18ebc91f3908583f0be9ab3f2aca8abf0a8acadf380a7abf293a8aaf1b2b6a6f3b89587f3809fadf3a39f97f3a8b48cf3b299bff19cab9df28399a01b374797708d2015d3401b6bfb7066c509754c8478a1f0ab8f96f287988df297aea7f3afa699f3859788f2a2b2b8f2b681a6f29a95a4f382978cf396b183e2acbdf39cbdb5f0b99b94f1a2baaaf1ba89b0f3a8a7bbf397bdabf3af8c83f1b38ebef0beb1a0f3939f83f0b9ad90f1acb597f0b49eb0f29ab3a3f480808ef39b878ae5989ff0a7b789f48981b6f281aba6f2a9ad88f09fb395f0aa95adf0a1a59ff38a8d97f397b7b0eebfa9f2a5ab87f2afa7b8f0b992b81bab566703ac0b139c401b463b0320db277de1841b73dd7cd1861ff4561bc0256739761d28dd1b39c9019ac37c08721b2f08fbf368bf7f94813b075c40eb7f66e0488078b4f288898bf486bc90e9b8bdf180ab8ff39db1a7f0b1afb8f3ab9fa6f1b19182f189bdaff3bf80a5f1a18fb4f39c99a9f0ba839af2adb88fe4bd9df39bba8bf28ebf9ef2b3a783f0b6b395f197be84f3a1998af1b0898bf3b0b08ef1b49b94f094b59df19dbfa6f2aa8494f48ba0b2e28181f1a08999f2b3ac81eaa689f1bb80bbf2ae918bf0a19397f1a19d9cf3b095b5f1b4baa2f0b7ad92f3ab8c8ef38fab92f489b499f18d9899f0b5bcb5f2a3b6a5f2a1acb1831be5bdbd1384238b4b1b8a95991fbf9ca8d11baf61be2ac6477c7d1ba1d9dac0cecd182d1b4175138c8c7fbb4e8384785ef0a7adaef39b8a8ff1b79bacc693f1948ebef0938383f48aa6abf09ab684f1ba8c89f188b091f18ab2b5f1ac8484f2b7b089f18b97bdf1838aacf397ad98f0b9a8aff394a2a3f39eb6bff09ab8bef39189bef18f89aaf3aca982f29381901b7f0bf8763d569f3b403b75c40e5c6163108084786af1879fa8f2b4af9ef3a8b3b6f3b0be93f0aba9a8f0a1a698f3b6a7a7e6adbef1a8849bf28087a3f3b89f82f38caab6f0b7b09ff1bf938ff0a0b1aff2b79691f0a5b29bf4858896f484a5abf393bbbbf3a2b8bdf29393a6eba180f1a1b3b0f29da098f1b09ca7f3bda2901b6f088c64a0854512401b564b5898ca46ac958467f29c9188eea5941bb7b58825b1edf1ee403b6dcbd95c52f6ca33" + r := rand.New(rand.NewSource(123456)) + val, ok := quick.Value(reflect.TypeOf(SimpleTypeTwo{}), r) + if !ok { + b.Fatal("failed to construct type") + } - d, err := hex.DecodeString(hx) - if err != nil { + tt := val.Interface().(SimpleTypeTwo) + + buf := new(bytes.Buffer) + if err := tt.MarshalCBOR(buf); err != nil { b.Fatal(err) } - buf := bytes.NewReader(d) + reader := bytes.NewReader(buf.Bytes()) + b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - buf.Seek(0, io.SeekStart) + reader.Seek(0, io.SeekStart) var tt SimpleTypeTwo - if err := tt.UnmarshalCBOR(buf); err != nil { + if err := tt.UnmarshalCBOR(reader); err != nil { b.Fatal(err) } } @@ -69,11 +75,44 @@ func BenchmarkLinkScan(b *testing.B) { b.Fatal(err) } + reader := bytes.NewReader(buf.Bytes()) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + reader.Seek(0, io.SeekStart) + if err := cbg.ScanForLinks(reader, func(cid.Cid) {}); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDeferred(b *testing.B) { + r := rand.New(rand.NewSource(123456)) + val, ok := quick.Value(reflect.TypeOf(SimpleTypeTwo{}), r) + if !ok { + b.Fatal("failed to construct type") + } + + tt := val.Interface().(SimpleTypeTwo) + + buf := new(bytes.Buffer) + if err := tt.MarshalCBOR(buf); err != nil { + b.Fatal(err) + } + + var ( + deferred cbg.Deferred + reader = bytes.NewReader(buf.Bytes()) + ) + b.ReportAllocs() + b.ResetTimer() for i := 0; i < b.N; i++ { - if err := cbg.ScanForLinks(bytes.NewReader(buf.Bytes()), func(cid.Cid) {}); err != nil { + reader.Seek(0, io.SeekStart) + if err := deferred.UnmarshalCBOR(reader); err != nil { b.Fatal(err) } } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index f9d365579c714ab7c2d75044777f325238475ea1..186f9a2d9cc4aa7338dcfe3cf1905b42a66319bd 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -427,16 +427,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeTwo) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) @@ -617,16 +615,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err @@ -643,16 +639,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err @@ -753,16 +747,14 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeOne) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index a29385e3570fd427050da6a4f5a6157933c163f5..1016c5de9e123c69e778af702ef3c4f50147b5a9 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -221,16 +221,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeTree) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) @@ -243,16 +241,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stufff = new(SimpleTypeTwo) if err := t.Stufff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stufff pointer: %w", err) @@ -384,16 +380,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err diff --git a/utils.go b/utils.go index 858b253f495c32bbb7b9300a504bba4077f06394..230771606ca4bede9a8e6c602e177195cf3b387f 100644 --- a/utils.go +++ b/utils.go @@ -14,69 +14,58 @@ import ( ) const maxCidLength = 100 +const maxHeaderSize = 9 func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { - buf := make([]byte, maxCidLength) - return scanForLinksRec(br, cb, buf) -} - -func scanForLinksRec(br io.Reader, cb func(cid.Cid), scratch []byte) error { - maj, extra, err := CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - case MajByteString, MajTextString: - _, err := io.CopyN(ioutil.Discard, br, int64(extra)) + scratch := make([]byte, maxCidLength) + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) if err != nil { return err } - case MajTag: - if extra == 42 { - maj, extra, err = CborReadHeaderBuf(br, scratch) - if err != nil { - return err - } - if maj != MajByteString { - return fmt.Errorf("expected cbor type 'byte string' in input") - } - - if extra > maxCidLength { - return fmt.Errorf("string in cbor input too long") - } - - if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { - return err - } - - c, err := cid.Cast(scratch[1:extra]) + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + case MajByteString, MajTextString: + _, err := io.CopyN(ioutil.Discard, br, int64(extra)) if err != nil { return err } - cb(c) - - } else { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err + case MajTag: + if extra == 42 { + maj, extra, err = CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if maj != MajByteString { + return fmt.Errorf("expected cbor type 'byte string' in input") + } + + if extra > maxCidLength { + return fmt.Errorf("string in cbor input too long") + } + + if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { + return err + } + + c, err := cid.Cast(scratch[1:extra]) + if err != nil { + return err + } + cb(c) + + } else { + remaining++ } + case MajArray: + remaining += extra + case MajMap: + remaining += (extra * 2) + default: + return fmt.Errorf("unhandled cbor type: %d", maj) } - case MajArray: - for i := 0; i < int(extra); i++ { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err - } - } - case MajMap: - for i := 0; i < int(extra*2); i++ { - if err := scanForLinksRec(br, cb, scratch); err != nil { - return err - } - } - default: - return fmt.Errorf("unhandled cbor type: %d", maj) } return nil } @@ -103,15 +92,7 @@ type CBORMarshaler interface { } type Deferred struct { - Raw []byte - nestedLevel int -} - -func (d *Deferred) Child() Deferred { - return Deferred{ - Raw: nil, - nestedLevel: d.nestedLevel + 1, - } + Raw []byte } func (d *Deferred) MarshalCBOR(w io.Writer) error { @@ -127,136 +108,82 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { } func (d *Deferred) UnmarshalCBOR(br io.Reader) error { - // TODO: theres a more efficient way to implement this method, but for now - // this is fine - maj, extra, err := CborReadHeader(br) - if err != nil { - return err - } - header := CborEncodeMajorType(maj, extra) - - switch maj { - case MajTag, MajArray, MajMap: - if d.nestedLevel >= MaxLength { - return maxLengthError - } - } - - switch maj { - case MajUnsignedInt, MajNegativeInt, MajOther: - d.Raw = header - return nil - case MajByteString, MajTextString: - if extra > ByteArrayMaxLen { - return maxLengthError - } - buf := make([]byte, int(extra)+len(header)) - copy(buf, header) - if _, err := io.ReadFull(br, buf[len(header):]); err != nil { + // Reuse any existing buffers. + reusedBuf := d.Raw[:0] + d.Raw = nil + buf := bytes.NewBuffer(reusedBuf) + + // Allocate some scratch space. + scratch := make([]byte, maxHeaderSize) + + // Algorithm: + // + // 1. We start off expecting to read one element. + // 2. If we see a tag, we expect to read one more element so we increment "remaining". + // 3. If see an array, we expect to read "extra" elements so we add "extra" to "remaining". + // 4. If see a map, we expect to read "2*extra" elements so we add "2*extra" to "remaining". + // 5. While "remaining" is non-zero, read more elements. + + // define this once so we don't keep allocating it. + limitedReader := io.LimitedReader{R: br} + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) + if err != nil { return err } - - d.Raw = buf - - return nil - case MajTag: - sub := d.Child() - if err := sub.UnmarshalCBOR(br); err != nil { + if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil { return err } - d.Raw = append(header, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { - return maxLengthError - } - return nil - case MajArray: - d.Raw = header - for i := 0; i < int(extra); i++ { - sub := d.Child() - if err := sub.UnmarshalCBOR(br); err != nil { - return err - } - - d.Raw = append(d.Raw, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + // nothing fancy to do + case MajByteString, MajTextString: + if extra > ByteArrayMaxLen { return maxLengthError } - } - return nil - case MajMap: - d.Raw = header - sub := d.Child() - for i := 0; i < int(extra*2); i++ { - sub.Raw = sub.Raw[:0] - if err := sub.UnmarshalCBOR(br); err != nil { + // Copy the bytes + limitedReader.N = int64(extra) + buf.Grow(int(extra)) + if n, err := buf.ReadFrom(&limitedReader); err != nil { return err + } else if n < int64(extra) { + return io.ErrUnexpectedEOF } - d.Raw = append(d.Raw, sub.Raw...) - if len(d.Raw) > ByteArrayMaxLen { + case MajTag: + remaining++ + case MajArray: + if extra > MaxLength { return maxLengthError } + remaining += extra + case MajMap: + if extra > MaxLength { + return maxLengthError + } + remaining += extra * 2 + default: + return fmt.Errorf("unhandled deferred cbor type: %d", maj) } - return nil - default: - return fmt.Errorf("unhandled deferred cbor type: %d", maj) - } -} - -// this is a bit gnarly i should just switch to taking in a byte array at the top level -type BytePeeker interface { - io.Reader - PeekByte() (byte, error) -} - -type peeker struct { - io.Reader -} - -func (p *peeker) PeekByte() (byte, error) { - switch r := p.Reader.(type) { - case *bytes.Reader: - b, err := r.ReadByte() - if err != nil { - return 0, err - } - return b, r.UnreadByte() - case *bytes.Buffer: - b, err := r.ReadByte() - if err != nil { - return 0, err - } - return b, r.UnreadByte() - case *bufio.Reader: - o, err := r.Peek(1) - if err != nil { - return 0, err - } - - return o[0], nil - default: - panic("invariant violated") } + d.Raw = buf.Bytes() + return nil } -func GetPeeker(r io.Reader) BytePeeker { +func readByte(r io.Reader) (byte, error) { + // try to cast to a concrete type, it's much faster than casting to an + // interface. switch r := r.(type) { - case *bytes.Reader: - return &peeker{r} case *bytes.Buffer: - return &peeker{r} + return r.ReadByte() + case *bytes.Reader: + return r.ReadByte() case *bufio.Reader: - return &peeker{r} + return r.ReadByte() case *peeker: - return r - default: - return &peeker{bufio.NewReaderSize(r, 16)} - } -} - -func readByte(r io.Reader) (byte, error) { - if br, ok := r.(io.ByteReader); ok { - return br.ReadByte() + return r.ReadByte() + case io.ByteReader: + return r.ReadByte() } var buf [1]byte _, err := io.ReadFull(r, buf[:1]) diff --git a/utils_test.go b/utils_test.go index daf8ca10508a6e753d42f8c9faddfa27ac90d764..f220cfd194e68d9045cdefe01dc37a48760f73c1 100644 --- a/utils_test.go +++ b/utils_test.go @@ -36,18 +36,3 @@ func TestDeferredMaxLengthSingle(t *testing.T) { t.Fatal("deferred: allowed more than the maximum allocation supported") } } - -func TestDeferredMaxLengthRecursive(t *testing.T) { - var header bytes.Buffer - for i := 0; i < MaxLength+1; i++ { - if err := WriteMajorTypeHeader(&header, MajTag, 0); err != nil { - t.Fatal("failed to write header") - } - } - - var deferred Deferred - err := deferred.UnmarshalCBOR(&header) - if err != maxLengthError { - t.Fatal("deferred: allowed more than the maximum number of elements") - } -}