Unverified Commit e12d796b authored by Rod Vagg's avatar Rod Vagg

untangle unmarshal from pb

parent 9f788466
...@@ -3,14 +3,16 @@ package dagpb ...@@ -3,14 +3,16 @@ package dagpb
import ( import (
"io" "io"
"github.com/ipfs/go-cid"
ipld "github.com/ipld/go-ipld-prime" ipld "github.com/ipld/go-ipld-prime"
cidlink "github.com/ipld/go-ipld-prime/linking/cid" cidlink "github.com/ipld/go-ipld-prime/linking/cid"
pb "github.com/rvagg/go-dagpb/pb" "github.com/polydawn/refmt/shared"
"golang.org/x/xerrors"
) )
func Unmarshal(na ipld.NodeAssembler, reader io.Reader) error { var ErrIntOverflow = xerrors.Errorf("protobuf: varint overflow")
var curLink ipld.MapAssembler
func Unmarshal(na ipld.NodeAssembler, in io.Reader) error {
ma, err := na.BeginMap(2) ma, err := na.BeginMap(2)
if err != nil { if err != nil {
return err return err
...@@ -23,9 +25,31 @@ func Unmarshal(na ipld.NodeAssembler, reader io.Reader) error { ...@@ -23,9 +25,31 @@ func Unmarshal(na ipld.NodeAssembler, reader io.Reader) error {
return err return err
} }
tokenReceiver := func(tok pb.Token) error { haveData := false
switch tok.Type { reader := shared.NewReader(in)
case pb.TypeData: for {
_, err := reader.Readn1()
if err == io.EOF {
break
}
reader.Unreadn1()
fieldNum, wireType, err := decodeKey(reader)
if err != nil {
return err
}
if wireType != 2 {
return xerrors.Errorf("protobuf: (PBNode) invalid wireType, expected 2, got %d", wireType)
}
if fieldNum == 1 {
if haveData {
return xerrors.Errorf("protobuf: (PBNode) duplicate Data section")
}
var chunk []byte
if chunk, err = decodeBytes(reader); err != nil {
return err
}
if err := links.Finish(); err != nil { if err := links.Finish(); err != nil {
return err return err
} }
...@@ -33,50 +57,178 @@ func Unmarshal(na ipld.NodeAssembler, reader io.Reader) error { ...@@ -33,50 +57,178 @@ func Unmarshal(na ipld.NodeAssembler, reader io.Reader) error {
if err := ma.AssembleKey().AssignString("Data"); err != nil { if err := ma.AssembleKey().AssignString("Data"); err != nil {
return err return err
} }
if err := ma.AssembleValue().AssignBytes(tok.Bytes); err != nil { if err := ma.AssembleValue().AssignBytes(chunk); err != nil {
return err
}
haveData = true
} else if fieldNum == 2 {
if haveData {
return xerrors.Errorf("protobuf: (PBNode) invalid order, found Data before Links content")
}
bytesLen, err := decodeVarint(reader)
if err != nil {
return err
}
curLink, err := links.AssembleValue().BeginMap(3)
if err != nil {
return err
}
if err = unmarshalLink(reader, int(bytesLen), curLink); err != nil {
return err return err
} }
case pb.TypeLink:
case pb.TypeLinkEnd:
if err := curLink.Finish(); err != nil { if err := curLink.Finish(); err != nil {
return err return err
} }
case pb.TypeHash: } else {
curLink, err = links.AssembleValue().BeginMap(3) return xerrors.Errorf("protobuf: (PBNode) invalid fieldNumber, expected 1 or 2, got %d", fieldNum)
}
}
if links != nil {
if err := links.Finish(); err != nil {
return err
}
}
return ma.Finish()
}
func unmarshalLink(reader shared.SlickReader, length int, ma ipld.MapAssembler) error {
haveHash := false
haveName := false
haveTsize := false
startOffset := reader.NumRead()
for {
readBytes := reader.NumRead() - startOffset
if readBytes == length {
break
} else if readBytes > length {
return xerrors.Errorf("protobuf: (PBLink) bad length for link")
}
fieldNum, wireType, err := decodeKey(reader)
if err != nil { if err != nil {
return err return err
} }
if err := curLink.AssembleKey().AssignString("Hash"); err != nil {
if fieldNum == 1 {
if haveHash {
return xerrors.Errorf("protobuf: (PBLink) duplicate Hash section")
}
if haveName {
return xerrors.Errorf("protobuf: (PBLink) invalid order, found Name before Hash")
}
if haveTsize {
return xerrors.Errorf("protobuf: (PBLink) invalid order, found Tsize before Hash")
}
if wireType != 2 {
return xerrors.Errorf("protobuf: (PBLink) wrong wireType (%d) for Hash", wireType)
}
var chunk []byte
if chunk, err = decodeBytes(reader); err != nil {
return err return err
} }
if err := curLink.AssembleValue().AssignLink(cidlink.Link{*tok.Cid}); err != nil { var c cid.Cid
if _, c, err = cid.CidFromBytes(chunk); err != nil {
return xerrors.Errorf("invalid Hash field found in link, expected CID (%v)", err)
}
if err := ma.AssembleKey().AssignString("Hash"); err != nil {
return err return err
} }
case pb.TypeName: if err := ma.AssembleValue().AssignLink(cidlink.Link{Cid: c}); err != nil {
if err := curLink.AssembleKey().AssignString("Name"); err != nil {
return err return err
} }
if err := curLink.AssembleValue().AssignString(string(tok.Bytes)); err != nil { haveHash = true
} else if fieldNum == 2 {
if haveName {
return xerrors.Errorf("protobuf: (PBLink) duplicate Name section")
}
if haveTsize {
return xerrors.Errorf("protobuf: (PBLink) invalid order, found Tsize before Name")
}
if wireType != 2 {
return xerrors.Errorf("protobuf: (PBLink) wrong wireType (%d) for Name", wireType)
}
var chunk []byte
if chunk, err = decodeBytes(reader); err != nil {
return err return err
} }
case pb.TypeTSize: if err := ma.AssembleKey().AssignString("Name"); err != nil {
if err := curLink.AssembleKey().AssignString("Tsize"); err != nil {
return err return err
} }
if err := curLink.AssembleValue().AssignInt(int(tok.Int)); err != nil { if err := ma.AssembleValue().AssignString(string(chunk)); err != nil {
return err return err
} }
haveName = true
} else if fieldNum == 3 {
if haveTsize {
return xerrors.Errorf("protobuf: (PBLink) duplicate Tsize section")
} }
return nil if wireType != 0 {
return xerrors.Errorf("protobuf: (PBLink) wrong wireType (%d) for Tsize", wireType)
} }
if err := pb.Unmarshal(reader, tokenReceiver); err != nil { var v uint64
if v, err = decodeVarint(reader); err != nil {
return err return err
} }
if links != nil { if err := ma.AssembleKey().AssignString("Tsize"); err != nil {
if err := links.Finish(); err != nil { return err
}
if err := ma.AssembleValue().AssignInt(int(v)); err != nil {
return err return err
} }
haveTsize = true
} else {
return xerrors.Errorf("protobuf: (PBLink) invalid fieldNumber, expected 1, 2 or 3, got %d", fieldNum)
} }
return ma.Finish() }
if !haveHash {
return xerrors.Errorf("invalid Hash field found in link, expected CID")
}
return nil
}
func decodeKey(reader shared.SlickReader) (int, int, error) {
var wire uint64
var err error
if wire, err = decodeVarint(reader); err != nil {
return 0, 0, err
}
fieldNum := int(wire >> 3)
wireType := int(wire & 0x7)
return fieldNum, wireType, nil
}
func decodeBytes(reader shared.SlickReader) ([]byte, error) {
bytesLen, err := decodeVarint(reader)
if err != nil {
return nil, err
}
byts, err := reader.Readn(int(bytesLen))
if err != nil {
return nil, xerrors.Errorf("protobuf: unexpected read error: %w", err)
}
return byts, nil
}
func decodeVarint(reader shared.SlickReader) (uint64, error) {
var v uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflow
}
b, err := reader.Readn1()
if err != nil {
return 0, xerrors.Errorf("protobuf: unexpected read error: %w", err)
}
v |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
return v, nil
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment