diff --git a/compiler/protogen/protogen.go b/compiler/protogen/protogen.go index 5dbdaf96..8ffccca2 100644 --- a/compiler/protogen/protogen.go +++ b/compiler/protogen/protogen.go @@ -157,7 +157,7 @@ func New(req *pluginpb.CodeGeneratorRequest, opts *Options) (*Plugin, error) { gen := &Plugin{ Request: req, FilesByPath: make(map[string]*File), - fileReg: protoregistry.NewFiles(), + fileReg: new(protoregistry.Files), enumsByName: make(map[protoreflect.FullName]*Enum), messagesByName: make(map[protoreflect.FullName]*Message), opts: opts, @@ -440,7 +440,7 @@ func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPac if err != nil { return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err) } - if err := gen.fileReg.Register(desc); err != nil { + if err := gen.fileReg.RegisterFile(desc); err != nil { return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err) } f := &File{ diff --git a/internal/filedesc/build.go b/internal/filedesc/build.go index 37194552..f4af7ad8 100644 --- a/internal/filedesc/build.go +++ b/internal/filedesc/build.go @@ -44,7 +44,7 @@ type Builder struct { FileRegistry interface { FindFileByPath(string) (protoreflect.FileDescriptor, error) FindDescriptorByName(pref.FullName) (pref.Descriptor, error) - Register(...pref.FileDescriptor) error + RegisterFile(pref.FileDescriptor) error } } @@ -107,7 +107,7 @@ func (db Builder) Build() (out Out) { out.Extensions = fd.allExtensions out.Services = fd.allServices - if err := db.FileRegistry.Register(fd); err != nil { + if err := db.FileRegistry.RegisterFile(fd); err != nil { panic(err) } return out diff --git a/internal/filetype/build.go b/internal/filetype/build.go index f421a622..0a0dd35d 100644 --- a/internal/filetype/build.go +++ b/internal/filetype/build.go @@ -108,7 +108,9 @@ type Builder struct { // TypeRegistry is the registry to register each type descriptor. // If nil, it uses protoregistry.GlobalTypes. TypeRegistry interface { - Register(...preg.Type) error + RegisterMessage(pref.MessageType) error + RegisterEnum(pref.EnumType) error + RegisterExtension(pref.ExtensionType) error } } @@ -149,7 +151,7 @@ func (tb Builder) Build() (out Out) { Desc: &fbOut.Enums[i], } // Register enum types. - if err := tb.TypeRegistry.Register(&tb.EnumInfos[i]); err != nil { + if err := tb.TypeRegistry.RegisterEnum(&tb.EnumInfos[i]); err != nil { panic(err) } } @@ -170,7 +172,7 @@ func (tb Builder) Build() (out Out) { tb.MessageInfos[i].Desc = &fbOut.Messages[i] // Register message types. - if err := tb.TypeRegistry.Register(&tb.MessageInfos[i]); err != nil { + if err := tb.TypeRegistry.RegisterMessage(&tb.MessageInfos[i]); err != nil { panic(err) } } @@ -232,7 +234,7 @@ func (tb Builder) Build() (out Out) { pimpl.InitExtensionInfo(&tb.ExtensionInfos[i], &fbOut.Extensions[i], goType) // Register extension types. - if err := tb.TypeRegistry.Register(&tb.ExtensionInfos[i]); err != nil { + if err := tb.TypeRegistry.RegisterExtension(&tb.ExtensionInfos[i]); err != nil { panic(err) } } @@ -274,7 +276,7 @@ type ( fileRegistry interface { FindFileByPath(string) (pref.FileDescriptor, error) FindDescriptorByName(pref.FullName) (pref.Descriptor, error) - Register(...pref.FileDescriptor) error + RegisterFile(pref.FileDescriptor) error } ) diff --git a/internal/impl/legacy_file.go b/internal/impl/legacy_file.go index bccaedef..b61a1354 100644 --- a/internal/impl/legacy_file.go +++ b/internal/impl/legacy_file.go @@ -70,4 +70,4 @@ type resolverOnly struct { *protoregistry.Files } -func (resolverOnly) Register(...protoreflect.FileDescriptor) error { return nil } +func (resolverOnly) Register(protoreflect.FileDescriptor) error { return nil } diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go index f55c0d3c..8e63ba6c 100644 --- a/internal/impl/legacy_test.go +++ b/internal/impl/legacy_test.go @@ -52,8 +52,8 @@ var legacyFD = func() []byte { func init() { mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil)) - preg.GlobalFiles.Register(mt.Descriptor().ParentFile()) - preg.GlobalTypes.Register(mt) + preg.GlobalFiles.RegisterFile(mt.Descriptor().ParentFile()) + preg.GlobalTypes.RegisterMessage(mt) } func mustMakeExtensionType(fileDesc, extDesc string, t reflect.Type, r pdesc.Resolver) pref.ExtensionType { @@ -82,7 +82,7 @@ var ( testMessageV1Desc = pimpl.Export{}.MessageDescriptorOf((*proto2_20180125.Message_ChildMessage)(nil)) testMessageV2Desc = enumMessagesType.Desc - depReg = preg.NewFiles( + depReg = newFileRegistry( testParentDesc.ParentFile(), testEnumV1Desc.ParentFile(), testMessageV1Desc.ParentFile(), diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go index c5a4a6a5..82d5f81d 100644 --- a/internal/impl/message_reflect_test.go +++ b/internal/impl/message_reflect_test.go @@ -990,7 +990,7 @@ var enumMessagesType = pimpl.MessageInfo{GoReflectType: reflect.TypeOf(new(EnumM {name:"F7Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_ENUM type_name:".EnumProto3"}] options:{map_entry:true}}, {name:"F8Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_MESSAGE type_name:".ScalarProto3"}] options:{map_entry:true}} ] - `, protoregistry.NewFiles( + `, newFileRegistry( EnumProto2(0).Descriptor().ParentFile(), EnumProto3(0).Descriptor().ParentFile(), ((*ScalarProto2)(nil)).ProtoReflect().Descriptor().ParentFile(), @@ -999,6 +999,14 @@ var enumMessagesType = pimpl.MessageInfo{GoReflectType: reflect.TypeOf(new(EnumM )), } +func newFileRegistry(files ...pref.FileDescriptor) *protoregistry.Files { + r := new(protoregistry.Files) + for _, file := range files { + r.RegisterFile(file) + } + return r +} + func (m *EnumMessages) ProtoReflect() pref.Message { return enumMessagesType.MessageOf(m) } func (*EnumMessages) XXX_OneofWrappers() []interface{} { diff --git a/internal/protolegacy/proto.go b/internal/protolegacy/proto.go index ffca87a2..b9f023a7 100644 --- a/internal/protolegacy/proto.go +++ b/internal/protolegacy/proto.go @@ -61,7 +61,7 @@ func RegisterFile(s string, d []byte) { func RegisterType(m Message, s string) { mt := protoimpl.X.LegacyMessageTypeOf(m, protoreflect.FullName(s)) - if err := protoregistry.GlobalTypes.Register(mt); err != nil { + if err := protoregistry.GlobalTypes.RegisterMessage(mt); err != nil { panic(err) } } @@ -75,7 +75,7 @@ func RegisterEnum(string, map[int32]string, map[string]int32) { } func RegisterExtension(d *ExtensionDesc) { - if err := protoregistry.GlobalTypes.Register(d); err != nil { + if err := protoregistry.GlobalTypes.RegisterExtension(d); err != nil { panic(err) } } diff --git a/reflect/protodesc/file_test.go b/reflect/protodesc/file_test.go index 38e8b4c6..9ee6c8d7 100644 --- a/reflect/protodesc/file_test.go +++ b/reflect/protodesc/file_test.go @@ -916,7 +916,7 @@ func TestNewFile(t *testing.T) { if err != nil { t.Fatalf("dependency %d: unexpected NewFile() error: %v", i, err) } - if err := r.Register(f); err != nil { + if err := r.RegisterFile(f); err != nil { t.Fatalf("dependency %d: unexpected Register() error: %v", i, err) } } diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index 4689f6be..116bd683 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -75,20 +75,37 @@ type packageDescriptor struct { // NewFiles returns a registry initialized with the provided set of files. // Files with a namespace conflict with an pre-existing file are not registered. +// +// Deprecated: Use Register. func NewFiles(files ...protoreflect.FileDescriptor) *Files { r := new(Files) - r.Register(files...) // ignore errors; first takes precedence + for _, file := range files { + r.RegisterFile(file) // ignore errors; first takes precedence + } return r } // Register registers the provided list of file descriptors. // -// If any descriptor within a file conflicts with the descriptor of any +// Deprecated: Use RegisterFile. +func (r *Files) Register(files ...protoreflect.FileDescriptor) error { + var firstErr error + for _, file := range files { + if err := r.RegisterFile(file); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// RegisterFile registers the provided file descriptor. +// +// If any descriptor within the file conflicts with the descriptor of any // previously registered file (e.g., two enums with the same full name), -// then that file is not registered and an error is returned. +// then the file is not registered and an error is returned. // // It is permitted for multiple files to have the same file path. -func (r *Files) Register(files ...protoreflect.FileDescriptor) error { +func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error { if r == GlobalFiles { globalMutex.Lock() defer globalMutex.Unlock() @@ -99,32 +116,23 @@ func (r *Files) Register(files ...protoreflect.FileDescriptor) error { } r.filesByPath = make(map[string]protoreflect.FileDescriptor) } - var firstErr error - for _, file := range files { - if err := r.registerFile(file); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} -func (r *Files) registerFile(fd protoreflect.FileDescriptor) error { - path := fd.Path() + path := file.Path() if prev := r.filesByPath[path]; prev != nil { - err := errors.New("file %q is already registered", fd.Path()) - err = amendErrorWithCaller(err, prev, fd) - if r == GlobalFiles && ignoreConflict(fd, err) { + err := errors.New("file %q is already registered", file.Path()) + err = amendErrorWithCaller(err, prev, file) + if r == GlobalFiles && ignoreConflict(file, err) { err = nil } return err } - for name := fd.Package(); name != ""; name = name.Parent() { + for name := file.Package(); name != ""; name = name.Parent() { switch prev := r.descsByName[name]; prev.(type) { case nil, *packageDescriptor: default: - err := errors.New("file %q has a package name conflict over %v", fd.Path(), name) - err = amendErrorWithCaller(err, prev, fd) - if r == GlobalFiles && ignoreConflict(fd, err) { + err := errors.New("file %q has a package name conflict over %v", file.Path(), name) + err = amendErrorWithCaller(err, prev, file) + if r == GlobalFiles && ignoreConflict(file, err) { err = nil } return err @@ -132,11 +140,11 @@ func (r *Files) registerFile(fd protoreflect.FileDescriptor) error { } var err error var hasConflict bool - rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) { + rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) { if prev := r.descsByName[d.FullName()]; prev != nil { hasConflict = true - err = errors.New("file %q has a name conflict over %v", fd.Path(), d.FullName()) - err = amendErrorWithCaller(err, prev, fd) + err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName()) + err = amendErrorWithCaller(err, prev, file) if r == GlobalFiles && ignoreConflict(d, err) { err = nil } @@ -146,17 +154,17 @@ func (r *Files) registerFile(fd protoreflect.FileDescriptor) error { return err } - for name := fd.Package(); name != ""; name = name.Parent() { + for name := file.Package(); name != ""; name = name.Parent() { if r.descsByName[name] == nil { r.descsByName[name] = &packageDescriptor{} } } - p := r.descsByName[fd.Package()].(*packageDescriptor) - p.files = append(p.files, fd) - rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) { + p := r.descsByName[file.Package()].(*packageDescriptor) + p.files = append(p.files, file) + rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) { r.descsByName[d.FullName()] = d }) - r.filesByPath[path] = fd + r.filesByPath[path] = file return nil } @@ -361,6 +369,10 @@ func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflec } // A Type is a protoreflect.EnumType, protoreflect.MessageType, or protoreflect.ExtensionType. +// +// Deprecated: Do not use. +// +// TODO: Remove. type Type interface{} // MessageTypeResolver is an interface for looking up messages. @@ -443,13 +455,15 @@ type Types struct { } type ( - typesByName map[protoreflect.FullName]Type + typesByName map[protoreflect.FullName]interface{} extensionsByMessage map[protoreflect.FullName]extensionsByNumber extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType ) // NewTypes returns a registry initialized with the provided set of types. // If there are conflicts, the first one takes precedence. +// +// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension. func NewTypes(typs ...Type) *Types { r := new(Types) r.Register(typs...) // ignore errors; first takes precedence @@ -458,88 +472,109 @@ func NewTypes(typs ...Type) *Types { // Register registers the provided list of descriptor types. // -// If a registration conflict occurs for enum, message, or extension types -// (e.g., two different types have the same full name), -// then the first type takes precedence and an error is returned. +// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension. func (r *Types) Register(typs ...Type) error { + var firstErr error + for _, typ := range typs { + var err error + switch t := typ.(type) { + case protoreflect.EnumType: + err = r.RegisterEnum(t) + case protoreflect.MessageType: + err = r.RegisterMessage(t) + case protoreflect.ExtensionType: + err = r.RegisterExtension(t) + default: + panic(fmt.Sprintf("invalid type: %T", t)) + } + if firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// RegisterMessage registers the provided message type. +// +// If a naming conflict occurs, the type is not registered and an error is returned. +func (r *Types) RegisterMessage(mt protoreflect.MessageType) error { if r == GlobalTypes { globalMutex.Lock() defer globalMutex.Unlock() } - var firstErr error -typeLoop: - for _, typ := range typs { - switch typ.(type) { - case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType: - // Check for conflicts in typesByName. - var desc protoreflect.Descriptor - var pcnt *int - switch t := typ.(type) { - case protoreflect.EnumType: - desc = t.Descriptor() - pcnt = &r.numEnums - case protoreflect.MessageType: - desc = t.Descriptor() - pcnt = &r.numMessages - case protoreflect.ExtensionType: - desc = t.TypeDescriptor() - pcnt = &r.numExtensions - default: - panic(fmt.Sprintf("invalid type: %T", t)) - } - name := desc.FullName() - if prev := r.typesByName[name]; prev != nil { - err := errors.New("%v %v is already registered", typeName(typ), name) - err = amendErrorWithCaller(err, prev, typ) - if r == GlobalTypes && ignoreConflict(desc, err) { - err = nil - } - if firstErr == nil { - firstErr = err - } - continue typeLoop - } - // Check for conflicts in extensionsByMessage. - if xt, _ := typ.(protoreflect.ExtensionType); xt != nil { - xd := xt.TypeDescriptor() - field := xd.Number() - message := xd.ContainingMessage().FullName() - if prev := r.extensionsByMessage[message][field]; prev != nil { - err := errors.New("extension number %d is already registered on message %v", field, message) - err = amendErrorWithCaller(err, prev, typ) - if r == GlobalTypes && ignoreConflict(xd, err) { - err = nil - } - if firstErr == nil { - firstErr = err - } - continue typeLoop - } + if err := r.register("message", mt.Descriptor(), mt); err != nil { + return err + } + r.numMessages++ + return nil +} - // Update extensionsByMessage. - if r.extensionsByMessage == nil { - r.extensionsByMessage = make(extensionsByMessage) - } - if r.extensionsByMessage[message] == nil { - r.extensionsByMessage[message] = make(extensionsByNumber) - } - r.extensionsByMessage[message][field] = xt - } +// RegisterEnum registers the provided enum type. +// +// If a naming conflict occurs, the type is not registered and an error is returned. +func (r *Types) RegisterEnum(et protoreflect.EnumType) error { + if r == GlobalTypes { + globalMutex.Lock() + defer globalMutex.Unlock() + } - // Update typesByName and the count. - if r.typesByName == nil { - r.typesByName = make(typesByName) - } - r.typesByName[name] = typ - (*pcnt)++ - default: - if firstErr == nil { - firstErr = errors.New("invalid type: %v", typeName(typ)) - } + if err := r.register("enum", et.Descriptor(), et); err != nil { + return err + } + r.numEnums++ + return nil +} + +// RegisterExtension registers the provided extension type. +// +// If a naming conflict occurs, the type is not registered and an error is returned. +func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error { + if r == GlobalTypes { + globalMutex.Lock() + defer globalMutex.Unlock() + } + + xd := xt.TypeDescriptor() + field := xd.Number() + message := xd.ContainingMessage().FullName() + if prev := r.extensionsByMessage[message][field]; prev != nil { + err := errors.New("extension number %d is already registered on message %v", field, message) + err = amendErrorWithCaller(err, prev, xt) + if !(r == GlobalTypes && ignoreConflict(xd, err)) { + return err } } - return firstErr + + if err := r.register("extension", xt.TypeDescriptor(), xt); err != nil { + return err + } + if r.extensionsByMessage == nil { + r.extensionsByMessage = make(extensionsByMessage) + } + if r.extensionsByMessage[message] == nil { + r.extensionsByMessage[message] = make(extensionsByNumber) + } + r.extensionsByMessage[message][field] = xt + r.numExtensions++ + return nil +} + +func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error { + name := desc.FullName() + prev := r.typesByName[name] + if prev != nil { + err := errors.New("%v %v is already registered", kind, name) + err = amendErrorWithCaller(err, prev, typ) + if !(r == GlobalTypes && ignoreConflict(desc, err)) { + return err + } + } + if r.typesByName == nil { + r.typesByName = make(typesByName) + } + r.typesByName[name] = typ + return nil } // FindEnumByName looks up an enum by its full name. diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go index 8d70243d..22fbdad6 100644 --- a/reflect/protoregistry/registry_test.go +++ b/reflect/protoregistry/registry_test.go @@ -282,7 +282,7 @@ func TestFiles(t *testing.T) { t.Run("", func(t *testing.T) { var files preg.Files for i, tc := range tt.files { - gotErr := files.Register(tc.inFile) + gotErr := files.RegisterFile(tc.inFile) if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) { t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr) } @@ -332,8 +332,17 @@ func TestTypes(t *testing.T) { xt1 := testpb.E_StringField xt2 := testpb.E_Message4_MessageField registry := new(preg.Types) - if err := registry.Register(mt1, et1, xt1, xt2); err != nil { - t.Fatalf("registry.Register() returns unexpected error: %v", err) + if err := registry.RegisterMessage(mt1); err != nil { + t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err) + } + if err := registry.RegisterEnum(et1); err != nil { + t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err) + } + if err := registry.RegisterExtension(xt1); err != nil { + t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err) + } + if err := registry.RegisterExtension(xt2); err != nil { + t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err) } t.Run("FindMessageByName", func(t *testing.T) {