Commit f9912a4f authored by Steven Allen's avatar Steven Allen

Fix import handling

Unfortunately, while goimports works in many cases, it won't always work. This
change handles imports internally.
parent fdf2ae9b
...@@ -3,15 +3,24 @@ package typegen ...@@ -3,15 +3,24 @@ package typegen
import ( import (
"fmt" "fmt"
"io" "io"
"math/big"
"reflect" "reflect"
"strings" "strings"
"text/template" "text/template"
"github.com/ipfs/go-cid"
) )
const MaxLength = 8192 const MaxLength = 8192
const ByteArrayMaxLen = 2 << 20 const ByteArrayMaxLen = 2 << 20
var (
cidType = reflect.TypeOf(cid.Cid{})
bigIntType = reflect.TypeOf(big.Int{})
deferredType = reflect.TypeOf(Deferred{})
)
func doTemplate(w io.Writer, info interface{}, templ string) error { func doTemplate(w io.Writer, info interface{}, templ string) error {
t := template.Must(template.New(""). t := template.Must(template.New("").
Funcs(template.FuncMap{ Funcs(template.FuncMap{
...@@ -28,10 +37,19 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { ...@@ -28,10 +37,19 @@ func doTemplate(w io.Writer, info interface{}, templ string) error {
return t.Execute(w, info) return t.Execute(w, info)
} }
func PrintHeaderAndUtilityMethods(w io.Writer, pkg string) error { func PrintHeaderAndUtilityMethods(w io.Writer, pkg string, typeInfos []*GenTypeInfo) error {
var imports []Import
for _, gti := range typeInfos {
imports = append(imports, gti.Imports()...)
}
imports = append(imports, defaultImports...)
imports = dedupImports(imports)
data := struct { data := struct {
Package string Package string
}{pkg} Imports []Import
}{pkg, imports}
return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. return doTemplate(w, data, `// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT.
package {{ .Package }} package {{ .Package }}
...@@ -39,8 +57,9 @@ package {{ .Package }} ...@@ -39,8 +57,9 @@ package {{ .Package }}
import ( import (
"fmt" "fmt"
"io" "io"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors" {{ range .Imports }}{{ .Name }} "{{ .PkgPath }}"
{{ end }}
) )
...@@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string { ...@@ -69,13 +88,14 @@ func typeName(pkg string, t reflect.Type) string {
case reflect.Map: case reflect.Map:
return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem()) return "map[" + typeName(pkg, t.Key()) + "]" + typeName(pkg, t.Elem())
default: default:
name := t.String() pkgPath := t.PkgPath()
if t.PkgPath() == "github.com/whyrusleeping/cbor-gen" { if pkgPath == "" {
name = "cbg." + strings.TrimPrefix(name, "typegen.") // It's a built-in.
} else { return t.String()
name = strings.TrimPrefix(name, pkg+".") } else if pkgPath == pkg {
return t.Name()
} }
return name return fmt.Sprintf("%s.%s", resolvePkgName(pkgPath, t.String()), t.Name())
} }
} }
...@@ -100,6 +120,25 @@ type GenTypeInfo struct { ...@@ -100,6 +120,25 @@ type GenTypeInfo struct {
Fields []Field Fields []Field
} }
func (gti *GenTypeInfo) Imports() []Import {
var imports []Import
for _, f := range gti.Fields {
switch f.Type.Kind() {
case reflect.Struct:
if !f.Pointer && f.Type != bigIntType {
continue
}
if f.Type == cidType {
continue
}
case reflect.Bool:
continue
}
imports = append(imports, ImportsForType(f.Pkg, f.Type)...)
}
return imports
}
func (gti *GenTypeInfo) NeedsScratch() bool { func (gti *GenTypeInfo) NeedsScratch() bool {
for _, f := range gti.Fields { for _, f := range gti.Fields {
switch f.Type.Kind() { switch f.Type.Kind() {
...@@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool { ...@@ -113,11 +152,7 @@ func (gti *GenTypeInfo) NeedsScratch() bool {
return true return true
case reflect.Struct: case reflect.Struct:
fname := f.Type.PkgPath() + "." + f.Type.Name() if f.Type == bigIntType || f.Type == cidType {
switch fname {
case "math/big.Int":
return true
case "github.com/ipfs/go-cid.Cid":
return true return true
} }
// nope // nope
...@@ -132,9 +167,11 @@ func nameIsExported(name string) bool { ...@@ -132,9 +167,11 @@ func nameIsExported(name string) bool {
return strings.ToUpper(name[0:1]) == name[0:1] return strings.ToUpper(name[0:1]) == name[0:1]
} }
func ParseTypeInfo(pkg string, i interface{}) (*GenTypeInfo, error) { func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
t := reflect.TypeOf(i) t := reflect.TypeOf(i)
pkg := t.PkgPath()
out := GenTypeInfo{ out := GenTypeInfo{
Name: t.Name(), Name: t.Name(),
} }
...@@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error { ...@@ -208,9 +245,8 @@ func emitCborMarshalStringField(w io.Writer, f Field) error {
`) `)
} }
func emitCborMarshalStructField(w io.Writer, f Field) error { func emitCborMarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name() switch f.Type {
switch fname { case bigIntType:
case "math/big.Int":
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil { if err := cbg.CborWriteHeader(w, cbg.MajTag, 2); err != nil {
...@@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error { ...@@ -230,7 +266,7 @@ func emitCborMarshalStructField(w io.Writer, f Field) error {
} }
`) `)
case "github.com/ipfs/go-cid.Cid": case cidType:
return doTemplate(w, f, ` return doTemplate(w, f, `
{{ if .Pointer }} {{ if .Pointer }}
if {{ .Name }} == nil { if {{ .Name }} == nil {
...@@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { ...@@ -403,9 +439,8 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
default: default:
return fmt.Errorf("do not yet support slices of %s yet", e.Kind()) return fmt.Errorf("do not yet support slices of %s yet", e.Kind())
case reflect.Struct: case reflect.Struct:
fname := e.PkgPath() + "." + e.Name() switch e {
switch fname { case cidType:
case "github.com/ipfs/go-cid.Cid":
err := doTemplate(w, f, ` err := doTemplate(w, f, `
if err := cbg.WriteCidBuf(scratch, w, v); err != nil { if err := cbg.WriteCidBuf(scratch, w, v); err != nil {
return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err) return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err)
...@@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error { ...@@ -548,10 +583,8 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error {
} }
func emitCborUnmarshalStructField(w io.Writer, f Field) error { func emitCborUnmarshalStructField(w io.Writer, f Field) error {
fname := f.Type.PkgPath() + "." + f.Type.Name() switch f.Type {
case bigIntType:
switch fname {
case "math/big.Int":
return doTemplate(w, f, ` return doTemplate(w, f, `
maj, extra, err = {{ ReadHeader "br" }} maj, extra, err = {{ ReadHeader "br" }}
if err != nil { if err != nil {
...@@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { ...@@ -585,7 +618,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ .Name }} = big.NewInt(0) {{ .Name }} = big.NewInt(0)
} }
`) `)
case "github.com/ipfs/go-cid.Cid": case cidType:
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
{{ if .Pointer }} {{ if .Pointer }}
...@@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { ...@@ -610,7 +643,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error {
{{ end }} {{ end }}
} }
`) `)
case "github.com/whyrusleeping/cbor-gen.Deferred": case deferredType:
return doTemplate(w, f, ` return doTemplate(w, f, `
{ {
{{ if .Pointer }} {{ if .Pointer }}
...@@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { ...@@ -1059,12 +1092,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
} }
// Generates 'tuple representation' cbor encoders for the given type // Generates 'tuple representation' cbor encoders for the given type
func GenTupleEncodersForType(inpkg string, i interface{}, w io.Writer) error { func GenTupleEncodersForType(gti *GenTypeInfo, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
if err := emitCborMarshalStructTuple(w, gti); err != nil { if err := emitCborMarshalStructTuple(w, gti); err != nil {
return err return err
} }
...@@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { ...@@ -1251,12 +1279,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
} }
// Generates 'tuple representation' cbor encoders for the given type // Generates 'tuple representation' cbor encoders for the given type
func GenMapEncodersForType(inpkg string, i interface{}, w io.Writer) error { func GenMapEncodersForType(gti *GenTypeInfo, w io.Writer) error {
gti, err := ParseTypeInfo(inpkg, i)
if err != nil {
return err
}
if err := emitCborMarshalStructMap(w, gti); err != nil { if err := emitCborMarshalStructMap(w, gti); err != nil {
return err return err
} }
......
package typegen
import (
"fmt"
"reflect"
"sort"
"strings"
"sync"
)
var (
knownPackageNamesMu sync.Mutex
pkgNameToPkgPath = make(map[string]string)
pkgPathToPkgName = make(map[string]string)
defaultImports = []Import{
{Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"},
{Name: "xerrors", PkgPath: "golang.org/x/xerrors"},
}
)
func init() {
for _, imp := range defaultImports {
if was, conflict := pkgNameToPkgPath[imp.Name]; conflict {
panic(fmt.Sprintf("reused pkg name %s for %s and %s", imp.Name, imp.PkgPath, was))
}
if _, conflict := pkgPathToPkgName[imp.Name]; conflict {
panic(fmt.Sprintf("duplicate default import %s", imp.PkgPath))
}
pkgNameToPkgPath[imp.Name] = imp.PkgPath
pkgPathToPkgName[imp.PkgPath] = imp.Name
}
}
func resolvePkgName(path, typeName string) string {
parts := strings.Split(typeName, ".")
if len(parts) != 2 {
panic(fmt.Sprintf("expected type to have a package name: %s", typeName))
}
defaultName := parts[0]
knownPackageNamesMu.Lock()
defer knownPackageNamesMu.Unlock()
// Check for a known name and use it.
if name, ok := pkgPathToPkgName[path]; ok {
return name
}
// Allocate a name.
for i := 0; ; i++ {
tryName := defaultName
if i > 0 {
tryName = fmt.Sprintf("%s%d", defaultName, i)
}
if _, taken := pkgNameToPkgPath[tryName]; !taken {
pkgNameToPkgPath[tryName] = path
pkgPathToPkgName[path] = tryName
return tryName
}
}
}
type Import struct {
Name, PkgPath string
}
func ImportsForType(currPkg string, t reflect.Type) []Import {
switch t.Kind() {
case reflect.Array, reflect.Slice, reflect.Ptr:
return ImportsForType(currPkg, t.Elem())
case reflect.Map:
return dedupImports(append(ImportsForType(currPkg, t.Key()), ImportsForType(currPkg, t.Elem())...))
default:
path := t.PkgPath()
if path == "" || path == currPkg {
// built-in or in current package.
return nil
}
return []Import{{PkgPath: path, Name: resolvePkgName(path, t.String())}}
}
}
func dedupImports(imps []Import) []Import {
impSet := make(map[string]string, len(imps))
for _, imp := range imps {
impSet[imp.PkgPath] = imp.Name
}
deduped := make([]Import, 0, len(imps))
for pkg, name := range impSet {
deduped = append(deduped, Import{Name: name, PkgPath: pkg})
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].PkgPath < deduped[j].PkgPath
})
return deduped
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"go/format" "go/format"
"os" "os"
"os/exec"
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
...@@ -12,12 +11,21 @@ import ( ...@@ -12,12 +11,21 @@ import (
func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error { func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil { typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err) return xerrors.Errorf("failed to write header: %w", err)
} }
for _, t := range types { for _, t := range typeInfos {
if err := GenTupleEncodersForType(pkg, t, buf); err != nil { if err := GenTupleEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err) return xerrors.Errorf("failed to generate encoders: %w", err)
} }
} }
...@@ -39,22 +47,27 @@ func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error { ...@@ -39,22 +47,27 @@ func WriteTupleEncodersToFile(fname, pkg string, types ...interface{}) error {
} }
_ = fi.Close() _ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil return nil
} }
func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error { func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := PrintHeaderAndUtilityMethods(buf, pkg); err != nil { typeInfos := make([]*GenTypeInfo, len(types))
for i, t := range types {
gti, err := ParseTypeInfo(t)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
typeInfos[i] = gti
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err) return xerrors.Errorf("failed to write header: %w", err)
} }
for _, t := range types { for _, t := range typeInfos {
if err := GenMapEncodersForType(pkg, t, buf); err != nil { if err := GenMapEncodersForType(t, buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err) return xerrors.Errorf("failed to generate encoders: %w", err)
} }
} }
...@@ -76,9 +89,5 @@ func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error { ...@@ -76,9 +89,5 @@ func WriteMapEncodersToFile(fname, pkg string, types ...interface{}) error {
} }
_ = fi.Close() _ = fi.Close()
if err := exec.Command("goimports", "-w", fname).Run(); err != nil {
return err
}
return nil return nil
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment