diff --git a/testgen/main.go b/testgen/main.go index 3d03ab31f3e7901f22cd1264c9a08e4311930c92..acf5be4423a2a85bfc7d7e9e4600ecff9c7d7707 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -12,6 +12,7 @@ func main() { types.SimpleTypeTwo{}, types.DeferredContainer{}, types.FixedArrays{}, + types.ThingWithSomeTime{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 186f9a2d9cc4aa7338dcfe3cf1905b42a66319bd..b3e2e14528a1422c18b8073c43afe40702a5d4e0 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -945,3 +945,111 @@ func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error { return nil } + +var lengthBufThingWithSomeTime = []byte{131} + +func (t *ThingWithSomeTime) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write(lengthBufThingWithSomeTime); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.When (typegen.CborTime) (struct) + if err := t.When.MarshalCBOR(w); err != nil { + return err + } + + // t.Stuff (int64) (int64) + if t.Stuff >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Stuff)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.Stuff-1)); err != nil { + return err + } + } + + // t.CatName (string) (string) + if len(t.CatName) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.CatName was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.CatName))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.CatName)); err != nil { + return err + } + return nil +} + +func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) error { + *t = ThingWithSomeTime{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 3 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.When (typegen.CborTime) (struct) + + { + + if err := t.When.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.When: %w", err) + } + + } + // t.Stuff (int64) (int64) + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Stuff = int64(extraI) + } + // t.CatName (string) (string) + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.CatName = string(sval) + } + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index 642e539b6392d653c3c91e889089d9cb034b6f38..4bf92e9b1d6fff2f21706c71cd8c7a09c073b51b 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -6,6 +6,7 @@ import ( "reflect" "testing" "testing/quick" + "time" "github.com/google/go-cmp/cmp" cbg "github.com/whyrusleeping/cbor-gen" @@ -115,3 +116,34 @@ func TestFixedArrays(t *testing.T) { recepticle := &FixedArrays{} testValueRoundtrip(t, zero, recepticle) } + +func TestTimeIsh(t *testing.T) { + val := &ThingWithSomeTime{ + When: cbg.CborTime(time.Now()), + Stuff: 1234, + CatName: "hank", + } + + buf := new(bytes.Buffer) + if err := val.MarshalCBOR(buf); err != nil { + t.Fatal(err) + } + + out := ThingWithSomeTime{} + if err := out.UnmarshalCBOR(buf); err != nil { + t.Fatal(err) + } + + if out.When.Time().UnixNano() != val.When.Time().UnixNano() { + t.Fatal("time didnt round trip properly", out.When.Time(), val.When.Time()) + } + + if out.Stuff != val.Stuff { + t.Fatal("no") + } + + if out.CatName != val.CatName { + t.Fatal("no") + } + +} diff --git a/testing/types.go b/testing/types.go index 623ac6045c0d458731621e0072146ee72175316b..58a8d2c2decadac932cde36fc4f8432469f5b935 100644 --- a/testing/types.go +++ b/testing/types.go @@ -55,6 +55,12 @@ type FixedArrays struct { Uint64 [20]uint64 } +type ThingWithSomeTime struct { + When cbg.CborTime + Stuff int64 + CatName string +} + // Do not add fields to this type. type NeedScratchForMap struct { Thing bool diff --git a/utils.go b/utils.go index be71c37f718f807d816aaaaef57f017710a45b31..6be5c7728bd72b9d2397ea524b844e74edd398ec 100644 --- a/utils.go +++ b/utils.go @@ -9,6 +9,7 @@ import ( "io" "io/ioutil" "math" + "time" cid "github.com/ipfs/go-cid" ) @@ -691,3 +692,50 @@ func (ci *CborInt) UnmarshalCBOR(r io.Reader) error { *ci = CborInt(extraI) return nil } + +type CborTime time.Time + +func (ct *CborTime) MarshalCBOR(w io.Writer) error { + b, err := (*time.Time)(ct).MarshalBinary() + if err != nil { + return err + } + + if err := CborWriteHeader(w, MajByteString, uint64(len(b))); err != nil { + return err + } + + if _, err := w.Write(b); err != nil { + return err + } + + return nil +} + +func (ct *CborTime) UnmarshalCBOR(r io.Reader) error { + t, l, err := CborReadHeader(r) + if err != nil { + return err + } + + if t != MajByteString { + return fmt.Errorf("CborTime expects to find a byte array (got %d)", t) + } + + buf := make([]byte, l) + if _, err := io.ReadFull(r, buf); err != nil { + return err + } + + tm := time.Time{} + if err := tm.UnmarshalBinary(buf); err != nil { + return err + } + + *ct = (CborTime)(tm) + return nil +} + +func (ct CborTime) Time() time.Time { + return (time.Time)(ct) +}