// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package internal_gengo is internal to the protobuf module. package internal_gengo import ( "bytes" "compress/gzip" "crypto/sha256" "encoding/hex" "fmt" "math" "sort" "strconv" "strings" "github.com/golang/protobuf/proto" descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/golang/protobuf/v2/protogen" "github.com/golang/protobuf/v2/reflect/protoreflect" ) // generatedCodeVersion indicates a version of the generated code. // It is incremented whenever an incompatibility between the generated code and // proto package is introduced; the generated code references // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion). const generatedCodeVersion = 2 const protoPackage = "github.com/golang/protobuf/proto" type fileInfo struct { *protogen.File descriptorVar string // var containing the gzipped FileDescriptorProto allEnums []*protogen.Enum allMessages []*protogen.Message allExtensions []*protogen.Extension } // GenerateFile generates the contents of a .pb.go file. func GenerateFile(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { f := &fileInfo{ File: file, } // The different order for enums and extensions is to match the output // of the previous implementation. // // TODO: Eventually make this consistent. f.allEnums = append(f.allEnums, f.File.Enums...) walkMessages(f.Messages, func(message *protogen.Message) { f.allMessages = append(f.allMessages, message) f.allEnums = append(f.allEnums, message.Enums...) f.allExtensions = append(f.allExtensions, message.Extensions...) }) f.allExtensions = append(f.allExtensions, f.File.Extensions...) // Determine the name of the var holding the file descriptor: // // fileDescriptor_ filenameHash := sha256.Sum256([]byte(f.Desc.Path())) f.descriptorVar = fmt.Sprintf("fileDescriptor_%s", hex.EncodeToString(filenameHash[:8])) g.P("// Code generated by protoc-gen-go. DO NOT EDIT.") if f.Proto.GetOptions().GetDeprecated() { g.P("// ", f.Desc.Path(), " is a deprecated file.") } else { g.P("// source: ", f.Desc.Path()) } g.P() const filePackageField = 2 // FileDescriptorProto.package g.PrintLeadingComments(protogen.Location{ SourceFile: f.Proto.GetName(), Path: []int32{filePackageField}, }) g.P() g.P("package ", f.GoPackageName) g.P() // These references are not necessary, since we automatically add // all necessary imports before formatting the generated file. // // This section exists to generate output more consistent with // the previous version of protoc-gen-go, to make it easier to // detect unintended variations. // // TODO: Eventually remove this. g.P("// Reference imports to suppress errors if they are not otherwise used.") g.P("var _ = ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "Marshal"}) g.P("var _ = ", protogen.GoIdent{GoImportPath: "fmt", GoName: "Errorf"}) g.P("var _ = ", protogen.GoIdent{GoImportPath: "math", GoName: "Inf"}) g.P() g.P("// This is a compile-time assertion to ensure that this generated file") g.P("// is compatible with the proto package it is being compiled against.") g.P("// A compilation error at this line likely means your copy of the") g.P("// proto package needs to be updated.") g.P("const _ = ", protogen.GoIdent{ GoImportPath: protoPackage, GoName: fmt.Sprintf("ProtoPackageIsVersion%d", generatedCodeVersion), }, "// please upgrade the proto package") g.P() for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ { genImport(gen, g, f, imps.Get(i)) } for _, enum := range f.allEnums { genEnum(gen, g, f, enum) } for _, message := range f.allMessages { genMessage(gen, g, f, message) } for _, extension := range f.Extensions { genExtension(gen, g, f, extension) } genInitFunction(gen, g, f) genFileDescriptor(gen, g, f) } // walkMessages calls f on each message and all of its descendants. func walkMessages(messages []*protogen.Message, f func(*protogen.Message)) { for _, m := range messages { f(m) walkMessages(m.Messages, f) } } func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp protoreflect.FileImport) { impFile, ok := gen.FileByName(imp.Path()) if !ok { return } if impFile.GoImportPath == f.GoImportPath { // Don't generate imports or aliases for types in the same Go package. return } // Generate imports for all non-weak dependencies, even if they are not // referenced, because other code and tools depend on having the // full transitive closure of protocol buffer types in the binary. if !imp.IsWeak { g.Import(impFile.GoImportPath) } if !imp.IsPublic { return } // TODO: An alternate approach to generating public imports might be // to generate the imported file contents, parse it, and extract all // exported identifiers from the AST to build a list of forwarding // declarations. // // TODO: Consider whether this should generate recursive aliases. e.g., // if a.proto publicly imports b.proto publicly imports c.proto, should // a.pb.go contain aliases for symbols defined in c.proto? var enums []*protogen.Enum enums = append(enums, impFile.Enums...) walkMessages(impFile.Messages, func(message *protogen.Message) { if message.Desc.IsMapEntry() { return } enums = append(enums, message.Enums...) for _, field := range message.Fields { if !fieldHasDefault(field) { continue } defVar := protogen.GoIdent{ GoImportPath: message.GoIdent.GoImportPath, GoName: "Default_" + message.GoIdent.GoName + "_" + field.GoName, } decl := "const" if field.Desc.Kind() == protoreflect.BytesKind { decl = "var" } g.P(decl, " ", defVar.GoName, " = ", defVar) } g.P("// ", message.GoIdent.GoName, " from public import ", imp.Path()) g.P("type ", message.GoIdent.GoName, " = ", message.GoIdent) for _, oneof := range message.Oneofs { for _, field := range oneof.Fields { typ := fieldOneofType(field) g.P("type ", typ.GoName, " = ", typ) } } g.P() }) for _, enum := range enums { g.P("// ", enum.GoIdent.GoName, " from public import ", imp.Path()) g.P("type ", enum.GoIdent.GoName, " = ", enum.GoIdent) g.P("var ", enum.GoIdent.GoName, "_name = ", enum.GoIdent, "_name") g.P("var ", enum.GoIdent.GoName, "_value = ", enum.GoIdent, "_value") g.P() for _, value := range enum.Values { g.P("const ", value.GoIdent.GoName, " = ", enum.GoIdent.GoName, "(", value.GoIdent, ")") } } for _, ext := range impFile.Extensions { ident := extensionVar(impFile, ext) g.P("var ", ident.GoName, " = ", ident) g.P() } g.P() } func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) { // Trim the source_code_info from the descriptor. // Marshal and gzip it. descProto := proto.Clone(f.Proto).(*descpb.FileDescriptorProto) descProto.SourceCodeInfo = nil b, err := proto.Marshal(descProto) if err != nil { gen.Error(err) return } var buf bytes.Buffer w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression) w.Write(b) w.Close() b = buf.Bytes() g.P("func init() { proto.RegisterFile(", strconv.Quote(f.Desc.Path()), ", ", f.descriptorVar, ") }") g.P() g.P("var ", f.descriptorVar, " = []byte{") g.P("// ", len(b), " bytes of a gzipped FileDescriptorProto") for len(b) > 0 { n := 16 if n > len(b) { n = len(b) } s := "" for _, c := range b[:n] { s += fmt.Sprintf("0x%02x,", c) } g.P(s) b = b[n:] } g.P("}") g.P() } func genEnum(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, enum *protogen.Enum) { g.PrintLeadingComments(enum.Location) g.Annotate(enum.GoIdent.GoName, enum.Location) g.P("type ", enum.GoIdent, " int32", deprecationComment(enum.Desc.Options().(*descpb.EnumOptions).GetDeprecated())) g.P("const (") for _, value := range enum.Values { g.PrintLeadingComments(value.Location) g.Annotate(value.GoIdent.GoName, value.Location) g.P(value.GoIdent, " ", enum.GoIdent, " = ", value.Desc.Number(), deprecationComment(value.Desc.Options().(*descpb.EnumValueOptions).GetDeprecated())) } g.P(")") g.P() nameMap := enum.GoIdent.GoName + "_name" g.P("var ", nameMap, " = map[int32]string{") generated := make(map[protoreflect.EnumNumber]bool) for _, value := range enum.Values { duplicate := "" if _, present := generated[value.Desc.Number()]; present { duplicate = "// Duplicate value: " } g.P(duplicate, value.Desc.Number(), ": ", strconv.Quote(string(value.Desc.Name())), ",") generated[value.Desc.Number()] = true } g.P("}") g.P() valueMap := enum.GoIdent.GoName + "_value" g.P("var ", valueMap, " = map[string]int32{") for _, value := range enum.Values { g.P(strconv.Quote(string(value.Desc.Name())), ": ", value.Desc.Number(), ",") } g.P("}") g.P() if enum.Desc.Syntax() != protoreflect.Proto3 { g.P("func (x ", enum.GoIdent, ") Enum() *", enum.GoIdent, " {") g.P("p := new(", enum.GoIdent, ")") g.P("*p = x") g.P("return p") g.P("}") g.P() } g.P("func (x ", enum.GoIdent, ") String() string {") g.P("return ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "EnumName"}, "(", enum.GoIdent, "_name, int32(x))") g.P("}") g.P() if enum.Desc.Syntax() != protoreflect.Proto3 { g.P("func (x *", enum.GoIdent, ") UnmarshalJSON(data []byte) error {") g.P("value, err := ", protogen.GoIdent{GoImportPath: protoPackage, GoName: "UnmarshalJSONEnum"}, "(", enum.GoIdent, `_value, data, "`, enum.GoIdent, `")`) g.P("if err != nil {") g.P("return err") g.P("}") g.P("*x = ", enum.GoIdent, "(value)") g.P("return nil") g.P("}") g.P() } var indexes []string for i := 1; i < len(enum.Location.Path); i += 2 { indexes = append(indexes, strconv.Itoa(int(enum.Location.Path[i]))) } g.P("func (", enum.GoIdent, ") EnumDescriptor() ([]byte, []int) {") g.P("return ", f.descriptorVar, ", []int{", strings.Join(indexes, ","), "}") g.P("}") g.P() genWellKnownType(g, "", enum.GoIdent, enum.Desc) } // enumRegistryName returns the name used to register an enum with the proto // package registry. // // Confusingly, this is .. This probably should have // been the full name of the proto enum type instead, but changing it at this // point would require thought. func enumRegistryName(enum *protogen.Enum) string { // Find the FileDescriptor for this enum. var desc protoreflect.Descriptor = enum.Desc for { p, ok := desc.Parent() if !ok { break } desc = p } fdesc := desc.(protoreflect.FileDescriptor) if fdesc.Package() == "" { return enum.GoIdent.GoName } return string(fdesc.Package()) + "." + enum.GoIdent.GoName } func genMessage(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, message *protogen.Message) { if message.Desc.IsMapEntry() { return } hasComment := g.PrintLeadingComments(message.Location) if message.Desc.Options().(*descpb.MessageOptions).GetDeprecated() { if hasComment { g.P("//") } g.P(deprecationComment(true)) } g.Annotate(message.GoIdent.GoName, message.Location) g.P("type ", message.GoIdent, " struct {") for _, field := range message.Fields { if field.OneofType != nil { // It would be a bit simpler to iterate over the oneofs below, // but generating the field here keeps the contents of the Go // struct in the same order as the contents of the source // .proto file. if field == field.OneofType.Fields[0] { genOneofField(gen, g, f, message, field.OneofType) } continue } g.PrintLeadingComments(field.Location) goType, pointer := fieldGoType(g, field) if pointer { goType = "*" + goType } tags := []string{ fmt.Sprintf("protobuf:%q", fieldProtobufTag(field)), fmt.Sprintf("json:%q", fieldJSONTag(field)), } if field.Desc.IsMap() { key := field.MessageType.Fields[0] val := field.MessageType.Fields[1] tags = append(tags, fmt.Sprintf("protobuf_key:%q", fieldProtobufTag(key)), fmt.Sprintf("protobuf_val:%q", fieldProtobufTag(val)), ) } g.Annotate(message.GoIdent.GoName+"."+field.GoName, field.Location) g.P(field.GoName, " ", goType, " `", strings.Join(tags, " "), "`", deprecationComment(field.Desc.Options().(*descpb.FieldOptions).GetDeprecated())) } g.P("XXX_NoUnkeyedLiteral struct{} `json:\"-\"`") if message.Desc.ExtensionRanges().Len() > 0 { var tags []string if message.Desc.Options().(*descpb.MessageOptions).GetMessageSetWireFormat() { tags = append(tags, `protobuf_messageset:"1"`) } tags = append(tags, `json:"-"`) g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "XXX_InternalExtensions", }, " `", strings.Join(tags, " "), "`") } // TODO XXX_InternalExtensions g.P("XXX_unrecognized []byte `json:\"-\"`") g.P("XXX_sizecache int32 `json:\"-\"`") g.P("}") g.P() // Reset g.P("func (m *", message.GoIdent, ") Reset() { *m = ", message.GoIdent, "{} }") // String g.P("func (m *", message.GoIdent, ") String() string { return ", protogen.GoIdent{ GoImportPath: protoPackage, GoName: "CompactTextString", }, "(m) }") // ProtoMessage g.P("func (*", message.GoIdent, ") ProtoMessage() {}") // Descriptor var indexes []string for i := 1; i < len(message.Location.Path); i += 2 { indexes = append(indexes, strconv.Itoa(int(message.Location.Path[i]))) } g.P("func (*", message.GoIdent, ") Descriptor() ([]byte, []int) {") g.P("return ", f.descriptorVar, ", []int{", strings.Join(indexes, ","), "}") g.P("}") g.P() // ExtensionRangeArray if extranges := message.Desc.ExtensionRanges(); extranges.Len() > 0 { if message.Desc.Options().(*descpb.MessageOptions).GetMessageSetWireFormat() { g.P("func (m *", message.GoIdent, ") MarshalJSON() ([]byte, error) {") g.P("return ", protogen.GoIdent{ GoImportPath: protoPackage, GoName: "MarshalMessageSetJSON", }, "(&m.XXX_InternalExtensions)") g.P("}") g.P("func (m *", message.GoIdent, ") UnmarshalJSON(buf []byte) error {") g.P("return ", protogen.GoIdent{ GoImportPath: protoPackage, GoName: "UnmarshalMessageSetJSON", }, "(buf, &m.XXX_InternalExtensions)") g.P("}") g.P() } protoExtRange := protogen.GoIdent{ GoImportPath: protoPackage, GoName: "ExtensionRange", } extRangeVar := "extRange_" + message.GoIdent.GoName g.P("var ", extRangeVar, " = []", protoExtRange, " {") for i := 0; i < extranges.Len(); i++ { r := extranges.Get(i) g.P("{Start:", r[0], ", End:", r[1]-1 /* inclusive */, "},") } g.P("}") g.P() g.P("func (*", message.GoIdent, ") ExtensionRangeArray() []", protoExtRange, " {") g.P("return ", extRangeVar) g.P("}") g.P() } genWellKnownType(g, "*", message.GoIdent, message.Desc) // Table-driven proto support. // // TODO: It does not scale to keep adding another method for every // operation on protos that we want to switch over to using the // table-driven approach. Instead, we should only add a single method // that allows getting access to the *InternalMessageInfo struct and then // calling Unmarshal, Marshal, Merge, Size, and Discard directly on that. messageInfoVar := "xxx_messageInfo_" + message.GoIdent.GoName // XXX_Unmarshal g.P("func (m *", message.GoIdent, ") XXX_Unmarshal(b []byte) error {") g.P("return ", messageInfoVar, ".Unmarshal(m, b)") g.P("}") // XXX_Marshal g.P("func (m *", message.GoIdent, ") XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {") g.P("return ", messageInfoVar, ".Marshal(b, m, deterministic)") g.P("}") // XXX_Merge g.P("func (m *", message.GoIdent, ") XXX_Merge(src proto.Message) {") g.P(messageInfoVar, ".Merge(m, src)") g.P("}") // XXX_Size g.P("func (m *", message.GoIdent, ") XXX_Size() int {") g.P("return ", messageInfoVar, ".Size(m)") g.P("}") // XXX_DiscardUnknown g.P("func (m *", message.GoIdent, ") XXX_DiscardUnknown() {") g.P(messageInfoVar, ".DiscardUnknown(m)") g.P("}") g.P() g.P("var ", messageInfoVar, " ", protogen.GoIdent{ GoImportPath: protoPackage, GoName: "InternalMessageInfo", }) g.P() // Constants and vars holding the default values of fields. for _, field := range message.Fields { if !fieldHasDefault(field) { continue } defVarName := "Default_" + message.GoIdent.GoName + "_" + field.GoName def := field.Desc.Default() switch field.Desc.Kind() { case protoreflect.StringKind: g.P("const ", defVarName, " string = ", strconv.Quote(def.String())) case protoreflect.BytesKind: g.P("var ", defVarName, " []byte = []byte(", strconv.Quote(string(def.Bytes())), ")") case protoreflect.EnumKind: evalueDesc := field.Desc.DefaultEnumValue() enum := field.EnumType evalue := enum.Values[evalueDesc.Index()] g.P("const ", defVarName, " ", field.EnumType.GoIdent, " = ", evalue.GoIdent) case protoreflect.FloatKind, protoreflect.DoubleKind: // Floating point numbers need extra handling for -Inf/Inf/NaN. f := field.Desc.Default().Float() goType := "float64" if field.Desc.Kind() == protoreflect.FloatKind { goType = "float32" } // funcCall returns a call to a function in the math package, // possibly converting the result to float32. funcCall := func(fn, param string) string { s := g.QualifiedGoIdent(protogen.GoIdent{ GoImportPath: "math", GoName: fn, }) + param if goType != "float64" { s = goType + "(" + s + ")" } return s } switch { case math.IsInf(f, -1): g.P("var ", defVarName, " ", goType, " = ", funcCall("Inf", "(-1)")) case math.IsInf(f, 1): g.P("var ", defVarName, " ", goType, " = ", funcCall("Inf", "(1)")) case math.IsNaN(f): g.P("var ", defVarName, " ", goType, " = ", funcCall("NaN", "()")) default: g.P("const ", defVarName, " ", goType, " = ", field.Desc.Default().Interface()) } default: goType, _ := fieldGoType(g, field) g.P("const ", defVarName, " ", goType, " = ", def.Interface()) } } g.P() // Getters. for _, field := range message.Fields { if field.OneofType != nil { if field == field.OneofType.Fields[0] { genOneofTypes(gen, g, f, message, field.OneofType) } } goType, pointer := fieldGoType(g, field) defaultValue := fieldDefaultValue(g, message, field) if field.Desc.Options().(*descpb.FieldOptions).GetDeprecated() { g.P(deprecationComment(true)) } g.Annotate(message.GoIdent.GoName+".Get"+field.GoName, field.Location) g.P("func (m *", message.GoIdent, ") Get", field.GoName, "() ", goType, " {") if field.OneofType != nil { g.P("if x, ok := m.Get", field.OneofType.GoName, "().(*", fieldOneofType(field), "); ok {") g.P("return x.", field.GoName) g.P("}") } else { if field.Desc.Syntax() == protoreflect.Proto3 || defaultValue == "nil" { g.P("if m != nil {") } else { g.P("if m != nil && m.", field.GoName, " != nil {") } star := "" if pointer { star = "*" } g.P("return ", star, " m.", field.GoName) g.P("}") } g.P("return ", defaultValue) g.P("}") g.P() } if len(message.Oneofs) > 0 { genOneofFuncs(gen, g, f, message) } for _, extension := range message.Extensions { genExtension(gen, g, f, extension) } } // fieldGoType returns the Go type used for a field. // // If it returns pointer=true, the struct field is a pointer to the type. func fieldGoType(g *protogen.GeneratedFile, field *protogen.Field) (goType string, pointer bool) { pointer = true switch field.Desc.Kind() { case protoreflect.BoolKind: goType = "bool" case protoreflect.EnumKind: goType = g.QualifiedGoIdent(field.EnumType.GoIdent) case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: goType = "int32" case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: goType = "uint32" case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: goType = "int64" case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: goType = "uint64" case protoreflect.FloatKind: goType = "float32" case protoreflect.DoubleKind: goType = "float64" case protoreflect.StringKind: goType = "string" case protoreflect.BytesKind: goType = "[]byte" pointer = false case protoreflect.MessageKind, protoreflect.GroupKind: if field.Desc.IsMap() { keyType, _ := fieldGoType(g, field.MessageType.Fields[0]) valType, _ := fieldGoType(g, field.MessageType.Fields[1]) return fmt.Sprintf("map[%v]%v", keyType, valType), false } goType = "*" + g.QualifiedGoIdent(field.MessageType.GoIdent) pointer = false } if field.Desc.Cardinality() == protoreflect.Repeated { goType = "[]" + goType pointer = false } // Extension fields always have pointer type, even when defined in a proto3 file. if field.Desc.Syntax() == protoreflect.Proto3 && field.Desc.ExtendedType() == nil { pointer = false } return goType, pointer } func fieldProtobufTag(field *protogen.Field) string { var tag []string // wire type tag = append(tag, wireTypes[field.Desc.Kind()]) // field number tag = append(tag, strconv.Itoa(int(field.Desc.Number()))) // cardinality switch field.Desc.Cardinality() { case protoreflect.Optional: tag = append(tag, "opt") case protoreflect.Required: tag = append(tag, "req") case protoreflect.Repeated: tag = append(tag, "rep") } if field.Desc.IsPacked() { tag = append(tag, "packed") } // TODO: packed // name name := string(field.Desc.Name()) if field.Desc.Kind() == protoreflect.GroupKind { // The name of the FieldDescriptor for a group field is // lowercased. To find the original capitalization, we // look in the field's MessageType. name = string(field.MessageType.Desc.Name()) } tag = append(tag, "name="+name) // JSON name if jsonName := field.Desc.JSONName(); jsonName != "" && jsonName != name { tag = append(tag, "json="+jsonName) } // proto3 // The previous implementation does not tag extension fields as proto3, // even when the field is defined in a proto3 file. Match that behavior // for consistency. if field.Desc.Syntax() == protoreflect.Proto3 && field.Desc.ExtendedType() == nil { tag = append(tag, "proto3") } // enum if field.Desc.Kind() == protoreflect.EnumKind { tag = append(tag, "enum="+enumRegistryName(field.EnumType)) } // oneof if field.Desc.OneofType() != nil { tag = append(tag, "oneof") } // default value // This must appear last in the tag, since commas in strings aren't escaped. if field.Desc.HasDefault() { var def string switch field.Desc.Kind() { case protoreflect.BoolKind: if field.Desc.Default().Bool() { def = "1" } else { def = "0" } case protoreflect.BytesKind: // Preserve protoc-gen-go's historical output of escaped bytes. // This behavior is buggy, but fixing it makes it impossible to // distinguish between the escaped and unescaped forms. // // To match the exact output of protoc, this is identical to the // CEscape function in strutil.cc of the protoc source code. var b []byte for _, c := range field.Desc.Default().Bytes() { switch c { case '\n': b = append(b, `\n`...) case '\r': b = append(b, `\r`...) case '\t': b = append(b, `\t`...) case '"': b = append(b, `\"`...) case '\'': b = append(b, `\'`...) case '\\': b = append(b, `\\`...) default: if c >= 0x20 && c <= 0x7e { b = append(b, c) } else { b = append(b, fmt.Sprintf(`\%03o`, c)...) } } } def = string(b) case protoreflect.FloatKind, protoreflect.DoubleKind: f := field.Desc.Default().Float() switch { case math.IsInf(f, -1): def = "-inf" case math.IsInf(f, 1): def = "inf" case math.IsNaN(f): def = "nan" default: def = fmt.Sprint(field.Desc.Default().Interface()) } default: def = fmt.Sprint(field.Desc.Default().Interface()) } tag = append(tag, "def="+def) } return strings.Join(tag, ",") } func fieldDefaultValue(g *protogen.GeneratedFile, message *protogen.Message, field *protogen.Field) string { if field.Desc.Cardinality() == protoreflect.Repeated { return "nil" } if fieldHasDefault(field) { defVarName := "Default_" + message.GoIdent.GoName + "_" + field.GoName if field.Desc.Kind() == protoreflect.BytesKind { return "append([]byte(nil), " + defVarName + "...)" } return defVarName } switch field.Desc.Kind() { case protoreflect.BoolKind: return "false" case protoreflect.StringKind: return `""` case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.BytesKind: return "nil" case protoreflect.EnumKind: return g.QualifiedGoIdent(field.EnumType.Values[0].GoIdent) default: return "0" } } // fieldHasDefault returns true if we consider a field to have a default value. // // For consistency with the previous generator, it returns false for fields with // [default=""], preventing the generation of a default const or var for these // fields. // // TODO: Drop this special case. func fieldHasDefault(field *protogen.Field) bool { if !field.Desc.HasDefault() { return false } switch field.Desc.Kind() { case protoreflect.StringKind: return field.Desc.Default().String() != "" case protoreflect.BytesKind: return len(field.Desc.Default().Bytes()) > 0 } return true } var wireTypes = map[protoreflect.Kind]string{ protoreflect.BoolKind: "varint", protoreflect.EnumKind: "varint", protoreflect.Int32Kind: "varint", protoreflect.Sint32Kind: "zigzag32", protoreflect.Uint32Kind: "varint", protoreflect.Int64Kind: "varint", protoreflect.Sint64Kind: "zigzag64", protoreflect.Uint64Kind: "varint", protoreflect.Sfixed32Kind: "fixed32", protoreflect.Fixed32Kind: "fixed32", protoreflect.FloatKind: "fixed32", protoreflect.Sfixed64Kind: "fixed64", protoreflect.Fixed64Kind: "fixed64", protoreflect.DoubleKind: "fixed64", protoreflect.StringKind: "bytes", protoreflect.BytesKind: "bytes", protoreflect.MessageKind: "bytes", protoreflect.GroupKind: "group", } func fieldJSONTag(field *protogen.Field) string { return string(field.Desc.Name()) + ",omitempty" } func genExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, extension *protogen.Extension) { // Special case for proto2 message sets: If this extension is extending // proto2.bridge.MessageSet, and its final name component is "message_set_extension", // then drop that last component. // // TODO: This should be implemented in the text formatter rather than the generator. // In addition, the situation for when to apply this special case is implemented // differently in other languages: // https://github.com/google/protobuf/blob/aff10976/src/google/protobuf/text_format.cc#L1560 name := extension.Desc.FullName() if n, ok := isExtensionMessageSetElement(extension); ok { name = n } g.P("var ", extensionVar(f.File, extension), " = &", protogen.GoIdent{ GoImportPath: protoPackage, GoName: "ExtensionDesc", }, "{") g.P("ExtendedType: (*", extension.ExtendedType.GoIdent, ")(nil),") goType, pointer := fieldGoType(g, extension) if pointer { goType = "*" + goType } g.P("ExtensionType: (", goType, ")(nil),") g.P("Field: ", extension.Desc.Number(), ",") g.P("Name: ", strconv.Quote(string(name)), ",") g.P("Tag: ", strconv.Quote(fieldProtobufTag(extension)), ",") g.P("Filename: ", strconv.Quote(f.Desc.Path()), ",") g.P("}") g.P() } // isExtensionMessageSetELement returns the adjusted name of an extension // which extends proto2.bridge.MessageSet. func isExtensionMessageSetElement(extension *protogen.Extension) (name protoreflect.FullName, ok bool) { opts := extension.ExtendedType.Desc.Options().(*descpb.MessageOptions) if !opts.GetMessageSetWireFormat() || extension.Desc.Name() != "message_set_extension" { return "", false } if extension.ParentMessage == nil { // This case shouldn't be given special handling at all--we're // only supposed to drop the ".message_set_extension" for // extensions defined within a message (i.e., the extension // takes the message's name). // // This matches the behavior of the v1 generator, however. // // TODO: See if we can drop this case. name = extension.Desc.FullName() name = name[:len(name)-len("message_set_extension")] return name, true } return extension.Desc.FullName().Parent(), true } // extensionVar returns the var holding the ExtensionDesc for an extension. func extensionVar(f *protogen.File, extension *protogen.Extension) protogen.GoIdent { name := "E_" if extension.ParentMessage != nil { name += extension.ParentMessage.GoIdent.GoName + "_" } name += extension.GoName return protogen.GoIdent{ GoImportPath: f.GoImportPath, GoName: name, } } // genInitFunction generates an init function that registers the types in the // generated file with the proto package. func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) { if len(f.allMessages) == 0 && len(f.allEnums) == 0 && len(f.allExtensions) == 0 { return } g.P("func init() {") for _, enum := range f.allEnums { name := enum.GoIdent.GoName g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "RegisterEnum", }, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name)) } for _, message := range f.allMessages { if message.Desc.IsMapEntry() { continue } for _, extension := range message.Extensions { genRegisterExtension(gen, g, f, extension) } name := message.GoIdent.GoName g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "RegisterType", }, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName())) // Types of map fields, sorted by the name of the field message type. var mapFields []*protogen.Field for _, field := range message.Fields { if field.Desc.IsMap() { mapFields = append(mapFields, field) } } sort.Slice(mapFields, func(i, j int) bool { ni := mapFields[i].MessageType.Desc.FullName() nj := mapFields[j].MessageType.Desc.FullName() return ni < nj }) for _, field := range mapFields { typeName := string(field.MessageType.Desc.FullName()) goType, _ := fieldGoType(g, field) g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "RegisterMapType", }, fmt.Sprintf("((%v)(nil), %q)", goType, typeName)) } } for _, extension := range f.Extensions { genRegisterExtension(gen, g, f, extension) } g.P("}") g.P() } func genRegisterExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, extension *protogen.Extension) { g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "RegisterExtension", }, "(", extensionVar(f.File, extension), ")") if name, ok := isExtensionMessageSetElement(extension); ok { goType, pointer := fieldGoType(g, extension) if pointer { goType = "*" + goType } g.P(protogen.GoIdent{ GoImportPath: protoPackage, GoName: "RegisterMessageSetType", }, "((", goType, ")(nil), ", extension.Desc.Number(), ",", strconv.Quote(string(name)), ")") } } // deprecationComment returns a standard deprecation comment if deprecated is true. func deprecationComment(deprecated bool) string { if !deprecated { return "" } return "// Deprecated: Do not use." } func genWellKnownType(g *protogen.GeneratedFile, ptr string, ident protogen.GoIdent, desc protoreflect.Descriptor) { if wellKnownTypes[desc.FullName()] { g.P("func (", ptr, ident, `) XXX_WellKnownType() string { return "`, desc.Name(), `" }`) g.P() } } // Names of messages and enums for which we will generate XXX_WellKnownType methods. var wellKnownTypes = map[protoreflect.FullName]bool{ "google.protobuf.Any": true, "google.protobuf.Duration": true, "google.protobuf.Empty": true, "google.protobuf.Struct": true, "google.protobuf.Timestamp": true, "google.protobuf.BoolValue": true, "google.protobuf.BytesValue": true, "google.protobuf.DoubleValue": true, "google.protobuf.FloatValue": true, "google.protobuf.Int32Value": true, "google.protobuf.Int64Value": true, "google.protobuf.ListValue": true, "google.protobuf.NullValue": true, "google.protobuf.StringValue": true, "google.protobuf.UInt32Value": true, "google.protobuf.UInt64Value": true, "google.protobuf.Value": true, }