proto: make one test more general

Tweak the "nested unknown extension" test case's resolver to not depend
on the exact message being tested. Useful for if/when we want to run
these tests on other message implementations.

Change-Id: Id1722afd8e094ddb59cb3e5440f7994c20cfa681
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217760
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-02-04 12:58:17 -08:00
parent f68f17085a
commit 1c33e1125a

View File

@ -10,6 +10,7 @@ import (
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/internal/protobuild"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
@ -1389,12 +1390,17 @@ var testValidMessages = []testProto{
desc: "nested unknown extension",
unmarshalOptions: proto.UnmarshalOptions{
DiscardUnknown: true,
Resolver: func() protoregistry.ExtensionTypeResolver {
types := &protoregistry.Types{}
types.RegisterExtension(testpb.E_OptionalNestedMessage)
types.RegisterExtension(testpb.E_OptionalInt32)
return types
}(),
Resolver: filterResolver{
filter: func(name protoreflect.FullName) bool {
switch name.Name() {
case "optional_nested_message",
"optional_int32":
return true
}
return false
},
resolver: protoregistry.GlobalTypes,
},
},
decodeTo: makeMessages(protobuild.Message{
"optional_nested_message": protobuild.Message{
@ -1847,3 +1853,26 @@ var testInvalidMessages = []testProto{
}.Marshal(),
},
}
type filterResolver struct {
filter func(name protoreflect.FullName) bool
resolver protoregistry.ExtensionTypeResolver
}
func (f filterResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
if !f.filter(field) {
return nil, protoregistry.NotFound
}
return f.resolver.FindExtensionByName(field)
}
func (f filterResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
xt, err := f.resolver.FindExtensionByNumber(message, field)
if err != nil {
return nil, err
}
if !f.filter(xt.TypeDescriptor().FullName()) {
return nil, protoregistry.NotFound
}
return xt, nil
}