From 1bca6d9b7d9696b10a482f58688d4eeaaf974b8a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 25 Apr 2023 16:37:46 -0700 Subject: [PATCH] types/dynamicpb: add NewTypes Add a function to construct a dynamic type registry from a protoregistry.Files. The NewTypes constructor takes a concrete Files to permit future improvements based on changes to Files. (For example, we might add a Files.FindExtensionByNumber method, which Types could take advantage of.) Fixes golang/protobuf#1216 Change-Id: I61edba0a94528829d40f69fad773ccb5912859e0 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/489316 Run-TryBot: Damien Neil Reviewed-by: Lasse Folger Reviewed-by: Joseph Tsai --- types/dynamicpb/types.go | 177 ++++++++++++++++++++++++++++++++++ types/dynamicpb/types_test.go | 174 +++++++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+) create mode 100644 types/dynamicpb/types.go create mode 100644 types/dynamicpb/types_test.go diff --git a/types/dynamicpb/types.go b/types/dynamicpb/types.go new file mode 100644 index 00000000..5a8010f1 --- /dev/null +++ b/types/dynamicpb/types.go @@ -0,0 +1,177 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dynamicpb + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +type extField struct { + name protoreflect.FullName + number protoreflect.FieldNumber +} + +// A Types is a collection of dynamically constructed descriptors. +// Its methods are safe for concurrent use. +// +// Types implements protoregistry.MessageTypeResolver and protoregistry.ExtensionTypeResolver. +// A Types may be used as a proto.UnmarshalOptions.Resolver. +type Types struct { + files *protoregistry.Files + + extMu sync.Mutex + atomicExtFiles uint64 + extensionsByMessage map[extField]protoreflect.ExtensionDescriptor +} + +// NewTypes creates a new Types registry with the provided files. +// The Files registry is retained, and changes to Files will be reflected in Types. +// It is not safe to concurrently change the Files while calling Types methods. +func NewTypes(f *protoregistry.Files) *Types { + return &Types{ + files: f, + } +} + +// FindEnumByName looks up an enum by its full name; +// e.g., "google.protobuf.Field.Kind". +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + ed, ok := d.(protoreflect.EnumDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want enum", descName(d)) + } + return NewEnumType(ed), nil +} + +// FindExtensionByName looks up an extension field by the field's full name. +// Note that this is the full name of the field as determined by +// where the extension is declared and is unrelated to the full name of the +// message being extended. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + xd, ok := d.(protoreflect.ExtensionDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want extension", descName(d)) + } + return NewExtensionType(xd), nil +} + +// FindExtensionByNumber looks up an extension field by the field number +// within some parent message, identified by full name. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + // Construct the extension number map lazily, since not every user will need it. + // Update the map if new files are added to the registry. + if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) { + t.updateExtensions() + } + xd := t.extensionsByMessage[extField{message, field}] + if xd == nil { + return nil, protoregistry.NotFound + } + return NewExtensionType(xd), nil +} + +// FindMessageByName looks up a message by its full name; +// e.g. "google.protobuf.Any". +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + md, ok := d.(protoreflect.MessageDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want message", descName(d)) + } + return NewMessageType(md), nil +} + +// FindMessageByURL looks up a message by a URL identifier. +// See documentation on google.protobuf.Any.type_url for the URL format. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { + // This function is similar to FindMessageByName but + // truncates anything before and including '/' in the URL. + message := protoreflect.FullName(url) + if i := strings.LastIndexByte(url, '/'); i >= 0 { + message = message[i+len("/"):] + } + return t.FindMessageByName(message) +} + +func (t *Types) updateExtensions() { + t.extMu.Lock() + defer t.extMu.Unlock() + if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) { + return + } + defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles())) + t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + t.registerExtensions(fd.Extensions()) + t.registerExtensionsInMessages(fd.Messages()) + return true + }) +} + +func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) { + count := mds.Len() + for i := 0; i < count; i++ { + md := mds.Get(i) + t.registerExtensions(md.Extensions()) + t.registerExtensionsInMessages(md.Messages()) + } +} + +func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) { + count := xds.Len() + for i := 0; i < count; i++ { + xd := xds.Get(i) + field := xd.Number() + message := xd.ContainingMessage().FullName() + if t.extensionsByMessage == nil { + t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor) + } + t.extensionsByMessage[extField{message, field}] = xd + } +} + +func descName(d protoreflect.Descriptor) string { + switch d.(type) { + case protoreflect.EnumDescriptor: + return "enum" + case protoreflect.EnumValueDescriptor: + return "enum value" + case protoreflect.MessageDescriptor: + return "message" + case protoreflect.ExtensionDescriptor: + return "extension" + case protoreflect.ServiceDescriptor: + return "service" + default: + return fmt.Sprintf("%T", d) + } +} diff --git a/types/dynamicpb/types_test.go b/types/dynamicpb/types_test.go new file mode 100644 index 00000000..1878f794 --- /dev/null +++ b/types/dynamicpb/types_test.go @@ -0,0 +1,174 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dynamicpb_test + +import ( + "strings" + "testing" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + + registrypb "google.golang.org/protobuf/internal/testprotos/registry" +) + +var _ protoregistry.ExtensionTypeResolver = &dynamicpb.Types{} +var _ protoregistry.MessageTypeResolver = &dynamicpb.Types{} + +func newTestTypes() *dynamicpb.Types { + files := &protoregistry.Files{} + files.RegisterFile(registrypb.File_internal_testprotos_registry_test_proto) + return dynamicpb.NewTypes(files) +} + +func TestDynamicTypesTypeMismatch(t *testing.T) { + types := newTestTypes() + const messageName = "testprotos.Message1" + const enumName = "testprotos.Enum1" + + _, err := types.FindEnumByName(messageName) + want := "found wrong type: got message, want enum" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("types.FindEnumByName(%q) = _, %q, want %q", messageName, err, want) + } + + _, err = types.FindMessageByName(enumName) + want = "found wrong type: got enum, want message" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("types.FindMessageByName(%q) = _, %q, want %q", messageName, err, want) + } + + _, err = types.FindExtensionByName(enumName) + want = "found wrong type: got enum, want extension" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("types.FindExtensionByName(%q) = _, %q, want %q", messageName, err, want) + } +} + +func TestDynamicTypesEnumNotFound(t *testing.T) { + types := newTestTypes() + for _, name := range []protoreflect.FullName{ + "Enum1", + "testprotos.DoesNotExist", + } { + _, err := types.FindEnumByName(name) + if err != protoregistry.NotFound { + t.Errorf("types.FindEnumByName(%q) = _, %v; want protoregistry.NotFound", name, err) + } + } +} + +func TestDynamicTypesFindEnumByName(t *testing.T) { + types := newTestTypes() + name := protoreflect.FullName("testprotos.Enum1") + et, err := types.FindEnumByName(name) + if err != nil { + t.Fatalf("types.FindEnumByName(%q) = %v", name, err) + } + if got, want := et.Descriptor().FullName(), name; got != want { + t.Fatalf("types.FindEnumByName(%q).Descriptor().FullName() = %q, want %q", name, got, want) + } +} + +func TestDynamicTypesMessageNotFound(t *testing.T) { + types := newTestTypes() + for _, name := range []protoreflect.FullName{ + "Message1", + "testprotos.DoesNotExist", + } { + _, err := types.FindMessageByName(name) + if err != protoregistry.NotFound { + t.Errorf("types.FindMessageByName(%q) = _, %v; want protoregistry.NotFound", name, err) + } + } +} + +func TestDynamicTypesFindMessageByName(t *testing.T) { + types := newTestTypes() + name := protoreflect.FullName("testprotos.Message1") + mt, err := types.FindMessageByName(name) + if err != nil { + t.Fatalf("types.FindMessageByName(%q) = %v", name, err) + } + if got, want := mt.Descriptor().FullName(), name; got != want { + t.Fatalf("types.FindMessageByName(%q).Descriptor().FullName() = %q, want %q", name, got, want) + } +} + +func TestDynamicTypesExtensionNotFound(t *testing.T) { + types := newTestTypes() + for _, name := range []protoreflect.FullName{ + "string_field", + "testprotos.DoesNotExist", + } { + _, err := types.FindExtensionByName(name) + if err != protoregistry.NotFound { + t.Errorf("types.FindExtensionByName(%q) = _, %v; want protoregistry.NotFound", name, err) + } + } + messageName := protoreflect.FullName("testprotos.Message1") + if _, err := types.FindExtensionByNumber(messageName, 100); err != protoregistry.NotFound { + t.Errorf("types.FindExtensionByNumber(%q, 100) = _, %v; want protoregistry.NotFound", messageName, 100) + } +} + +func TestDynamicTypesFindExtensionByNameOrNumber(t *testing.T) { + types := newTestTypes() + messageName := protoreflect.FullName("testprotos.Message1") + mt, err := types.FindMessageByName(messageName) + if err != nil { + t.Fatalf("types.FindMessageByName(%q) = %v", messageName, err) + } + for _, extensionName := range []protoreflect.FullName{ + "testprotos.string_field", + "testprotos.Message4.message_field", + } { + xt, err := types.FindExtensionByName(extensionName) + if err != nil { + t.Fatalf("types.FindExtensionByName(%q) = %v", extensionName, err) + } + if got, want := xt.TypeDescriptor().FullName(), extensionName; got != want { + t.Fatalf("types.FindExtensionByName(%q).TypeDescriptor().FullName() = %q, want %q", extensionName, got, want) + } + if got, want := xt.TypeDescriptor().ContainingMessage(), mt.Descriptor(); got != want { + t.Fatalf("xt.TypeDescriptor().ContainingMessage() = %q, want %q", got.FullName(), want.FullName()) + } + number := xt.TypeDescriptor().Number() + xt2, err := types.FindExtensionByNumber(messageName, number) + if err != nil { + t.Fatalf("types.FindExtensionByNumber(%q, %v) = %v", messageName, number, err) + } + if xt != xt2 { + t.Fatalf("FindExtensionByName returned a differet extension than FindExtensionByNumber") + } + } +} + +func TestDynamicTypesFilesChangeAfterCreation(t *testing.T) { + files := &protoregistry.Files{} + files.RegisterFile(descriptorpb.File_google_protobuf_descriptor_proto) + types := dynamicpb.NewTypes(files) + + // Not found: Files registry does not contain this file. + const message = "testprotos.Message1" + const number = 11 + if _, err := types.FindMessageByName(message); err != protoregistry.NotFound { + t.Errorf("types.FindMessageByName(%q) = %v, want protoregistry.NotFound", message, err) + } + if _, err := types.FindExtensionByNumber(message, number); err != protoregistry.NotFound { + t.Errorf("types.FindExtensionByNumber(%q, %v) = %v, want protoregistry.NotFound", message, number, err) + } + + // Found: Add the file to the registry and recheck. + files.RegisterFile(registrypb.File_internal_testprotos_registry_test_proto) + if _, err := types.FindMessageByName(message); err != nil { + t.Errorf("types.FindMessageByName(%q) = %v, want nil", message, err) + } + if _, err := types.FindExtensionByNumber(message, number); err != nil { + t.Errorf("types.FindExtensionByNumber(%q, %v) = %v, want nil", message, number, err) + } +}