Unverified Commit 958ddffe authored by Whyrusleeping's avatar Whyrusleeping Committed by GitHub

Merge pull request #38 from Stebalien/feat/imports

Fix import handling
parents 4fed7092 f9912a4f
......@@ -3,15 +3,24 @@ package typegen
import (
"fmt"
"io"
"math/big"
"reflect"
"strings"
"text/template"
"github.com/ipfs/go-cid"
)
const MaxLength = 8192
const ByteArrayMaxLen = 2 << 20
var (
cidType = reflect.TypeOf(cid.Cid{})
bigIntType = reflect.TypeOf(big.Int{})
deferredType = reflect.TypeOf(Deferred{})
)
func doTemplate(w io.Writer, info interface{}, templ string) error {
t := template.Must(template.New("").
Funcs(template.FuncMap{
......@@ -28,10 +37,19 @@ func doTemplate(w io.Writer, info interface{}, templ string) error {
return t.Execute(w, info)
}
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error {
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error {
var imports []Import
for _, gti := range typeInfos {
imports = append(imports, gti.Imports()...)
}
imports = append(imports, defaultImports...)
imports = dedupImports(imports)
data := struct {
Package string
}{pkg}
Imports []Import
}{pkg, imports}
return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT.
package {{ .Package }}
......@@ -39,8 +57,9 @@ package {{ .Package }}
import (
"fmt"
"io"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }}
)
......@@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string {
case reflect.Map:
return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem())
default:
name := t.String()
if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" {
name = "cbg." + strings.TrimPrefix(name, "typegen.")
} else {
name = strings.TrimPrefix(name, pkg+".")
pkgPath := t.PkgPath()
if pkgPath == "" {
// It's a built-in.
return t.String()
} else if pkgPath == pkg {
return t.Name()
}
return name
return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name())
}
}
......@@ -100,6 +120,25 @@ type GenTypeInfo struct {
Fields []Field
}
func (gti *GenTypeInfo) Imports() []Import {
var imports []Import
for _, f := range gti.Fields {
switch f.Type.Kind() {
case reflect.Struct:
if !f.Pointer && f.Type != bigIntType {
continue
}
if f.Type == cidType {
continue
}
case reflect.Bool:
continue
}
imports = append(imports, ImportsForType(f.Pkg, f.Type)...)
}
return imports
}
func (gti *GenTypeInfo) NeedsScratch() bool {
for _, f := range gti.Fields {
switch f.Type.Kind() {
......@@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool {
return true
case reflect.Struct:
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
return true
case "github.com/ipfs/go-cid.Cid":
if f.Type == bigIntType || f.Type == cidType {
return true
}
// nope
......@@ -132,9 +167,11 @@ func nameIsExported(name string) bool {
return strings.ToUpper(name[0:1]) == name[0:1]
}
func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) {
func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
t := reflect.TypeOf(i)
pkg := t.PkgPath()
out := GenTypeInfo{
Name: t.Name(),
}
......@@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error {
`)
}
func emitCborMarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
{
if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil {
......@@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error {
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{{ if .Pointer }}
if {{ .Name }} == nil {
......@@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
default:
return fmt.Errorf("do not yet support slices of %s yet", e.Kind())
case reflect.Struct:
fname := e.PkgPath() + "." + e.Name()
switch fname {
case "github.com/ipfs/go-cid.Cid":
switch e {
case cidType:
err := doTemplate(w, f, `
if err := cbg.WriteCidBuf(scratch, w, v); err != nil {
return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err)
......@@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error {
}
func emitCborUnmarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
maj, extra, err = {{ ReadHeader "br" }}
if err != nil {
......@@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ .Name }} = big.NewInt(0)
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
......@@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ end }}
}
`)
case "github.com/whyrusleeping/cbor-gen.Deferred":
case deferredType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
......@@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}
// Generates 'tuple representation' cbor encoders for the given type
func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructTuple(w, gti); err != nil {
return err
}
......@@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}
// Generates 'tuple representation' cbor encoders for the given type
func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructMap(w, gti); err != nil {
return err
}
......
package typegen
import (
"fmt"
"reflect"
"sort"
"strings"
"sync"
)
var (
knownPackageNamesMu sync.Mutex
pkgNameToPkgPath = make(map[string]string)
pkgPathToPkgName = make(map[string]string)
defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
}
)
func init() {
for _, imp := range defaultImports {
if was, conflict := pkgNameToPkgPath[imp.Name]; conflict {
panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was))
}
if _, conflict := pkgPathToPkgName[imp.Name]; conflict {
panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath))
}
pkgNameToPkgPath[imp.Name] = imp.PkgPath
pkgPathToPkgName[imp.PkgPath] = imp.Name
}
}
func resolvePkgName(path, typeName string) string {
parts := strings.Split(typeName, ".")
if len(parts) != 2 {
panic(fmt.Sprintf("expected type to have a package name: %s", typeName))
}
defaultName := parts[0]
knownPackageNamesMu.Lock()
defer knownPackageNamesMu.Unlock()
// Check for a known name and use it.
if name, ok := pkgPathToPkgName[path]; ok {
return name
}
// Allocate a name.
for i := 0; ; i++ {
tryName := defaultName
if i > 0 {
tryName = fmt.Sprintf("%s%d", defaultName, i)
}
if _, taken := pkgNameToPkgPath[tryName]; !taken {
pkgNameToPkgPath[tryName] = path
pkgPathToPkgName[path] = tryName
return tryName
}
}
}
type Import struct {
Name, PkgPath string
}
func ImportsForType(currPkg string, t reflect.Type) []Import {
switch t.Kind() {
case reflect.Array, reflect.Slice, reflect.Ptr:
return ImportsForType(currPkg, t.Elem())
case reflect.Map:
return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...))
default:
path := t.PkgPath()
if path == "" || path == currPkg {
// built-in or in current package.
return nil
}
return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}}
}
}
func dedupImports(imps []Import) []Import {
impSet := make(map[string]string, len(imps))
for _, imp := range imps {
impSet[imp.PkgPath] = imp.Name
}
deduped := make([]Import, 0, len(imps))
for pkg, name := range impSet {
deduped = append(deduped, Import{Name: name, PkgPath: pkg})
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].PkgPath < deduped[j].PkgPath
})
return deduped
}
......@@ -4,7 +4,6 @@ import (
"bytes"
"go/format"
"os"
"os/exec"
"golang.org/x/xerrors"
)
......@@ -12,12 +11,21 @@ import (
func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil {
typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for _, t := range types {
if err := GenTupleEncodersForType(pkg, t, buf); err != nil {
for _, t := range typeInfos {
if err := GenTupleEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
......@@ -39,22 +47,27 @@ func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
}
_ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil
}
func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil {
typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for _, t := range types {
if err := GenMapEncodersForType(pkg, t, buf); err != nil {
for _, t := range typeInfos {
if err := GenMapEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
......@@ -76,9 +89,5 @@ func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
}
_ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment