From 1c33e1125a5f1e2500ced9568c31ffad13d752ec Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 4 Feb 2020 12:58:17 -0800 Subject: [PATCH] proto: make one test more general Tweak the "nested unknown extension" test case's resolver to not depend on the exact message being tested. Useful for if/when we want to run these tests on other message implementations. Change-Id: Id1722afd8e094ddb59cb3e5440f7994c20cfa681 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217760 Reviewed-by: Joe Tsai --- proto/testmessages_test.go | 41 ++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go index 045cb8fb..d980fd15 100644 --- a/proto/testmessages_test.go +++ b/proto/testmessages_test.go @@ -10,6 +10,7 @@ import ( "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/internal/protobuild" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" legacypb "google.golang.org/protobuf/internal/testprotos/legacy" @@ -1389,12 +1390,17 @@ var testValidMessages = []testProto{ desc: "nested unknown extension", unmarshalOptions: proto.UnmarshalOptions{ DiscardUnknown: true, - Resolver: func() protoregistry.ExtensionTypeResolver { - types := &protoregistry.Types{} - types.RegisterExtension(testpb.E_OptionalNestedMessage) - types.RegisterExtension(testpb.E_OptionalInt32) - return types - }(), + Resolver: filterResolver{ + filter: func(name protoreflect.FullName) bool { + switch name.Name() { + case "optional_nested_message", + "optional_int32": + return true + } + return false + }, + resolver: protoregistry.GlobalTypes, + }, }, decodeTo: makeMessages(protobuild.Message{ "optional_nested_message": protobuild.Message{ @@ -1847,3 +1853,26 @@ var testInvalidMessages = []testProto{ }.Marshal(), }, } + +type filterResolver struct { + filter func(name protoreflect.FullName) bool + resolver protoregistry.ExtensionTypeResolver +} + +func (f filterResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { + if !f.filter(field) { + return nil, protoregistry.NotFound + } + return f.resolver.FindExtensionByName(field) +} + +func (f filterResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + xt, err := f.resolver.FindExtensionByNumber(message, field) + if err != nil { + return nil, err + } + if !f.filter(xt.TypeDescriptor().FullName()) { + return nil, protoregistry.NotFound + } + return xt, nil +}