diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go index d91e7a1c..c243e009 100644 --- a/cmd/protoc-gen-go/main.go +++ b/cmd/protoc-gen-go/main.go @@ -11,6 +11,7 @@ import ( "compress/gzip" "crypto/sha256" "encoding/hex" + "flag" "fmt" "strconv" "strings" @@ -24,7 +25,13 @@ import ( const protoPackage = "github.com/golang/protobuf/proto" func main() { - protogen.Run(func(gen *protogen.Plugin) error { + var flags flag.FlagSet + // TODO: Decide what to do for backwards compatibility with plugins=grpc. + flags.String("plugins", "", "") + opts := &protogen.Options{ + ParamFunc: flags.Set, + } + protogen.Run(opts, func(gen *protogen.Plugin) error { for _, f := range gen.Files { if !f.Generate { continue diff --git a/protogen/protogen.go b/protogen/protogen.go index ada17bad..ee4c7358 100644 --- a/protogen/protogen.go +++ b/protogen/protogen.go @@ -41,14 +41,16 @@ import ( // // If a failure occurs while reading or writing, Run prints an error to // os.Stderr and calls os.Exit(1). -func Run(f func(*Plugin) error) { - if err := run(f); err != nil { +// +// Passing a nil options is equivalent to passing a zero-valued one. +func Run(opts *Options, f func(*Plugin) error) { + if err := run(opts, f); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err) os.Exit(1) } } -func run(f func(*Plugin) error) error { +func run(opts *Options, f func(*Plugin) error) error { in, err := ioutil.ReadAll(os.Stdin) if err != nil { return err @@ -57,7 +59,7 @@ func run(f func(*Plugin) error) error { if err := proto.Unmarshal(in, req); err != nil { return err } - gen, err := New(req) + gen, err := New(req, opts) if err != nil { return err } @@ -98,15 +100,47 @@ type Plugin struct { err error } +// Options are optional parameters to New. +type Options struct { + // If ParamFunc is non-nil, it will be called with each unknown + // generator parameter. + // + // Plugins for protoc can accept parameters from the command line, + // passed in the --<lang>_out protoc, separated from the output + // directory with a colon; e.g., + // + // --go_out=<param1>=<value1>,<param2>=<value2>:<output_directory> + // + // Parameters passed in this fashion as a comma-separated list of + // key=value pairs will be passed to the ParamFunc. + // + // The (flag.FlagSet).Set method matches this function signature, + // so parameters can be converted into flags as in the following: + // + // var flags flag.FlagSet + // value := flags.Bool("param", false, "") + // opts := &protogen.Options{ + // ParamFunc: flags.Set, + // } + // protogen.Run(opts, func(p *protogen.Plugin) error { + // if *value { ... } + // }) + ParamFunc func(name, value string) error +} + // New returns a new Plugin. -func New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { +// +// Passing a nil Options is equivalent to passing a zero-valued one. +func New(req *pluginpb.CodeGeneratorRequest, opts *Options) (*Plugin, error) { + if opts == nil { + opts = &Options{} + } gen := &Plugin{ Request: req, filesByName: make(map[string]*File), fileReg: protoregistry.NewFiles(), } - // TODO: Figure out how to pass parameters to the generator. packageNames := make(map[string]GoPackageName) // filename -> package name importPaths := make(map[string]GoImportPath) // filename -> import path var packageImportPath GoImportPath @@ -132,15 +166,18 @@ func New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { default: return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value) } - case "plugins": - // TODO case "annotate_code": // TODO default: - if param[0] != 'M' { - return nil, fmt.Errorf("unknown parameter %q", param) + if param[0] == 'M' { + importPaths[param[1:]] = GoImportPath(value) + continue + } + if opts.ParamFunc != nil { + if err := opts.ParamFunc(param, value); err != nil { + return nil, err + } } - importPaths[param[1:]] = GoImportPath(value) } } diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go index 05d0bf29..f35068ae 100644 --- a/protogen/protogen_test.go +++ b/protogen/protogen_test.go @@ -5,6 +5,7 @@ package protogen import ( + "flag" "fmt" "io/ioutil" "os" @@ -18,8 +19,45 @@ import ( pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin" ) +func TestPluginParameters(t *testing.T) { + var flags flag.FlagSet + value := flags.Int("integer", 0, "") + opts := &Options{ + ParamFunc: flags.Set, + } + const params = "integer=2" + _, err := New(&pluginpb.CodeGeneratorRequest{ + Parameter: proto.String(params), + }, opts) + if err != nil { + t.Errorf("New(generator parameters %q): %v", params, err) + } + if *value != 2 { + t.Errorf("New(generator parameters %q): integer=%v, want 2", params, *value) + } +} + +func TestPluginParameterErrors(t *testing.T) { + for _, parameter := range []string{ + "unknown=1", + "boolean=error", + } { + var flags flag.FlagSet + flags.Bool("boolean", false, "") + opts := &Options{ + ParamFunc: flags.Set, + } + _, err := New(&pluginpb.CodeGeneratorRequest{ + Parameter: proto.String(parameter), + }, opts) + if err == nil { + t.Errorf("New(generator parameters %q): want error, got nil", parameter) + } + } +} + func TestFiles(t *testing.T) { - gen, err := New(makeRequest(t, "testdata/go_package/no_go_package_import.proto")) + gen, err := New(makeRequest(t, "testdata/go_package/no_go_package_import.proto"), nil) if err != nil { t.Fatal(err) } @@ -144,7 +182,7 @@ TEST: %v if test.generate { req.FileToGenerate = []string{filename} } - gen, err := New(req) + gen, err := New(req, nil) if err != nil { t.Errorf("%vNew(req) = %v", context, err) continue @@ -182,7 +220,7 @@ func TestPackageNameInference(t *testing.T) { }, }, FileToGenerate: []string{"dir/file1.proto", "dir/file2.proto"}, - }) + }, nil) if err != nil { t.Fatalf("New(req) = %v", err) } @@ -212,14 +250,14 @@ func TestInconsistentPackageNames(t *testing.T) { }, }, FileToGenerate: []string{"dir/file1.proto", "dir/file2.proto"}, - }) + }, nil) if err == nil { t.Fatalf("inconsistent package names for the same import path: New(req) = nil, want error") } } func TestImports(t *testing.T) { - gen, err := New(&pluginpb.CodeGeneratorRequest{}) + gen, err := New(&pluginpb.CodeGeneratorRequest{}, nil) if err != nil { t.Fatal(err) } @@ -309,7 +347,7 @@ func makeRequest(t *testing.T, args ...string) *pluginpb.CodeGeneratorRequest { func init() { if os.Getenv("RUN_AS_PROTOC_PLUGIN") != "" { - Run(func(p *Plugin) error { + Run(nil, func(p *Plugin) error { g := p.NewGeneratedFile("request", "") return proto.MarshalText(g, p.Request) })