protogen: automatic handling of imports

The GoIdent type is now a tuple of import path and name. Generated files
have an associated import path. Writing a GoIdent to a generated file
qualifies the name if the identifier is from a different package.
All necessary imports are automatically added to generated Go files.

Change-Id: I839e0b7aa8ec967ce178aea4ffb960b62779cf74
Reviewed-on: https://go-review.googlesource.com/133635
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2018-08-23 14:39:30 -07:00
parent 23ddbd1430
commit d901677135
7 changed files with 135 additions and 29 deletions

View File

@ -25,7 +25,7 @@ func main() {
}
func genFile(gen *protogen.Plugin, f *protogen.File) {
g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto") + ".pb.go")
g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto")+".pb.go", f.GoImportPath)
g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
g.P("// source: ", f.Desc.GetName())
g.P()

1
go.mod
View File

@ -5,4 +5,5 @@ require (
github.com/google/go-cmp v0.2.0
golang.org/x/net v0.0.0-20180821023952-922f4815f713 // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48
)

2
go.sum
View File

@ -6,3 +6,5 @@ golang.org/x/net v0.0.0-20180821023952-922f4815f713 h1:rMJUcaDGbG+X967I4zGKCq5la
golang.org/x/net v0.0.0-20180821023952-922f4815f713/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48 h1:PIz+xUHW4G/jqfFWeKhQ96ZV/t2HDsXfWj923rV0bZY=
golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -1,6 +1,7 @@
package protogen
import (
"fmt"
"go/token"
"strconv"
"strings"
@ -8,8 +9,13 @@ import (
"unicode/utf8"
)
// A GoIdent is a Go identifier.
type GoIdent string
// A GoIdent is a Go identifier, consisting of a name and import path.
type GoIdent struct {
GoName string
GoImportPath GoImportPath
}
func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
// A GoImportPath is the import path of a Go package. e.g., "google.golang.org/genproto/protobuf".
type GoImportPath string
@ -64,7 +70,7 @@ func baseName(name string) string {
// but it's so remote we're prepared to pretend it's nonexistent - since the
// C++ generator lowercases names, it's extremely unlikely to have two fields
// with different capitalizations.
func camelCase(s string) GoIdent {
func camelCase(s string) string {
if s == "" {
return ""
}
@ -102,7 +108,7 @@ func camelCase(s string) GoIdent {
}
}
}
return GoIdent(t)
return string(t)
}
// Is c an ASCII lower-case letter?

View File

@ -8,8 +8,7 @@ import "testing"
func TestCamelCase(t *testing.T) {
tests := []struct {
in string
want GoIdent
in, want string
}{
{"one", "One"},
{"one_two", "OneTwo"},

View File

@ -20,11 +20,14 @@ import (
"io/ioutil"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"github.com/golang/protobuf/proto"
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
"golang.org/x/tools/go/ast/astutil"
)
// Run executes a function as a protoc plugin.
@ -168,7 +171,7 @@ func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
}
}
resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
Name: proto.String(gf.path),
Name: proto.String(gf.filename),
Content: proto.String(string(content)),
})
}
@ -185,7 +188,8 @@ func (gen *Plugin) FileByName(name string) (f *File, ok bool) {
type File struct {
Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor
Messages []*Message // top-level message declartions
GoImportPath GoImportPath // import path of this file's Go package
Messages []*Message // top-level message declarations
Generate bool // true if we should generate code for this file
}
@ -193,8 +197,8 @@ func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File {
f := &File{
Desc: p,
}
for _, d := range p.MessageType {
f.Messages = append(f.Messages, newMessage(gen, nil, d))
for i, mdesc := range p.MessageType {
f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i))
}
return f
}
@ -207,30 +211,40 @@ type Message struct {
Messages []*Message // nested message declarations
}
func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message {
func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message {
m := &Message{
Desc: p,
GoIdent: camelCase(p.GetName()),
GoIdent: GoIdent{
GoName: camelCase(p.GetName()),
GoImportPath: f.GoImportPath,
},
}
if parent != nil {
m.GoIdent = parent.GoIdent + "_" + m.GoIdent
m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName
}
for _, nested := range p.GetNestedType() {
m.Messages = append(m.Messages, newMessage(gen, m, nested))
for i, nested := range p.GetNestedType() {
m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i))
}
return m
}
// A GeneratedFile is a generated file.
type GeneratedFile struct {
path string
filename string
goImportPath GoImportPath
buf bytes.Buffer
packageNames map[GoImportPath]GoPackageName
usedPackageNames map[GoPackageName]bool
}
// NewGeneratedFile creates a new generated file with the given path.
func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
// NewGeneratedFile creates a new generated file with the given filename
// and import path.
func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
g := &GeneratedFile{
path: path,
filename: filename,
goImportPath: goImportPath,
packageNames: make(map[GoImportPath]GoPackageName),
usedPackageNames: make(map[GoPackageName]bool),
}
gen.genFiles = append(gen.genFiles, g)
return g
@ -243,11 +257,33 @@ func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
// TODO: .meta file annotations.
func (g *GeneratedFile) P(v ...interface{}) {
for _, x := range v {
switch x := x.(type) {
case GoIdent:
if x.GoImportPath != g.goImportPath {
fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
fmt.Fprint(&g.buf, ".")
}
fmt.Fprint(&g.buf, x.GoName)
default:
fmt.Fprint(&g.buf, x)
}
}
fmt.Fprintln(&g.buf)
}
func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
if name, ok := g.packageNames[importPath]; ok {
return name
}
name := cleanPackageName(baseName(string(importPath)))
for i, orig := 1, name; g.usedPackageNames[name]; i++ {
name = orig + GoPackageName(strconv.Itoa(i))
}
g.packageNames[importPath] = name
g.usedPackageNames[name] = true
return name
}
// Write implements io.Writer.
func (g *GeneratedFile) Write(p []byte) (n int, err error) {
return g.buf.Write(p)
@ -255,7 +291,7 @@ func (g *GeneratedFile) Write(p []byte) (n int, err error) {
// Content returns the contents of the generated file.
func (g *GeneratedFile) Content() ([]byte, error) {
if !strings.HasSuffix(g.path, ".go") {
if !strings.HasSuffix(g.filename, ".go") {
return g.buf.Bytes(), nil
}
@ -272,13 +308,24 @@ func (g *GeneratedFile) Content() ([]byte, error) {
for line := 1; s.Scan(); line++ {
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
}
return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String())
return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
}
// Add imports.
var importPaths []string
for importPath := range g.packageNames {
importPaths = append(importPaths, string(importPath))
}
sort.Strings(importPaths)
for _, importPath := range importPaths {
astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
}
var out bytes.Buffer
if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err)
return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
}
// TODO: Patch annotation locations.
// TODO: Annotations.
return out.Bytes(), nil
}

