Commit 021bda15 authored by Steven Allen's avatar Steven Allen

Switch to non-buffering byte reader

Buffering could lead to reading over the end of the object, corrupting the next object.

This patch also gets rid of "PeekByte" and uses the standard ReadByte/UnreadByte
interfaces. That way, we can avoid wrapping the byte reader in the happy path,
saving some overhead.
parent 6a3894a6
...@@ -589,16 +589,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { ...@@ -589,16 +589,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
{{ if .Pointer }} {{ if .Pointer }}
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
{{ end }} {{ end }}
c, err := cbg.ReadCid(br) c, err := cbg.ReadCid(br)
if err != nil { if err != nil {
...@@ -628,16 +626,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { ...@@ -628,16 +626,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
{{ if .Pointer }} {{ if .Pointer }}
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
{{ .Name }} = new({{ .TypeName }}) {{ .Name }} = new({{ .TypeName }})
if err := {{ .Name }}.UnmarshalCBOR(br); err != nil { if err := {{ .Name }}.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling {{ .Name }} pointer: %w", err) return xerrors.Errorf("unmarshaling {{ .Name }} pointer: %w", err)
...@@ -685,16 +681,14 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { ...@@ -685,16 +681,14 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error {
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
{{ if .Pointer }} {{ if .Pointer }}
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
maj, extra, err = {{ ReadHeader "br" }} maj, extra, err = {{ ReadHeader "br" }}
if err != nil { if err != nil {
return err return err
......
package typegen
import (
"bufio"
"io"
)
// BytePeeker combines the Reader and ByteScanner interfaces.
type BytePeeker interface {
io.Reader
io.ByteScanner
}
func GetPeeker(r io.Reader) BytePeeker {
if r, ok := r.(BytePeeker); ok {
return r
}
return &peeker{reader: r}
}
// peeker is a non-buffering BytePeeker.
type peeker struct {
reader io.Reader
peekState int
lastByte byte
}
const (
peekEmpty = iota
peekSet
peekUnread
)
func (p *peeker) Read(buf []byte) (n int, err error) {
// Read "nothing". I.e., read an error, maybe.
if len(buf) == 0 {
// There's something pending in the
if p.peekState == peekUnread {
return 0, nil
}
return p.reader.Read(nil)
}
if p.peekState == peekUnread {
buf[0] = p.lastByte
n, err = p.reader.Read(buf[1:])
n += 1
} else {
n, err = p.reader.Read(buf)
}
if n > 0 {
p.peekState = peekSet
p.lastByte = buf[n-1]
}
return n, err
}
func (p *peeker) ReadByte() (byte, error) {
if p.peekState == peekUnread {
p.peekState = peekSet
return p.lastByte, nil
}
var buf [1]byte
n, err := p.reader.Read(buf[:])
if n == 0 {
return 0, err
}
b := buf[0]
p.lastByte = b
p.peekState = peekSet
return b, err
}
func (p *peeker) UnreadByte() error {
if p.peekState != peekSet {
return bufio.ErrInvalidUnreadByte
}
p.peekState = peekUnread
return nil
}
package typegen
import (
"bufio"
"bytes"
"io"
"testing"
)
func TestPeeker(t *testing.T) {
buf := bytes.NewBuffer([]byte{0, 1, 2, 3})
p := peeker{reader: buf}
n, err := p.Read(nil)
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Fatal(err)
}
err = p.UnreadByte()
if err != bufio.ErrInvalidUnreadByte {
t.Fatal(err)
}
// read 2 bytes
var out [2]byte
n, err = p.Read(out[:])
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 bytes, got %d", n)
}
if !bytes.Equal(out[:], []byte{0, 1}) {
t.Fatalf("unexpected output")
}
// unread that last byte and read it again.
err = p.UnreadByte()
if err != nil {
t.Fatal(err)
}
b, err := p.ReadByte()
if err != nil {
t.Fatal(err)
}
if b != 1 {
t.Fatal("expected 1")
}
// unread that last byte then read 2
err = p.UnreadByte()
if err != nil {
t.Fatal(err)
}
n, err = p.Read(out[:])
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 bytes, got %d", n)
}
if !bytes.Equal(out[:], []byte{1, 2}) {
t.Fatalf("unexpected output")
}
// read another byte
b, err = p.ReadByte()
if err != nil {
t.Fatal(err)
}
if b != 3 {
t.Fatal("expected 1")
}
// Should read eof at end.
n, err = p.Read(out[:])
if err != io.EOF {
t.Fatal(err)
}
if n != 0 {
t.Fatal("should have been at end")
}
// should unread eof
err = p.UnreadByte()
if err != nil {
t.Fatal(err)
}
_, err = p.Read(nil)
if err != nil {
t.Fatal(err)
}
b, err = p.ReadByte()
if err != nil {
t.Fatal(err)
}
if b != 3 {
t.Fatal("expected 1")
}
}
...@@ -427,16 +427,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { ...@@ -427,16 +427,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
t.Stuff = new(SimpleTypeTwo) t.Stuff = new(SimpleTypeTwo)
if err := t.Stuff.UnmarshalCBOR(br); err != nil { if err := t.Stuff.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err)
...@@ -617,16 +615,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { ...@@ -617,16 +615,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil { if err != nil {
return err return err
...@@ -643,16 +639,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { ...@@ -643,16 +639,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil { if err != nil {
return err return err
...@@ -753,16 +747,14 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { ...@@ -753,16 +747,14 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
t.Stuff = new(SimpleTypeOne) t.Stuff = new(SimpleTypeOne)
if err := t.Stuff.UnmarshalCBOR(br); err != nil { if err := t.Stuff.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err)
......
...@@ -221,16 +221,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { ...@@ -221,16 +221,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
t.Stuff = new(SimpleTypeTree) t.Stuff = new(SimpleTypeTree)
if err := t.Stuff.UnmarshalCBOR(br); err != nil { if err := t.Stuff.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err)
...@@ -243,16 +241,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { ...@@ -243,16 +241,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
t.Stufff = new(SimpleTypeTwo) t.Stufff = new(SimpleTypeTwo)
if err := t.Stufff.UnmarshalCBOR(br); err != nil { if err := t.Stufff.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.Stufff pointer: %w", err) return xerrors.Errorf("unmarshaling t.Stufff pointer: %w", err)
...@@ -384,16 +380,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { ...@@ -384,16 +380,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error {
{ {
pb, err := br.PeekByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return err return err
} }
if pb == cbg.CborNull[0] { if b != cbg.CborNull[0] {
var nbuf [1]byte if err := br.UnreadByte(); err != nil {
if _, err := br.Read(nbuf[:]); err != nil {
return err return err
} }
} else {
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil { if err != nil {
return err return err
......
package typegen package typegen
import ( import (
"bufio"
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
...@@ -203,57 +201,6 @@ func (d *Deferred) UnmarshalCBOR(br io.Reader) error { ...@@ -203,57 +201,6 @@ func (d *Deferred) UnmarshalCBOR(br io.Reader) error {
} }
} }
// this is a bit gnarly i should just switch to taking in a byte array at the top level
type BytePeeker interface {
io.Reader
PeekByte() (byte, error)
}
type peeker struct {
io.Reader
}
func (p *peeker) PeekByte() (byte, error) {
switch r := p.Reader.(type) {
case *bytes.Reader:
b, err := r.ReadByte()
if err != nil {
return 0, err
}
return b, r.UnreadByte()
case *bytes.Buffer:
b, err := r.ReadByte()
if err != nil {
return 0, err
}
return b, r.UnreadByte()
case *bufio.Reader:
o, err := r.Peek(1)
if err != nil {
return 0, err
}
return o[0], nil
default:
panic("invariant violated")
}
}
func GetPeeker(r io.Reader) BytePeeker {
switch r := r.(type) {
case *bytes.Reader:
return &peeker{r}
case *bytes.Buffer:
return &peeker{r}
case *bufio.Reader:
return &peeker{r}
case *peeker:
return r
default:
return &peeker{bufio.NewReaderSize(r, 16)}
}
}
func readByte(r io.Reader) (byte, error) { func readByte(r io.Reader) (byte, error) {
if br, ok := r.(io.ByteReader); ok { if br, ok := r.(io.ByteReader); ok {
return br.ReadByte() return br.ReadByte()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment