diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go index 79738958..e144aa61 100644 --- a/cmd/protoc-gen-go/main.go +++ b/cmd/protoc-gen-go/main.go @@ -25,7 +25,7 @@ func main() { } func genFile(gen *protogen.Plugin, f *protogen.File) { - g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto") + ".pb.go") + g := gen.NewGeneratedFile(strings.TrimSuffix(f.Desc.GetName(), ".proto")+".pb.go", f.GoImportPath) g.P("// Code generated by protoc-gen-go. DO NOT EDIT.") g.P("// source: ", f.Desc.GetName()) g.P() diff --git a/go.mod b/go.mod index 6aa64ad9..638151cc 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ require ( github.com/google/go-cmp v0.2.0 golang.org/x/net v0.0.0-20180821023952-922f4815f713 // indirect golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect + golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48 ) diff --git a/go.sum b/go.sum index 1bfafa6e..51e6c99d 100644 --- a/go.sum +++ b/go.sum @@ -6,3 +6,5 @@ golang.org/x/net v0.0.0-20180821023952-922f4815f713 h1:rMJUcaDGbG+X967I4zGKCq5la golang.org/x/net v0.0.0-20180821023952-922f4815f713/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48 h1:PIz+xUHW4G/jqfFWeKhQ96ZV/t2HDsXfWj923rV0bZY= +golang.org/x/tools v0.0.0-20180904205237-0aa4b8830f48/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/protogen/names.go b/protogen/names.go index b97c47d5..ea3d0577 100644 --- a/protogen/names.go +++ b/protogen/names.go @@ -1,6 +1,7 @@ package protogen import ( + "fmt" "go/token" "strconv" "strings" @@ -8,8 +9,13 @@ import ( "unicode/utf8" ) -// A GoIdent is a Go identifier. -type GoIdent string +// A GoIdent is a Go identifier, consisting of a name and import path. +type GoIdent struct { + GoName string + GoImportPath GoImportPath +} + +func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) } // A GoImportPath is the import path of a Go package. e.g., "google.golang.org/genproto/protobuf". type GoImportPath string @@ -64,7 +70,7 @@ func baseName(name string) string { // but it's so remote we're prepared to pretend it's nonexistent - since the // C++ generator lowercases names, it's extremely unlikely to have two fields // with different capitalizations. -func camelCase(s string) GoIdent { +func camelCase(s string) string { if s == "" { return "" } @@ -102,7 +108,7 @@ func camelCase(s string) GoIdent { } } } - return GoIdent(t) + return string(t) } // Is c an ASCII lower-case letter? diff --git a/protogen/names_test.go b/protogen/names_test.go index 021e71a1..05e698e1 100644 --- a/protogen/names_test.go +++ b/protogen/names_test.go @@ -8,8 +8,7 @@ import "testing" func TestCamelCase(t *testing.T) { tests := []struct { - in string - want GoIdent + in, want string }{ {"one", "One"}, {"one_two", "OneTwo"}, diff --git a/protogen/protogen.go b/protogen/protogen.go index f499a8d7..2d65ee03 100644 --- a/protogen/protogen.go +++ b/protogen/protogen.go @@ -20,11 +20,14 @@ import ( "io/ioutil" "os" "path/filepath" + "sort" + "strconv" "strings" "github.com/golang/protobuf/proto" descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin" + "golang.org/x/tools/go/ast/astutil" ) // Run executes a function as a protoc plugin. @@ -168,7 +171,7 @@ func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse { } } resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ - Name: proto.String(gf.path), + Name: proto.String(gf.filename), Content: proto.String(string(content)), }) } @@ -185,16 +188,17 @@ func (gen *Plugin) FileByName(name string) (f *File, ok bool) { type File struct { Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor - Messages []*Message // top-level message declartions - Generate bool // true if we should generate code for this file + GoImportPath GoImportPath // import path of this file's Go package + Messages []*Message // top-level message declarations + Generate bool // true if we should generate code for this file } func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File { f := &File{ Desc: p, } - for _, d := range p.MessageType { - f.Messages = append(f.Messages, newMessage(gen, nil, d)) + for i, mdesc := range p.MessageType { + f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i)) } return f } @@ -207,30 +211,40 @@ type Message struct { Messages []*Message // nested message declarations } -func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message { +func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message { m := &Message{ - Desc: p, - GoIdent: camelCase(p.GetName()), + Desc: p, + GoIdent: GoIdent{ + GoName: camelCase(p.GetName()), + GoImportPath: f.GoImportPath, + }, } if parent != nil { - m.GoIdent = parent.GoIdent + "_" + m.GoIdent + m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName } - for _, nested := range p.GetNestedType() { - m.Messages = append(m.Messages, newMessage(gen, m, nested)) + for i, nested := range p.GetNestedType() { + m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i)) } return m } // A GeneratedFile is a generated file. type GeneratedFile struct { - path string - buf bytes.Buffer + filename string + goImportPath GoImportPath + buf bytes.Buffer + packageNames map[GoImportPath]GoPackageName + usedPackageNames map[GoPackageName]bool } -// NewGeneratedFile creates a new generated file with the given path. -func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile { +// NewGeneratedFile creates a new generated file with the given filename +// and import path. +func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile { g := &GeneratedFile{ - path: path, + filename: filename, + goImportPath: goImportPath, + packageNames: make(map[GoImportPath]GoPackageName), + usedPackageNames: make(map[GoPackageName]bool), } gen.genFiles = append(gen.genFiles, g) return g @@ -243,11 +257,33 @@ func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile { // TODO: .meta file annotations. func (g *GeneratedFile) P(v ...interface{}) { for _, x := range v { - fmt.Fprint(&g.buf, x) + switch x := x.(type) { + case GoIdent: + if x.GoImportPath != g.goImportPath { + fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath)) + fmt.Fprint(&g.buf, ".") + } + fmt.Fprint(&g.buf, x.GoName) + default: + fmt.Fprint(&g.buf, x) + } } fmt.Fprintln(&g.buf) } +func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName { + if name, ok := g.packageNames[importPath]; ok { + return name + } + name := cleanPackageName(baseName(string(importPath))) + for i, orig := 1, name; g.usedPackageNames[name]; i++ { + name = orig + GoPackageName(strconv.Itoa(i)) + } + g.packageNames[importPath] = name + g.usedPackageNames[name] = true + return name +} + // Write implements io.Writer. func (g *GeneratedFile) Write(p []byte) (n int, err error) { return g.buf.Write(p) @@ -255,7 +291,7 @@ func (g *GeneratedFile) Write(p []byte) (n int, err error) { // Content returns the contents of the generated file. func (g *GeneratedFile) Content() ([]byte, error) { - if !strings.HasSuffix(g.path, ".go") { + if !strings.HasSuffix(g.filename, ".go") { return g.buf.Bytes(), nil } @@ -272,13 +308,24 @@ func (g *GeneratedFile) Content() ([]byte, error) { for line := 1; s.Scan(); line++ { fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes()) } - return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String()) + return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String()) } + + // Add imports. + var importPaths []string + for importPath := range g.packageNames { + importPaths = append(importPaths, string(importPath)) + } + sort.Strings(importPaths) + for _, importPath := range importPaths { + astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath) + } + var out bytes.Buffer if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil { - return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err) + return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err) } - // TODO: Patch annotation locations. + // TODO: Annotations. return out.Bytes(), nil } diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go index 1d23cc07..4da5b3ce 100644 --- a/protogen/protogen_test.go +++ b/protogen/protogen_test.go @@ -45,6 +45,57 @@ func TestFiles(t *testing.T) { } } +func TestImports(t *testing.T) { + gen, err := New(&pluginpb.CodeGeneratorRequest{}) + if err != nil { + t.Fatal(err) + } + g := gen.NewGeneratedFile("foo.go", "golang.org/x/foo") + g.P("package foo") + g.P() + for _, importPath := range []GoImportPath{ + "golang.org/x/foo", + // Multiple references to the same package. + "golang.org/x/bar", + "golang.org/x/bar", + // Reference to a different package with the same basename. + "golang.org/y/bar", + "golang.org/x/baz", + } { + g.P("var _ = ", GoIdent{GoName: "X", GoImportPath: importPath}, " // ", importPath) + } + want := `package foo + +import ( + bar "golang.org/x/bar" + bar1 "golang.org/y/bar" + baz "golang.org/x/baz" +) + +var _ = X // "golang.org/x/foo" +var _ = bar.X // "golang.org/x/bar" +var _ = bar.X // "golang.org/x/bar" +var _ = bar1.X // "golang.org/y/bar" +var _ = baz.X // "golang.org/x/baz" +` + got, err := g.Content() + if err != nil { + t.Fatalf("g.Content() = %v", err) + } + if want != string(got) { + t.Fatalf(`want: +========== +%v +========== + +got: +========== +%v +==========`, + want, string(got)) + } +} + // makeRequest returns a CodeGeneratorRequest for the given protoc inputs. // // It does this by running protoc with the current binary as the protoc-gen-go @@ -86,7 +137,7 @@ func makeRequest(t *testing.T, args ...string) *pluginpb.CodeGeneratorRequest { func init() { if os.Getenv("RUN_AS_PROTOC_PLUGIN") != "" { Run(func(p *Plugin) error { - g := p.NewGeneratedFile("request") + g := p.NewGeneratedFile("request", "") return proto.MarshalText(g, p.Request) }) os.Exit(0)