View File

@ -45,6 +45,57 @@ func TestFiles(t *testing.T) {
}
}
func TestImports(t *testing.T) {
gen, err := New(&pluginpb.CodeGeneratorRequest{})
if err != nil {
t.Fatal(err)
}
g := gen.NewGeneratedFile("foo.go", "golang.org/x/foo")
g.P("package foo")
g.P()
for _, importPath := range []GoImportPath{
"golang.org/x/foo",
// Multiple references to the same package.
"golang.org/x/bar",
"golang.org/x/bar",
// Reference to a different package with the same basename.
"golang.org/y/bar",
"golang.org/x/baz",
} {
g.P("var _ = ", GoIdent{GoName: "X", GoImportPath: importPath}, " // ", importPath)
}
want := `package foo
import (
bar "golang.org/x/bar"
bar1 "golang.org/y/bar"
baz "golang.org/x/baz"
)
var _ = X // "golang.org/x/foo"
var _ = bar.X // "golang.org/x/bar"
var _ = bar.X // "golang.org/x/bar"
var _ = bar1.X // "golang.org/y/bar"
var _ = baz.X // "golang.org/x/baz"
`
got, err := g.Content()
if err != nil {
t.Fatalf("g.Content() = %v", err)
}
if want != string(got) {
t.Fatalf(`want:
==========
%v
==========
got:
==========
%v
==========`,
want, string(got))
}
}
// makeRequest returns a CodeGeneratorRequest for the given protoc inputs.
//
// It does this by running protoc with the current binary as the protoc-gen-go
@ -86,7 +137,7 @@ func makeRequest(t *testing.T, args ...string) *pluginpb.CodeGeneratorRequest {
func init() {
if os.Getenv("RUN_AS_PROTOC_PLUGIN") != "" {
Run(func(p *Plugin) error {
g := p.NewGeneratedFile("request")
g := p.NewGeneratedFile("request", "")
return proto.MarshalText(g, p.Request)
})
os.Exit(0)