mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-09 22:13:27 +00:00
reflect/protoregistry: add (*Types).Register{Message,Enum,Extension}
Add type-safe methods to register message, enum, and extension types. Deprecate the NewTypes function and the (*Types).Register method. Add (*File).RegisterFile and deprecate the NewFiles function and the (*File).Register method. Updates golang/protobuf#963 Change-Id: Ie89e77526e0874539e9bd929ca0ba8d758e65a6e Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/199898 Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
b76294a1a6
commit
c826885a2a
@ -157,7 +157,7 @@ func New(req *pluginpb.CodeGeneratorRequest, opts *Options) (*Plugin, error) {
|
|||||||
gen := &Plugin{
|
gen := &Plugin{
|
||||||
Request: req,
|
Request: req,
|
||||||
FilesByPath: make(map[string]*File),
|
FilesByPath: make(map[string]*File),
|
||||||
fileReg: protoregistry.NewFiles(),
|
fileReg: new(protoregistry.Files),
|
||||||
enumsByName: make(map[protoreflect.FullName]*Enum),
|
enumsByName: make(map[protoreflect.FullName]*Enum),
|
||||||
messagesByName: make(map[protoreflect.FullName]*Message),
|
messagesByName: make(map[protoreflect.FullName]*Message),
|
||||||
opts: opts,
|
opts: opts,
|
||||||
@ -440,7 +440,7 @@ func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPac
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
|
return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
|
||||||
}
|
}
|
||||||
if err := gen.fileReg.Register(desc); err != nil {
|
if err := gen.fileReg.RegisterFile(desc); err != nil {
|
||||||
return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
|
return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
|
||||||
}
|
}
|
||||||
f := &File{
|
f := &File{
|
||||||
|
@ -44,7 +44,7 @@ type Builder struct {
|
|||||||
FileRegistry interface {
|
FileRegistry interface {
|
||||||
FindFileByPath(string) (protoreflect.FileDescriptor, error)
|
FindFileByPath(string) (protoreflect.FileDescriptor, error)
|
||||||
FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
|
FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
|
||||||
Register(...pref.FileDescriptor) error
|
RegisterFile(pref.FileDescriptor) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ func (db Builder) Build() (out Out) {
|
|||||||
out.Extensions = fd.allExtensions
|
out.Extensions = fd.allExtensions
|
||||||
out.Services = fd.allServices
|
out.Services = fd.allServices
|
||||||
|
|
||||||
if err := db.FileRegistry.Register(fd); err != nil {
|
if err := db.FileRegistry.RegisterFile(fd); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
@ -108,7 +108,9 @@ type Builder struct {
|
|||||||
// TypeRegistry is the registry to register each type descriptor.
|
// TypeRegistry is the registry to register each type descriptor.
|
||||||
// If nil, it uses protoregistry.GlobalTypes.
|
// If nil, it uses protoregistry.GlobalTypes.
|
||||||
TypeRegistry interface {
|
TypeRegistry interface {
|
||||||
Register(...preg.Type) error
|
RegisterMessage(pref.MessageType) error
|
||||||
|
RegisterEnum(pref.EnumType) error
|
||||||
|
RegisterExtension(pref.ExtensionType) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,7 +151,7 @@ func (tb Builder) Build() (out Out) {
|
|||||||
Desc: &fbOut.Enums[i],
|
Desc: &fbOut.Enums[i],
|
||||||
}
|
}
|
||||||
// Register enum types.
|
// Register enum types.
|
||||||
if err := tb.TypeRegistry.Register(&tb.EnumInfos[i]); err != nil {
|
if err := tb.TypeRegistry.RegisterEnum(&tb.EnumInfos[i]); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -170,7 +172,7 @@ func (tb Builder) Build() (out Out) {
|
|||||||
tb.MessageInfos[i].Desc = &fbOut.Messages[i]
|
tb.MessageInfos[i].Desc = &fbOut.Messages[i]
|
||||||
|
|
||||||
// Register message types.
|
// Register message types.
|
||||||
if err := tb.TypeRegistry.Register(&tb.MessageInfos[i]); err != nil {
|
if err := tb.TypeRegistry.RegisterMessage(&tb.MessageInfos[i]); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -232,7 +234,7 @@ func (tb Builder) Build() (out Out) {
|
|||||||
pimpl.InitExtensionInfo(&tb.ExtensionInfos[i], &fbOut.Extensions[i], goType)
|
pimpl.InitExtensionInfo(&tb.ExtensionInfos[i], &fbOut.Extensions[i], goType)
|
||||||
|
|
||||||
// Register extension types.
|
// Register extension types.
|
||||||
if err := tb.TypeRegistry.Register(&tb.ExtensionInfos[i]); err != nil {
|
if err := tb.TypeRegistry.RegisterExtension(&tb.ExtensionInfos[i]); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,7 +276,7 @@ type (
|
|||||||
fileRegistry interface {
|
fileRegistry interface {
|
||||||
FindFileByPath(string) (pref.FileDescriptor, error)
|
FindFileByPath(string) (pref.FileDescriptor, error)
|
||||||
FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
|
FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
|
||||||
Register(...pref.FileDescriptor) error
|
RegisterFile(pref.FileDescriptor) error
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -70,4 +70,4 @@ type resolverOnly struct {
|
|||||||
*protoregistry.Files
|
*protoregistry.Files
|
||||||
}
|
}
|
||||||
|
|
||||||
func (resolverOnly) Register(...protoreflect.FileDescriptor) error { return nil }
|
func (resolverOnly) Register(protoreflect.FileDescriptor) error { return nil }
|
||||||
|
@ -52,8 +52,8 @@ var legacyFD = func() []byte {
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
|
mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
|
||||||
preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
|
preg.GlobalFiles.RegisterFile(mt.Descriptor().ParentFile())
|
||||||
preg.GlobalTypes.Register(mt)
|
preg.GlobalTypes.RegisterMessage(mt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustMakeExtensionType(fileDesc, extDesc string, t reflect.Type, r pdesc.Resolver) pref.ExtensionType {
|
func mustMakeExtensionType(fileDesc, extDesc string, t reflect.Type, r pdesc.Resolver) pref.ExtensionType {
|
||||||
@ -82,7 +82,7 @@ var (
|
|||||||
testMessageV1Desc = pimpl.Export{}.MessageDescriptorOf((*proto2_20180125.Message_ChildMessage)(nil))
|
testMessageV1Desc = pimpl.Export{}.MessageDescriptorOf((*proto2_20180125.Message_ChildMessage)(nil))
|
||||||
testMessageV2Desc = enumMessagesType.Desc
|
testMessageV2Desc = enumMessagesType.Desc
|
||||||
|
|
||||||
depReg = preg.NewFiles(
|
depReg = newFileRegistry(
|
||||||
testParentDesc.ParentFile(),
|
testParentDesc.ParentFile(),
|
||||||
testEnumV1Desc.ParentFile(),
|
testEnumV1Desc.ParentFile(),
|
||||||
testMessageV1Desc.ParentFile(),
|
testMessageV1Desc.ParentFile(),
|
||||||
|
@ -990,7 +990,7 @@ var enumMessagesType = pimpl.MessageInfo{GoReflectType: reflect.TypeOf(new(EnumM
|
|||||||
{name:"F7Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_ENUM type_name:".EnumProto3"}] options:{map_entry:true}},
|
{name:"F7Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_ENUM type_name:".EnumProto3"}] options:{map_entry:true}},
|
||||||
{name:"F8Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_MESSAGE type_name:".ScalarProto3"}] options:{map_entry:true}}
|
{name:"F8Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_MESSAGE type_name:".ScalarProto3"}] options:{map_entry:true}}
|
||||||
]
|
]
|
||||||
`, protoregistry.NewFiles(
|
`, newFileRegistry(
|
||||||
EnumProto2(0).Descriptor().ParentFile(),
|
EnumProto2(0).Descriptor().ParentFile(),
|
||||||
EnumProto3(0).Descriptor().ParentFile(),
|
EnumProto3(0).Descriptor().ParentFile(),
|
||||||
((*ScalarProto2)(nil)).ProtoReflect().Descriptor().ParentFile(),
|
((*ScalarProto2)(nil)).ProtoReflect().Descriptor().ParentFile(),
|
||||||
@ -999,6 +999,14 @@ var enumMessagesType = pimpl.MessageInfo{GoReflectType: reflect.TypeOf(new(EnumM
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newFileRegistry(files ...pref.FileDescriptor) *protoregistry.Files {
|
||||||
|
r := new(protoregistry.Files)
|
||||||
|
for _, file := range files {
|
||||||
|
r.RegisterFile(file)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func (m *EnumMessages) ProtoReflect() pref.Message { return enumMessagesType.MessageOf(m) }
|
func (m *EnumMessages) ProtoReflect() pref.Message { return enumMessagesType.MessageOf(m) }
|
||||||
|
|
||||||
func (*EnumMessages) XXX_OneofWrappers() []interface{} {
|
func (*EnumMessages) XXX_OneofWrappers() []interface{} {
|
||||||
|
@ -61,7 +61,7 @@ func RegisterFile(s string, d []byte) {
|
|||||||
|
|
||||||
func RegisterType(m Message, s string) {
|
func RegisterType(m Message, s string) {
|
||||||
mt := protoimpl.X.LegacyMessageTypeOf(m, protoreflect.FullName(s))
|
mt := protoimpl.X.LegacyMessageTypeOf(m, protoreflect.FullName(s))
|
||||||
if err := protoregistry.GlobalTypes.Register(mt); err != nil {
|
if err := protoregistry.GlobalTypes.RegisterMessage(mt); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,7 +75,7 @@ func RegisterEnum(string, map[int32]string, map[string]int32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RegisterExtension(d *ExtensionDesc) {
|
func RegisterExtension(d *ExtensionDesc) {
|
||||||
if err := protoregistry.GlobalTypes.Register(d); err != nil {
|
if err := protoregistry.GlobalTypes.RegisterExtension(d); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -916,7 +916,7 @@ func TestNewFile(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dependency %d: unexpected NewFile() error: %v", i, err)
|
t.Fatalf("dependency %d: unexpected NewFile() error: %v", i, err)
|
||||||
}
|
}
|
||||||
if err := r.Register(f); err != nil {
|
if err := r.RegisterFile(f); err != nil {
|
||||||
t.Fatalf("dependency %d: unexpected Register() error: %v", i, err)
|
t.Fatalf("dependency %d: unexpected Register() error: %v", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,20 +75,37 @@ type packageDescriptor struct {
|
|||||||
|
|
||||||
// NewFiles returns a registry initialized with the provided set of files.
|
// NewFiles returns a registry initialized with the provided set of files.
|
||||||
// Files with a namespace conflict with an pre-existing file are not registered.
|
// Files with a namespace conflict with an pre-existing file are not registered.
|
||||||
|
//
|
||||||
|
// Deprecated: Use Register.
|
||||||
func NewFiles(files ...protoreflect.FileDescriptor) *Files {
|
func NewFiles(files ...protoreflect.FileDescriptor) *Files {
|
||||||
r := new(Files)
|
r := new(Files)
|
||||||
r.Register(files...) // ignore errors; first takes precedence
|
for _, file := range files {
|
||||||
|
r.RegisterFile(file) // ignore errors; first takes precedence
|
||||||
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers the provided list of file descriptors.
|
// Register registers the provided list of file descriptors.
|
||||||
//
|
//
|
||||||
// If any descriptor within a file conflicts with the descriptor of any
|
// Deprecated: Use RegisterFile.
|
||||||
|
func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
|
||||||
|
var firstErr error
|
||||||
|
for _, file := range files {
|
||||||
|
if err := r.RegisterFile(file); err != nil && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterFile registers the provided file descriptor.
|
||||||
|
//
|
||||||
|
// If any descriptor within the file conflicts with the descriptor of any
|
||||||
// previously registered file (e.g., two enums with the same full name),
|
// previously registered file (e.g., two enums with the same full name),
|
||||||
// then that file is not registered and an error is returned.
|
// then the file is not registered and an error is returned.
|
||||||
//
|
//
|
||||||
// It is permitted for multiple files to have the same file path.
|
// It is permitted for multiple files to have the same file path.
|
||||||
func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
|
func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
|
||||||
if r == GlobalFiles {
|
if r == GlobalFiles {
|
||||||
globalMutex.Lock()
|
globalMutex.Lock()
|
||||||
defer globalMutex.Unlock()
|
defer globalMutex.Unlock()
|
||||||
@ -99,32 +116,23 @@ func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
|
|||||||
}
|
}
|
||||||
r.filesByPath = make(map[string]protoreflect.FileDescriptor)
|
r.filesByPath = make(map[string]protoreflect.FileDescriptor)
|
||||||
}
|
}
|
||||||
var firstErr error
|
path := file.Path()
|
||||||
for _, file := range files {
|
|
||||||
if err := r.registerFile(file); err != nil && firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return firstErr
|
|
||||||
}
|
|
||||||
func (r *Files) registerFile(fd protoreflect.FileDescriptor) error {
|
|
||||||
path := fd.Path()
|
|
||||||
if prev := r.filesByPath[path]; prev != nil {
|
if prev := r.filesByPath[path]; prev != nil {
|
||||||
err := errors.New("file %q is already registered", fd.Path())
|
err := errors.New("file %q is already registered", file.Path())
|
||||||
err = amendErrorWithCaller(err, prev, fd)
|
err = amendErrorWithCaller(err, prev, file)
|
||||||
if r == GlobalFiles && ignoreConflict(fd, err) {
|
if r == GlobalFiles && ignoreConflict(file, err) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for name := fd.Package(); name != ""; name = name.Parent() {
|
for name := file.Package(); name != ""; name = name.Parent() {
|
||||||
switch prev := r.descsByName[name]; prev.(type) {
|
switch prev := r.descsByName[name]; prev.(type) {
|
||||||
case nil, *packageDescriptor:
|
case nil, *packageDescriptor:
|
||||||
default:
|
default:
|
||||||
err := errors.New("file %q has a package name conflict over %v", fd.Path(), name)
|
err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
|
||||||
err = amendErrorWithCaller(err, prev, fd)
|
err = amendErrorWithCaller(err, prev, file)
|
||||||
if r == GlobalFiles && ignoreConflict(fd, err) {
|
if r == GlobalFiles && ignoreConflict(file, err) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@ -132,11 +140,11 @@ func (r *Files) registerFile(fd protoreflect.FileDescriptor) error {
|
|||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
var hasConflict bool
|
var hasConflict bool
|
||||||
rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
|
rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
|
||||||
if prev := r.descsByName[d.FullName()]; prev != nil {
|
if prev := r.descsByName[d.FullName()]; prev != nil {
|
||||||
hasConflict = true
|
hasConflict = true
|
||||||
err = errors.New("file %q has a name conflict over %v", fd.Path(), d.FullName())
|
err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
|
||||||
err = amendErrorWithCaller(err, prev, fd)
|
err = amendErrorWithCaller(err, prev, file)
|
||||||
if r == GlobalFiles && ignoreConflict(d, err) {
|
if r == GlobalFiles && ignoreConflict(d, err) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
@ -146,17 +154,17 @@ func (r *Files) registerFile(fd protoreflect.FileDescriptor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for name := fd.Package(); name != ""; name = name.Parent() {
|
for name := file.Package(); name != ""; name = name.Parent() {
|
||||||
if r.descsByName[name] == nil {
|
if r.descsByName[name] == nil {
|
||||||
r.descsByName[name] = &packageDescriptor{}
|
r.descsByName[name] = &packageDescriptor{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p := r.descsByName[fd.Package()].(*packageDescriptor)
|
p := r.descsByName[file.Package()].(*packageDescriptor)
|
||||||
p.files = append(p.files, fd)
|
p.files = append(p.files, file)
|
||||||
rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
|
rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
|
||||||
r.descsByName[d.FullName()] = d
|
r.descsByName[d.FullName()] = d
|
||||||
})
|
})
|
||||||
r.filesByPath[path] = fd
|
r.filesByPath[path] = file
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,6 +369,10 @@ func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A Type is a protoreflect.EnumType, protoreflect.MessageType, or protoreflect.ExtensionType.
|
// A Type is a protoreflect.EnumType, protoreflect.MessageType, or protoreflect.ExtensionType.
|
||||||
|
//
|
||||||
|
// Deprecated: Do not use.
|
||||||
|
//
|
||||||
|
// TODO: Remove.
|
||||||
type Type interface{}
|
type Type interface{}
|
||||||
|
|
||||||
// MessageTypeResolver is an interface for looking up messages.
|
// MessageTypeResolver is an interface for looking up messages.
|
||||||
@ -443,13 +455,15 @@ type Types struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
typesByName map[protoreflect.FullName]Type
|
typesByName map[protoreflect.FullName]interface{}
|
||||||
extensionsByMessage map[protoreflect.FullName]extensionsByNumber
|
extensionsByMessage map[protoreflect.FullName]extensionsByNumber
|
||||||
extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType
|
extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewTypes returns a registry initialized with the provided set of types.
|
// NewTypes returns a registry initialized with the provided set of types.
|
||||||
// If there are conflicts, the first one takes precedence.
|
// If there are conflicts, the first one takes precedence.
|
||||||
|
//
|
||||||
|
// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
|
||||||
func NewTypes(typs ...Type) *Types {
|
func NewTypes(typs ...Type) *Types {
|
||||||
r := new(Types)
|
r := new(Types)
|
||||||
r.Register(typs...) // ignore errors; first takes precedence
|
r.Register(typs...) // ignore errors; first takes precedence
|
||||||
@ -458,88 +472,109 @@ func NewTypes(typs ...Type) *Types {
|
|||||||
|
|
||||||
// Register registers the provided list of descriptor types.
|
// Register registers the provided list of descriptor types.
|
||||||
//
|
//
|
||||||
// If a registration conflict occurs for enum, message, or extension types
|
// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
|
||||||
// (e.g., two different types have the same full name),
|
|
||||||
// then the first type takes precedence and an error is returned.
|
|
||||||
func (r *Types) Register(typs ...Type) error {
|
func (r *Types) Register(typs ...Type) error {
|
||||||
|
var firstErr error
|
||||||
|
for _, typ := range typs {
|
||||||
|
var err error
|
||||||
|
switch t := typ.(type) {
|
||||||
|
case protoreflect.EnumType:
|
||||||
|
err = r.RegisterEnum(t)
|
||||||
|
case protoreflect.MessageType:
|
||||||
|
err = r.RegisterMessage(t)
|
||||||
|
case protoreflect.ExtensionType:
|
||||||
|
err = r.RegisterExtension(t)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("invalid type: %T", t))
|
||||||
|
}
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMessage registers the provided message type.
|
||||||
|
//
|
||||||
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||||||
|
func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
|
||||||
if r == GlobalTypes {
|
if r == GlobalTypes {
|
||||||
globalMutex.Lock()
|
globalMutex.Lock()
|
||||||
defer globalMutex.Unlock()
|
defer globalMutex.Unlock()
|
||||||
}
|
}
|
||||||
var firstErr error
|
|
||||||
typeLoop:
|
|
||||||
for _, typ := range typs {
|
|
||||||
switch typ.(type) {
|
|
||||||
case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
|
|
||||||
// Check for conflicts in typesByName.
|
|
||||||
var desc protoreflect.Descriptor
|
|
||||||
var pcnt *int
|
|
||||||
switch t := typ.(type) {
|
|
||||||
case protoreflect.EnumType:
|
|
||||||
desc = t.Descriptor()
|
|
||||||
pcnt = &r.numEnums
|
|
||||||
case protoreflect.MessageType:
|
|
||||||
desc = t.Descriptor()
|
|
||||||
pcnt = &r.numMessages
|
|
||||||
case protoreflect.ExtensionType:
|
|
||||||
desc = t.TypeDescriptor()
|
|
||||||
pcnt = &r.numExtensions
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("invalid type: %T", t))
|
|
||||||
}
|
|
||||||
name := desc.FullName()
|
|
||||||
if prev := r.typesByName[name]; prev != nil {
|
|
||||||
err := errors.New("%v %v is already registered", typeName(typ), name)
|
|
||||||
err = amendErrorWithCaller(err, prev, typ)
|
|
||||||
if r == GlobalTypes && ignoreConflict(desc, err) {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
if firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
continue typeLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for conflicts in extensionsByMessage.
|
if err := r.register("message", mt.Descriptor(), mt); err != nil {
|
||||||
if xt, _ := typ.(protoreflect.ExtensionType); xt != nil {
|
return err
|
||||||
xd := xt.TypeDescriptor()
|
}
|
||||||
field := xd.Number()
|
r.numMessages++
|
||||||
message := xd.ContainingMessage().FullName()
|
return nil
|
||||||
if prev := r.extensionsByMessage[message][field]; prev != nil {
|
}
|
||||||
err := errors.New("extension number %d is already registered on message %v", field, message)
|
|
||||||
err = amendErrorWithCaller(err, prev, typ)
|
|
||||||
if r == GlobalTypes && ignoreConflict(xd, err) {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
if firstErr == nil {
|
|
||||||
firstErr = err
|
|
||||||
}
|
|
||||||
continue typeLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update extensionsByMessage.
|
// RegisterEnum registers the provided enum type.
|
||||||
if r.extensionsByMessage == nil {
|
//
|
||||||
r.extensionsByMessage = make(extensionsByMessage)
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||||||
}
|
func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
|
||||||
if r.extensionsByMessage[message] == nil {
|
if r == GlobalTypes {
|
||||||
r.extensionsByMessage[message] = make(extensionsByNumber)
|
globalMutex.Lock()
|
||||||
}
|
defer globalMutex.Unlock()
|
||||||
r.extensionsByMessage[message][field] = xt
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Update typesByName and the count.
|
if err := r.register("enum", et.Descriptor(), et); err != nil {
|
||||||
if r.typesByName == nil {
|
return err
|
||||||
r.typesByName = make(typesByName)
|
}
|
||||||
}
|
r.numEnums++
|
||||||
r.typesByName[name] = typ
|
return nil
|
||||||
(*pcnt)++
|
}
|
||||||
default:
|
|
||||||
if firstErr == nil {
|
// RegisterExtension registers the provided extension type.
|
||||||
firstErr = errors.New("invalid type: %v", typeName(typ))
|
//
|
||||||
}
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||||||
|
func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
|
||||||
|
if r == GlobalTypes {
|
||||||
|
globalMutex.Lock()
|
||||||
|
defer globalMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
xd := xt.TypeDescriptor()
|
||||||
|
field := xd.Number()
|
||||||
|
message := xd.ContainingMessage().FullName()
|
||||||
|
if prev := r.extensionsByMessage[message][field]; prev != nil {
|
||||||
|
err := errors.New("extension number %d is already registered on message %v", field, message)
|
||||||
|
err = amendErrorWithCaller(err, prev, xt)
|
||||||
|
if !(r == GlobalTypes && ignoreConflict(xd, err)) {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return firstErr
|
|
||||||
|
if err := r.register("extension", xt.TypeDescriptor(), xt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r.extensionsByMessage == nil {
|
||||||
|
r.extensionsByMessage = make(extensionsByMessage)
|
||||||
|
}
|
||||||
|
if r.extensionsByMessage[message] == nil {
|
||||||
|
r.extensionsByMessage[message] = make(extensionsByNumber)
|
||||||
|
}
|
||||||
|
r.extensionsByMessage[message][field] = xt
|
||||||
|
r.numExtensions++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
|
||||||
|
name := desc.FullName()
|
||||||
|
prev := r.typesByName[name]
|
||||||
|
if prev != nil {
|
||||||
|
err := errors.New("%v %v is already registered", kind, name)
|
||||||
|
err = amendErrorWithCaller(err, prev, typ)
|
||||||
|
if !(r == GlobalTypes && ignoreConflict(desc, err)) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.typesByName == nil {
|
||||||
|
r.typesByName = make(typesByName)
|
||||||
|
}
|
||||||
|
r.typesByName[name] = typ
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindEnumByName looks up an enum by its full name.
|
// FindEnumByName looks up an enum by its full name.
|
||||||
|
@ -282,7 +282,7 @@ func TestFiles(t *testing.T) {
|
|||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
var files preg.Files
|
var files preg.Files
|
||||||
for i, tc := range tt.files {
|
for i, tc := range tt.files {
|
||||||
gotErr := files.Register(tc.inFile)
|
gotErr := files.RegisterFile(tc.inFile)
|
||||||
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
|
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
|
||||||
t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
|
t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
|
||||||
}
|
}
|
||||||
@ -332,8 +332,17 @@ func TestTypes(t *testing.T) {
|
|||||||
xt1 := testpb.E_StringField
|
xt1 := testpb.E_StringField
|
||||||
xt2 := testpb.E_Message4_MessageField
|
xt2 := testpb.E_Message4_MessageField
|
||||||
registry := new(preg.Types)
|
registry := new(preg.Types)
|
||||||
if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
|
if err := registry.RegisterMessage(mt1); err != nil {
|
||||||
t.Fatalf("registry.Register() returns unexpected error: %v", err)
|
t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err)
|
||||||
|
}
|
||||||
|
if err := registry.RegisterEnum(et1); err != nil {
|
||||||
|
t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err)
|
||||||
|
}
|
||||||
|
if err := registry.RegisterExtension(xt1); err != nil {
|
||||||
|
t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err)
|
||||||
|
}
|
||||||
|
if err := registry.RegisterExtension(xt2); err != nil {
|
||||||
|
t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("FindMessageByName", func(t *testing.T) {
|
t.Run("FindMessageByName", func(t *testing.T) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user