Commit f9912a4f authored by Steven Allen's avatar Steven Allen

Fix import handling

Unfortunately, while goimports works in many cases, it won't always work. This
change handles imports internally.
parent fdf2ae9b
......@@ -3,15 +3,24 @@ package typegen
import (
"fmt"
"io"
"math/big"
"reflect"
"strings"
"text/template"
"github.com/ipfs/go-cid"
)
const MaxLength = 8192
const ByteArrayMaxLen = 2 << 20
var (
cidType = reflect.TypeOf(cid.Cid{})
bigIntType = reflect.TypeOf(big.Int{})
deferredType = reflect.TypeOf(Deferred{})
)
func doTemplate(w io.Writer, info interface{}, templ string) error {
t := template.Must(template.New("").
Funcs(template.FuncMap{
......@@ -28,10 +37,19 @@ func doTemplate(w io.Writer, info interface{}, templ string) error {
return t.Execute(w, info)
}
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error {
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error {
var imports []Import
for _, gti := range typeInfos {
imports = append(imports, gti.Imports()...)
}
imports = append(imports, defaultImports...)
imports = dedupImports(imports)
data := struct {
Package string
}{pkg}
Imports []Import
}{pkg, imports}
return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT.
package {{ .Package }}
......@@ -39,8 +57,9 @@ package {{ .Package }}
import (
"fmt"
"io"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
{{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }}
)
......@@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string {
case reflect.Map:
return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem())
default:
name := t.String()
if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" {
name = "cbg." + strings.TrimPrefix(name, "typegen.")
} else {
name = strings.TrimPrefix(name, pkg+".")
pkgPath := t.PkgPath()
if pkgPath == "" {
// It's a built-in.
return t.String()
} else if pkgPath == pkg {
return t.Name()
}
return name
return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name())
}
}
......@@ -100,6 +120,25 @@ type GenTypeInfo struct {
Fields []Field
}
func (gti *GenTypeInfo) Imports() []Import {
var imports []Import
for _, f := range gti.Fields {
switch f.Type.Kind() {
case reflect.Struct:
if !f.Pointer && f.Type != bigIntType {
continue
}
if f.Type == cidType {
continue
}
case reflect.Bool:
continue
}
imports = append(imports, ImportsForType(f.Pkg, f.Type)...)
}
return imports
}
func (gti *GenTypeInfo) NeedsScratch() bool {
for _, f := range gti.Fields {
switch f.Type.Kind() {
......@@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool {
return true
case reflect.Struct:
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
return true
case "github.com/ipfs/go-cid.Cid":
if f.Type == bigIntType || f.Type == cidType {
return true
}
// nope
......@@ -132,9 +167,11 @@ func nameIsExported(name string) bool {
return strings.ToUpper(name[0:1]) == name[0:1]
}
func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) {
func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
t := reflect.TypeOf(i)
pkg := t.PkgPath()
out := GenTypeInfo{
Name: t.Name(),
}
......@@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error {
`)
}
func emitCborMarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
{
if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil {
......@@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error {
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{{ if .Pointer }}
if {{ .Name }} == nil {
......@@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
default:
return fmt.Errorf("do not yet support slices of %s yet", e.Kind())
case reflect.Struct:
fname := e.PkgPath() + "." + e.Name()
switch fname {
case "github.com/ipfs/go-cid.Cid":
switch e {
case cidType:
err := doTemplate(w, f, `
if err := cbg.WriteCidBuf(scratch, w, v); err != nil {
return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err)
......@@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error {
}
func emitCborUnmarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name()
switch fname {
case "math/big.Int":
switch f.Type {
case bigIntType:
return doTemplate(w, f, `
maj, extra, err = {{ ReadHeader "br" }}
if err != nil {
......@@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ .Name }} = big.NewInt(0)
}
`)
case "github.com/ipfs/go-cid.Cid":
case cidType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
......@@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ end }}
}
`)
case "github.com/whyrusleeping/cbor-gen.Deferred":
case deferredType:
return doTemplate(w, f, `
{
{{ if .Pointer }}
......@@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}
// Generates 'tuple representation' cbor encoders for the given type
func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructTuple(w, gti); err != nil {
return err
}
......@@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
}
// Generates 'tuple representation' cbor encoders for the given type
func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error {
if err := emitCborMarshalStructMap(w, gti); err != nil {
return err
}
......
package typegen
import (
"fmt"
"reflect"
"sort"
"strings"
"sync"
)
var (
knownPackageNamesMu sync.Mutex
pkgNameToPkgPath = make(map[string]string)
pkgPathToPkgName = make(map[string]string)
defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
}
)
func init() {
for _, imp := range defaultImports {
if was, conflict := pkgNameToPkgPath[imp.Name]; conflict {
panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was))
}
if _, conflict := pkgPathToPkgName[imp.Name]; conflict {
panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath))
}
pkgNameToPkgPath[imp.Name] = imp.PkgPath
pkgPathToPkgName[imp.PkgPath] = imp.Name
}
}
func resolvePkgName(path, typeName string) string {
parts := strings.Split(typeName, ".")
if len(parts) != 2 {
panic(fmt.Sprintf("expected type to have a package name: %s", typeName))
}
defaultName := parts[0]
knownPackageNamesMu.Lock()
defer knownPackageNamesMu.Unlock()
// Check for a known name and use it.
if name, ok := pkgPathToPkgName[path]; ok {
return name
}
// Allocate a name.
for i := 0; ; i++ {
tryName := defaultName
if i > 0 {
tryName = fmt.Sprintf("%s%d", defaultName, i)
}
if _, taken := pkgNameToPkgPath[tryName]; !taken {
pkgNameToPkgPath[tryName] = path
pkgPathToPkgName[path] = tryName
return tryName
}
}
}
type Import struct {
Name, PkgPath string
}
func ImportsForType(currPkg string, t reflect.Type) []Import {
switch t.Kind() {
case reflect.Array, reflect.Slice, reflect.Ptr:
return ImportsForType(currPkg, t.Elem())
case reflect.Map:
return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...))
default:
path := t.PkgPath()
if path == "" || path == currPkg {
// built-in or in current package.
return nil
}
return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}}
}
}
func dedupImports(imps []Import) []Import {
impSet := make(map[string]string, len(imps))
for _, imp := range imps {
impSet[imp.PkgPath] = imp.Name
}
deduped := make([]Import, 0, len(imps))
for pkg, name := range impSet {
deduped = append(deduped, Import{Name: name, PkgPath: pkg})
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].PkgPath < deduped[j].PkgPath
})
return deduped
}
......@@ -4,7 +4,6 @@ import (
"bytes"
"go/format"
"os"
"os/exec"
"golang.org/x/xerrors"
)
......@@ -12,12 +11,21 @@ import (
func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil {
typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for _, t := range types {
if err := GenTupleEncodersForType(pkg, t, buf); err != nil {
for _, t := range typeInfos {
if err := GenTupleEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
......@@ -39,22 +47,27 @@ func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
}
_ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil
}
func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil {
typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for _, t := range types {
if err := GenMapEncodersForType(pkg, t, buf); err != nil {
for _, t := range typeInfos {
if err := GenMapEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
......@@ -76,9 +89,5 @@ func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
}
_ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment