Commit fd5c0318 authored by Dirk McCormick's avatar Dirk McCormick

feat: allow unmarshalling of object with same fields + more fields than marshalled object

parent 169e9d70
......@@ -64,6 +64,7 @@ import (
var _ = xerrors.Errorf
var _ = cid.Undef
`)
}
......@@ -1269,7 +1270,8 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
return doTemplate(w, gti, `
default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
// Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid){})
}
}
......
......@@ -16,6 +16,7 @@ var (
defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
{Name: "cid", PkgPath: "github.com/ipfs/go-cid"},
}
)
......
......@@ -20,6 +20,8 @@ func main() {
if err := cbg.WriteMapEncodersToFile("testing/cbor_map_gen.go", "testing",
types.SimpleTypeTree{},
types.NeedScratchForMap{},
types.SimpleStructV1{},
types.SimpleStructV2{},
); err != nil {
panic(err)
}
......
......@@ -6,11 +6,13 @@ import (
"fmt"
"io"
cid "github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
)
var _ = xerrors.Errorf
var _ = cid.Undef
var lengthBufSignedArray = []byte{129}
......
......@@ -6,11 +6,13 @@ import (
"fmt"
"io"
cid "github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
)
var _ = xerrors.Errorf
var _ = cid.Undef
func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error {
if t == nil {
......@@ -402,7 +404,8 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error {
}
default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
// Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid) {})
}
}
......@@ -490,7 +493,588 @@ func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error {
}
default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name)
// Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid) {})
}
}
return nil
}
func (t *SimpleStructV1) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
}
if _, err := w.Write([]byte{164}); err != nil {
return err
}
scratch := make([]byte, 9)
// t.OldStr (string) (string)
if len("OldStr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldStr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldStr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldStr")); err != nil {
return err
}
if len(t.OldStr) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.OldStr was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.OldStr))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.OldStr)); err != nil {
return err
}
// t.OldBytes ([]uint8) (slice)
if len("OldBytes") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldBytes\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldBytes"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldBytes")); err != nil {
return err
}
if len(t.OldBytes) > cbg.ByteArrayMaxLen {
return xerrors.Errorf("Byte array in field t.OldBytes was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.OldBytes))); err != nil {
return err
}
if _, err := w.Write(t.OldBytes[:]); err != nil {
return err
}
// t.OldNum (uint64) (uint64)
if len("OldNum") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldNum\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldNum"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldNum")); err != nil {
return err
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.OldNum)); err != nil {
return err
}
// t.OldPtr (cid.Cid) (struct)
if len("OldPtr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldPtr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldPtr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldPtr")); err != nil {
return err
}
if t.OldPtr == nil {
if _, err := w.Write(cbg.CborNull); err != nil {
return err
}
} else {
if err := cbg.WriteCidBuf(scratch, w, *t.OldPtr); err != nil {
return xerrors.Errorf("failed to write cid field t.OldPtr: %w", err)
}
}
return nil
}
func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) error {
*t = SimpleStructV1{}
br := cbg.GetPeeker(r)
scratch := make([]byte, 8)
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajMap {
return fmt.Errorf("cbor input should be of type map")
}
if extra > cbg.MaxLength {
return fmt.Errorf("SimpleStructV1: map struct too large (%d)", extra)
}
var name string
n := extra
for i := uint64(0); i < n; i++ {
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
name = string(sval)
}
switch name {
// t.OldStr (string) (string)
case "OldStr":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.OldStr = string(sval)
}
// t.OldBytes ([]uint8) (slice)
case "OldBytes":
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if extra > cbg.ByteArrayMaxLen {
return fmt.Errorf("t.OldBytes: byte array too large (%d)", extra)
}
if maj != cbg.MajByteString {
return fmt.Errorf("expected byte array")
}
if extra > 0 {
t.OldBytes = make([]uint8, extra)
}
if _, err := io.ReadFull(br, t.OldBytes[:]); err != nil {
return err
}
// t.OldNum (uint64) (uint64)
case "OldNum":
{
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajUnsignedInt {
return fmt.Errorf("wrong type for uint64 field")
}
t.OldNum = uint64(extra)
}
// t.OldPtr (cid.Cid) (struct)
case "OldPtr":
{
b, err := br.ReadByte()
if err != nil {
return err
}
if b != cbg.CborNull[0] {
if err := br.UnreadByte(); err != nil {
return err
}
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.OldPtr: %w", err)
}
t.OldPtr = &c
}
}
default:
// Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid) {})
}
}
return nil
}
func (t *SimpleStructV2) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
}
if _, err := w.Write([]byte{168}); err != nil {
return err
}
scratch := make([]byte, 9)
// t.OldStr (string) (string)
if len("OldStr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldStr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldStr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldStr")); err != nil {
return err
}
if len(t.OldStr) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.OldStr was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.OldStr))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.OldStr)); err != nil {
return err
}
// t.NewStr (string) (string)
if len("NewStr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"NewStr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewStr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("NewStr")); err != nil {
return err
}
if len(t.NewStr) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.NewStr was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.NewStr))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.NewStr)); err != nil {
return err
}
// t.OldBytes ([]uint8) (slice)
if len("OldBytes") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldBytes\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldBytes"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldBytes")); err != nil {
return err
}
if len(t.OldBytes) > cbg.ByteArrayMaxLen {
return xerrors.Errorf("Byte array in field t.OldBytes was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.OldBytes))); err != nil {
return err
}
if _, err := w.Write(t.OldBytes[:]); err != nil {
return err
}
// t.NewBytes ([]uint8) (slice)
if len("NewBytes") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"NewBytes\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewBytes"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("NewBytes")); err != nil {
return err
}
if len(t.NewBytes) > cbg.ByteArrayMaxLen {
return xerrors.Errorf("Byte array in field t.NewBytes was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.NewBytes))); err != nil {
return err
}
if _, err := w.Write(t.NewBytes[:]); err != nil {
return err
}
// t.OldNum (uint64) (uint64)
if len("OldNum") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldNum\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldNum"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldNum")); err != nil {
return err
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.OldNum)); err != nil {
return err
}
// t.NewNum (uint64) (uint64)
if len("NewNum") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"NewNum\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewNum"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("NewNum")); err != nil {
return err
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.NewNum)); err != nil {
return err
}
// t.OldPtr (cid.Cid) (struct)
if len("OldPtr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"OldPtr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldPtr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("OldPtr")); err != nil {
return err
}
if t.OldPtr == nil {
if _, err := w.Write(cbg.CborNull); err != nil {
return err
}
} else {
if err := cbg.WriteCidBuf(scratch, w, *t.OldPtr); err != nil {
return xerrors.Errorf("failed to write cid field t.OldPtr: %w", err)
}
}
// t.NewPtr (cid.Cid) (struct)
if len("NewPtr") > cbg.MaxLength {
return xerrors.Errorf("Value in field \"NewPtr\" was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewPtr"))); err != nil {
return err
}
if _, err := io.WriteString(w, string("NewPtr")); err != nil {
return err
}
if t.NewPtr == nil {
if _, err := w.Write(cbg.CborNull); err != nil {
return err
}
} else {
if err := cbg.WriteCidBuf(scratch, w, *t.NewPtr); err != nil {
return xerrors.Errorf("failed to write cid field t.NewPtr: %w", err)
}
}
return nil
}
func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) error {
*t = SimpleStructV2{}
br := cbg.GetPeeker(r)
scratch := make([]byte, 8)
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajMap {
return fmt.Errorf("cbor input should be of type map")
}
if extra > cbg.MaxLength {
return fmt.Errorf("SimpleStructV2: map struct too large (%d)", extra)
}
var name string
n := extra
for i := uint64(0); i < n; i++ {
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
name = string(sval)
}
switch name {
// t.OldStr (string) (string)
case "OldStr":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.OldStr = string(sval)
}
// t.NewStr (string) (string)
case "NewStr":
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.NewStr = string(sval)
}
// t.OldBytes ([]uint8) (slice)
case "OldBytes":
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if extra > cbg.ByteArrayMaxLen {
return fmt.Errorf("t.OldBytes: byte array too large (%d)", extra)
}
if maj != cbg.MajByteString {
return fmt.Errorf("expected byte array")
}
if extra > 0 {
t.OldBytes = make([]uint8, extra)
}
if _, err := io.ReadFull(br, t.OldBytes[:]); err != nil {
return err
}
// t.NewBytes ([]uint8) (slice)
case "NewBytes":
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if extra > cbg.ByteArrayMaxLen {
return fmt.Errorf("t.NewBytes: byte array too large (%d)", extra)
}
if maj != cbg.MajByteString {
return fmt.Errorf("expected byte array")
}
if extra > 0 {
t.NewBytes = make([]uint8, extra)
}
if _, err := io.ReadFull(br, t.NewBytes[:]); err != nil {
return err
}
// t.OldNum (uint64) (uint64)
case "OldNum":
{
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajUnsignedInt {
return fmt.Errorf("wrong type for uint64 field")
}
t.OldNum = uint64(extra)
}
// t.NewNum (uint64) (uint64)
case "NewNum":
{
maj, extra, err = cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajUnsignedInt {
return fmt.Errorf("wrong type for uint64 field")
}
t.NewNum = uint64(extra)
}
// t.OldPtr (cid.Cid) (struct)
case "OldPtr":
{
b, err := br.ReadByte()
if err != nil {
return err
}
if b != cbg.CborNull[0] {
if err := br.UnreadByte(); err != nil {
return err
}
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.OldPtr: %w", err)
}
t.OldPtr = &c
}
}
// t.NewPtr (cid.Cid) (struct)
case "NewPtr":
{
b, err := br.ReadByte()
if err != nil {
return err
}
if b != cbg.CborNull[0] {
if err := br.UnreadByte(); err != nil {
return err
}
c, err := cbg.ReadCid(br)
if err != nil {
return xerrors.Errorf("failed to read cid field t.NewPtr: %w", err)
}
t.NewPtr = &c
}
}
default:
// Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid) {})
}
}
......
......@@ -3,6 +3,7 @@ package testing
import (
"bytes"
"encoding/json"
"github.com/ipfs/go-cid"
"math/rand"
"reflect"
"testing"
......@@ -162,3 +163,95 @@ func TestTimeIsh(t *testing.T) {
}
}
func TestLessToMoreFieldsRoundTrip(t *testing.T) {
dummyCid, _ := cid.Parse("bafkqaaa")
obj := &SimpleStructV1{
OldStr: "hello",
OldBytes: []byte("bytes"),
OldNum: 10,
OldPtr: &dummyCid,
}
buf := new(bytes.Buffer)
if err := obj.MarshalCBOR(buf); err != nil {
t.Fatal("failed marshaling", err)
}
enc := buf.Bytes()
nobj := SimpleStructV2{}
if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil {
t.Logf("got bad bytes: %x", enc)
t.Fatal("failed to round trip object: ", err)
}
if obj.OldStr != nobj.OldStr {
t.Fatal("mismatch ", obj.OldStr, " != ", nobj.OldStr)
}
if nobj.NewStr != "" {
t.Fatal("expected field to be zero value")
}
if obj.OldNum != nobj.OldNum {
t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum)
}
if nobj.NewNum != 0 {
t.Fatal("expected field to be zero value")
}
if !bytes.Equal(obj.OldBytes, nobj.OldBytes) {
t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes)
}
if nobj.NewBytes != nil {
t.Fatal("expected field to be zero value")
}
if *obj.OldPtr != *nobj.OldPtr {
t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr)
}
if nobj.NewPtr != nil {
t.Fatal("expected field to be zero value")
}
}
func TestMoreToLessFieldsRoundTrip(t *testing.T) {
dummyCid1, _ := cid.Parse("bafkqaaa")
dummyCid2, _ := cid.Parse("bafkqaab")
obj := &SimpleStructV2{
OldStr: "oldstr",
NewStr: "newstr",
OldBytes: []byte("oldbytes"),
NewBytes: []byte("newbytes"),
OldNum: 10,
NewNum: 11,
OldPtr: &dummyCid1,
NewPtr: &dummyCid2,
}
buf := new(bytes.Buffer)
if err := obj.MarshalCBOR(buf); err != nil {
t.Fatal("failed marshaling", err)
}
enc := buf.Bytes()
nobj := SimpleStructV1{}
if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil {
t.Logf("got bad bytes: %x", enc)
t.Fatal("failed to round trip object: ", err)
}
if obj.OldStr != nobj.OldStr {
t.Fatal("mismatch", obj.OldStr, " != ", nobj.OldStr)
}
if obj.OldNum != nobj.OldNum {
t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum)
}
if !bytes.Equal(obj.OldBytes, nobj.OldBytes) {
t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes)
}
if *obj.OldPtr != *nobj.OldPtr {
t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr)
}
}
package testing
import (
"github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen"
)
......@@ -43,6 +44,27 @@ type SimpleTypeTree struct {
NotPizza *uint64
}
type SimpleStructV1 struct {
OldStr string
OldBytes []byte
OldNum uint64
OldPtr *cid.Cid
}
type SimpleStructV2 struct {
OldStr string
NewStr string
OldBytes []byte
NewBytes []byte
OldNum uint64
NewNum uint64
OldPtr *cid.Cid
NewPtr *cid.Cid
}
type DeferredContainer struct {
Stuff *SimpleTypeOne
Deferred *cbg.Deferred
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment