From 829e0e1a206763af3e0985d77a0747465e4f3857 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 13 Aug 2020 16:53:13 -0700 Subject: [PATCH] Add a validate function. This function is equivalent to Deferred.UnmarshalCBOR, just more efficient because it doesn't copy. --- validate.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ validate_test.go | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 validate.go create mode 100644 validate_test.go diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..04a5e26 --- /dev/null +++ b/validate.go @@ -0,0 +1,59 @@ +package typegen + +import ( + "bytes" + "fmt" + "io" +) + +// ValidateCBOR validates that a byte array is a single valid CBOR object. +func ValidateCBOR(b []byte) error { + // The code here is basically identical to the previous function, it + // just doesn't copy. + + br := bytes.NewReader(b) + + // Allocate some scratch space. + scratch := make([]byte, maxHeaderSize) + + for remaining := uint64(1); remaining > 0; remaining-- { + maj, extra, err := CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + switch maj { + case MajUnsignedInt, MajNegativeInt, MajOther: + // nothing fancy to do + case MajByteString, MajTextString: + if extra > ByteArrayMaxLen { + return maxLengthError + } + if uint64(br.Len()) < extra { + return io.ErrUnexpectedEOF + } + + if _, err := br.Seek(int64(extra), io.SeekCurrent); err != nil { + return err + } + case MajTag: + remaining++ + case MajArray: + if extra > MaxLength { + return maxLengthError + } + remaining += extra + case MajMap: + if extra > MaxLength { + return maxLengthError + } + remaining += extra * 2 + default: + return fmt.Errorf("unhandled deferred cbor type: %d", maj) + } + } + if br.Len() > 0 { + return fmt.Errorf("unexpected %d unread bytes", br.Len()) + } + return nil +} diff --git a/validate_test.go b/validate_test.go new file mode 100644 index 0000000..6bf0f9e --- /dev/null +++ b/validate_test.go @@ -0,0 +1,42 @@ +package typegen + +import ( + "bytes" + "testing" +) + +func TestValidateShort(t *testing.T) { + var buf bytes.Buffer + if err := WriteMajorTypeHeader(&buf, MajByteString, 100); err != nil { + t.Fatal("failed to write header") + } + + if err := ValidateCBOR(buf.Bytes()); err == nil { + t.Fatal("expected an error checking truncated cbor") + } +} + +func TestValidateDouble(t *testing.T) { + var buf bytes.Buffer + if err := WriteBool(&buf, false); err != nil { + t.Fatal(err) + } + if err := WriteBool(&buf, false); err != nil { + t.Fatal(err) + } + + if err := ValidateCBOR(buf.Bytes()); err == nil { + t.Fatal("expected an error checking cbor with two objects") + } +} + +func TestValidate(t *testing.T) { + var buf bytes.Buffer + if err := WriteBool(&buf, false); err != nil { + t.Fatal(err) + } + + if err := ValidateCBOR(buf.Bytes()); err != nil { + t.Fatal(err) + } +} -- GitLab