Commit 868c0fd5 authored by whyrusleeping's avatar whyrusleeping

add a helper for roundtripping time.time objects

parent 534767cc
......@@ -12,6 +12,7 @@ func main() {
types.SimpleTypeTwo{},
types.DeferredContainer{},
types.FixedArrays{},
types.ThingWithSomeTime{},
); err != nil {
panic(err)
}
......
......@@ -945,3 +945,111 @@ func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error {
return nil
}
var lengthBufThingWithSomeTime = []byte{131}
func (t *ThingWithSomeTime) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
}
if _, err := w.Write(lengthBufThingWithSomeTime); err != nil {
return err
}
scratch := make([]byte, 9)
// t.When (typegen.CborTime) (struct)
if err := t.When.MarshalCBOR(w); err != nil {
return err
}
// t.Stuff (int64) (int64)
if t.Stuff >= 0 {
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Stuff)); err != nil {
return err
}
} else {
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.Stuff-1)); err != nil {
return err
}
}
// t.CatName (string) (string)
if len(t.CatName) > cbg.MaxLength {
return xerrors.Errorf("Value in field t.CatName was too long")
}
if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.CatName))); err != nil {
return err
}
if _, err := io.WriteString(w, string(t.CatName)); err != nil {
return err
}
return nil
}
func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) error {
*t = ThingWithSomeTime{}
br := cbg.GetPeeker(r)
scratch := make([]byte, 8)
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
if err != nil {
return err
}
if maj != cbg.MajArray {
return fmt.Errorf("cbor input should be of type array")
}
if extra != 3 {
return fmt.Errorf("cbor input had wrong number of fields")
}
// t.When (typegen.CborTime) (struct)
{
if err := t.When.UnmarshalCBOR(br); err != nil {
return xerrors.Errorf("unmarshaling t.When: %w", err)
}
}
// t.Stuff (int64) (int64)
{
maj, extra, err := cbg.CborReadHeaderBuf(br, scratch)
var extraI int64
if err != nil {
return err
}
switch maj {
case cbg.MajUnsignedInt:
extraI = int64(extra)
if extraI < 0 {
return fmt.Errorf("int64 positive overflow")
}
case cbg.MajNegativeInt:
extraI = int64(extra)
if extraI < 0 {
return fmt.Errorf("int64 negative oveflow")
}
extraI = -1 - extraI
default:
return fmt.Errorf("wrong type for int64 field: %d", maj)
}
t.Stuff = int64(extraI)
}
// t.CatName (string) (string)
{
sval, err := cbg.ReadStringBuf(br, scratch)
if err != nil {
return err
}
t.CatName = string(sval)
}
return nil
}
......@@ -6,6 +6,7 @@ import (
"reflect"
"testing"
"testing/quick"
"time"
"github.com/google/go-cmp/cmp"
cbg "github.com/whyrusleeping/cbor-gen"
......@@ -115,3 +116,34 @@ func TestFixedArrays(t *testing.T) {
recepticle := &FixedArrays{}
testValueRoundtrip(t, zero, recepticle)
}
func TestTimeIsh(t *testing.T) {
val := &ThingWithSomeTime{
When: cbg.CborTime(time.Now()),
Stuff: 1234,
CatName: "hank",
}
buf := new(bytes.Buffer)
if err := val.MarshalCBOR(buf); err != nil {
t.Fatal(err)
}
out := ThingWithSomeTime{}
if err := out.UnmarshalCBOR(buf); err != nil {
t.Fatal(err)
}
if out.When.Time().UnixNano() != val.When.Time().UnixNano() {
t.Fatal("time didnt round trip properly", out.When.Time(), val.When.Time())
}
if out.Stuff != val.Stuff {
t.Fatal("no")
}
if out.CatName != val.CatName {
t.Fatal("no")
}
}
......@@ -55,6 +55,12 @@ type FixedArrays struct {
Uint64 [20]uint64
}
type ThingWithSomeTime struct {
When cbg.CborTime
Stuff int64
CatName string
}
// Do not add fields to this type.
type NeedScratchForMap struct {
Thing bool
......
......@@ -9,6 +9,7 @@ import (
"io"
"io/ioutil"
"math"
"time"
cid "github.com/ipfs/go-cid"
)
......@@ -691,3 +692,50 @@ func (ci *CborInt) UnmarshalCBOR(r io.Reader) error {
*ci = CborInt(extraI)
return nil
}
type CborTime time.Time
func (ct *CborTime) MarshalCBOR(w io.Writer) error {
b, err := (*time.Time)(ct).MarshalBinary()
if err != nil {
return err
}
if err := CborWriteHeader(w, MajByteString, uint64(len(b))); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return nil
}
func (ct *CborTime) UnmarshalCBOR(r io.Reader) error {
t, l, err := CborReadHeader(r)
if err != nil {
return err
}
if t != MajByteString {
return fmt.Errorf("CborTime expects to find a byte array (got %d)", t)
}
buf := make([]byte, l)
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
tm := time.Time{}
if err := tm.UnmarshalBinary(buf); err != nil {
return err
}
*ct = (CborTime)(tm)
return nil
}
func (ct CborTime) Time() time.Time {
return (time.Time)(ct)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment