Unverified Commit f37d2929 authored by Łukasz Magiera's avatar Łukasz Magiera Committed by GitHub

Merge pull request #50 from dirkmc/feat/round-trip-extra-fields

feat: allow unmarshaling of struct with more fields than marshaled struct
parents 169e9d70 a77f48b5
...@@ -57,6 +57,7 @@ package {{ .Package }} ...@@ -57,6 +57,7 @@ package {{ .Package }}
import ( import (
"fmt" "fmt"
"io" "io"
"sort"
{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}" {{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }} {{ end }}
...@@ -64,6 +65,8 @@ import ( ...@@ -64,6 +65,8 @@ import (
var _ = xerrors.Errorf var _ = xerrors.Errorf
var _ = cid.Undef
var _ = sort.Sort
`) `)
} }
...@@ -1269,7 +1272,8 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { ...@@ -1269,7 +1272,8 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
return doTemplate(w, gti, ` return doTemplate(w, gti, `
default: default:
return fmt.Errorf("unknown struct field %d: '%s'", i, name) // Field doesn't exist on this type, so ignore it
cbg.ScanForLinks(r, func(cid.Cid){})
} }
} }
......
...@@ -16,6 +16,7 @@ var ( ...@@ -16,6 +16,7 @@ var (
defaultImports = []Import{ defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"}, {Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"}, {Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
{Name: "cid", PkgPath: "github.com/ipfs/go-cid"},
} }
) )
......
...@@ -20,6 +20,8 @@ func main() { ...@@ -20,6 +20,8 @@ func main() {
if err := cbg.WriteMapEncodersToFile("testing/cbor_map_gen.go", "testing", if err := cbg.WriteMapEncodersToFile("testing/cbor_map_gen.go", "testing",
types.SimpleTypeTree{}, types.SimpleTypeTree{},
types.NeedScratchForMap{}, types.NeedScratchForMap{},
types.SimpleStructV1{},
types.SimpleStructV2{},
); err != nil { ); err != nil {
panic(err) panic(err)
} }
......
...@@ -5,12 +5,16 @@ package testing ...@@ -5,12 +5,16 @@ package testing
import ( import (
"fmt" "fmt"
"io" "io"
"sort"
cid "github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen" cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors" xerrors "golang.org/x/xerrors"
) )
var _ = xerrors.Errorf var _ = xerrors.Errorf
var _ = cid.Undef
var _ = sort.Sort
var lengthBufSignedArray = []byte{129} var lengthBufSignedArray = []byte{129}
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ package testing ...@@ -3,6 +3,7 @@ package testing
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/ipfs/go-cid"
"math/rand" "math/rand"
"reflect" "reflect"
"testing" "testing"
...@@ -162,3 +163,155 @@ func TestTimeIsh(t *testing.T) { ...@@ -162,3 +163,155 @@ func TestTimeIsh(t *testing.T) {
} }
} }
func TestLessToMoreFieldsRoundTrip(t *testing.T) {
dummyCid, _ := cid.Parse("bafkqaaa")
simpleTypeOne := SimpleTypeOne{
Foo: "foo",
Value: 1,
Binary: []byte("bin"),
Signed: -1,
NString: "namedstr",
}
obj := &SimpleStructV1{
OldStr: "hello",
OldBytes: []byte("bytes"),
OldNum: 10,
OldPtr: &dummyCid,
OldMap: map[string]SimpleTypeOne{"first": simpleTypeOne},
OldArray: []SimpleTypeOne{simpleTypeOne},
OldStruct: simpleTypeOne,
}
buf := new(bytes.Buffer)
if err := obj.MarshalCBOR(buf); err != nil {
t.Fatal("failed marshaling", err)
}
enc := buf.Bytes()
nobj := SimpleStructV2{}
if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil {
t.Logf("got bad bytes: %x", enc)
t.Fatal("failed to round trip object: ", err)
}
if obj.OldStr != nobj.OldStr {
t.Fatal("mismatch ", obj.OldStr, " != ", nobj.OldStr)
}
if nobj.NewStr != "" {
t.Fatal("expected field to be zero value")
}
if obj.OldNum != nobj.OldNum {
t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum)
}
if nobj.NewNum != 0 {
t.Fatal("expected field to be zero value")
}
if !bytes.Equal(obj.OldBytes, nobj.OldBytes) {
t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes)
}
if nobj.NewBytes != nil {
t.Fatal("expected field to be zero value")
}
if *obj.OldPtr != *nobj.OldPtr {
t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr)
}
if nobj.NewPtr != nil {
t.Fatal("expected field to be zero value")
}
if !cmp.Equal(obj.OldMap, nobj.OldMap) {
t.Fatal("mismatch map marshal / unmarshal")
}
if len(nobj.NewMap) != 0 {
t.Fatal("expected field to be zero value")
}
if !cmp.Equal(obj.OldArray, nobj.OldArray) {
t.Fatal("mismatch array marshal / unmarshal")
}
if len(nobj.NewArray) != 0 {
t.Fatal("expected field to be zero value")
}
if !cmp.Equal(obj.OldStruct, nobj.OldStruct) {
t.Fatal("mismatch struct marshal / unmarshal")
}
if !cmp.Equal(nobj.NewStruct, SimpleTypeOne{}) {
t.Fatal("expected field to be zero value")
}
}
func TestMoreToLessFieldsRoundTrip(t *testing.T) {
dummyCid1, _ := cid.Parse("bafkqaaa")
dummyCid2, _ := cid.Parse("bafkqaab")
simpleType1 := SimpleTypeOne{
Foo: "foo",
Value: 1,
Binary: []byte("bin"),
Signed: -1,
NString: "namedstr",
}
simpleType2 := SimpleTypeOne{
Foo: "bar",
Value: 2,
Binary: []byte("bin2"),
Signed: -2,
NString: "namedstr2",
}
obj := &SimpleStructV2{
OldStr: "oldstr",
NewStr: "newstr",
OldBytes: []byte("oldbytes"),
NewBytes: []byte("newbytes"),
OldNum: 10,
NewNum: 11,
OldPtr: &dummyCid1,
NewPtr: &dummyCid2,
OldMap: map[string]SimpleTypeOne{"foo": simpleType1},
NewMap: map[string]SimpleTypeOne{"bar": simpleType2},
OldArray: []SimpleTypeOne{simpleType1},
NewArray: []SimpleTypeOne{simpleType1, simpleType2},
OldStruct: simpleType1,
NewStruct: simpleType2,
}
buf := new(bytes.Buffer)
if err := obj.MarshalCBOR(buf); err != nil {
t.Fatal("failed marshaling", err)
}
enc := buf.Bytes()
nobj := SimpleStructV1{}
if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil {
t.Logf("got bad bytes: %x", enc)
t.Fatal("failed to round trip object: ", err)
}
if obj.OldStr != nobj.OldStr {
t.Fatal("mismatch", obj.OldStr, " != ", nobj.OldStr)
}
if obj.OldNum != nobj.OldNum {
t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum)
}
if !bytes.Equal(obj.OldBytes, nobj.OldBytes) {
t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes)
}
if *obj.OldPtr != *nobj.OldPtr {
t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr)
}
if !cmp.Equal(obj.OldMap, nobj.OldMap) {
t.Fatal("mismatch map marshal / unmarshal")
}
if !cmp.Equal(obj.OldArray, nobj.OldArray) {
t.Fatal("mismatch array marshal / unmarshal")
}
if !cmp.Equal(obj.OldStruct, nobj.OldStruct) {
t.Fatal("mismatch struct marshal / unmarshal")
}
}
package testing package testing
import ( import (
"github.com/ipfs/go-cid"
cbg "github.com/whyrusleeping/cbor-gen" cbg "github.com/whyrusleeping/cbor-gen"
) )
...@@ -43,6 +44,39 @@ type SimpleTypeTree struct { ...@@ -43,6 +44,39 @@ type SimpleTypeTree struct {
NotPizza *uint64 NotPizza *uint64
} }
type SimpleStructV1 struct {
OldStr string
OldBytes []byte
OldNum uint64
OldPtr *cid.Cid
OldMap map[string]SimpleTypeOne
OldArray []SimpleTypeOne
OldStruct SimpleTypeOne
}
type SimpleStructV2 struct {
OldStr string
NewStr string
OldBytes []byte
NewBytes []byte
OldNum uint64
NewNum uint64
OldPtr *cid.Cid
NewPtr *cid.Cid
OldMap map[string]SimpleTypeOne
NewMap map[string]SimpleTypeOne
OldArray []SimpleTypeOne
NewArray []SimpleTypeOne
OldStruct SimpleTypeOne
NewStruct SimpleTypeOne
}
type DeferredContainer struct { type DeferredContainer struct {
Stuff *SimpleTypeOne Stuff *SimpleTypeOne
Deferred *cbg.Deferred Deferred *cbg.Deferred
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment