testing/prototest: refactor prototest API

For consistency with other options types in the protobuf module, make
the test function a method of the options.

Drop the ExtensionTypes option and just look up the extension types to
test with in the provided resolver.

Change-Id: I7918bd10b7c003e4af56d27521d30218653d5b4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/219142
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2020-02-11 16:43:13 -08:00
parent fb5fde41cd
commit 56786dc5df
3 changed files with 42 additions and 44 deletions

View File

@ -17,47 +17,42 @@ import (
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/reflect/protoregistry"
)
// TODO: Test invalid field descriptors or oneof descriptors.
// TODO: This should test the functionality that can be provided by fast-paths.
// MessageOptions configure message tests.
type MessageOptions struct {
// ExtensionTypes is a list of types to test with.
//
// If nil, TestMessage will look for extension types in the global registry.
ExtensionTypes []pref.ExtensionType
// Resolver is used for looking up types when unmarshaling extension fields.
// Message tests a message implemention.
type Message struct {
// Resolver is used to determine the list of extension fields to test with.
// If nil, this defaults to using protoregistry.GlobalTypes.
Resolver interface {
preg.ExtensionTypeResolver
FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
}
}
// TODO(blocks): TestMessage should not take in MessageOptions,
// but have a MessageOptions.Test method instead.
// Test performs tests on a MessageType implementation.
func (test Message) Test(t testing.TB, mt pref.MessageType) {
testType(t, mt)
// TestMessage runs the provided m through a series of tests
// exercising the protobuf reflection API.
func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
testType(t, m)
md := m.ProtoReflect().Descriptor()
m1 := m.ProtoReflect().New()
md := mt.Descriptor()
m1 := mt.New()
for i := 0; i < md.Fields().Len(); i++ {
fd := md.Fields().Get(i)
testField(t, m1, fd)
}
if opts.ExtensionTypes == nil {
preg.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
opts.ExtensionTypes = append(opts.ExtensionTypes, e)
return true
})
if test.Resolver == nil {
test.Resolver = protoregistry.GlobalTypes
}
for _, xt := range opts.ExtensionTypes {
var extTypes []pref.ExtensionType
test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
extTypes = append(extTypes, e)
return true
})
for _, xt := range extTypes {
testField(t, m1, xt.TypeDescriptor())
}
for i := 0; i < md.Oneofs().Len(); i++ {
@ -66,9 +61,9 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
testUnknown(t, m1)
// Test round-trip marshal/unmarshal.
m2 := m.ProtoReflect().New().Interface()
m2 := mt.New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil)
for _, xt := range opts.ExtensionTypes {
for _, xt := range extTypes {
m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
}
b, err := proto.MarshalOptions{
@ -77,10 +72,10 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
}
m3 := m.ProtoReflect().New().Interface()
m3 := mt.New().Interface()
if err := (proto.UnmarshalOptions{
AllowPartial: true,
Resolver: opts.Resolver,
Resolver: test.Resolver,
}.Unmarshal(b, m3)); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
}
@ -89,7 +84,8 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
}
}
func testType(t testing.TB, m proto.Message) {
func testType(t testing.TB, mt pref.MessageType) {
m := mt.New().Interface()
want := reflect.TypeOf(m)
if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)

View File

@ -38,7 +38,7 @@ func Test(t *testing.T) {
for _, m := range ms {
t.Run(fmt.Sprintf("%T", m), func(t *testing.T) {
prototest.TestMessage(t, m, prototest.MessageOptions{})
prototest.Message{}.Test(t, m.ProtoReflect().Type())
})
}
}

View File

@ -23,24 +23,20 @@ func TestConformance(t *testing.T) {
(*test3pb.TestAllTypes)(nil),
(*testpb.TestAllExtensions)(nil),
} {
prototest.TestMessage(t, dynamicpb.NewMessage(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
prototest.Message{}.Test(t, mt)
}
}
func TestDynamicExtensions(t *testing.T) {
file, err := preg.GlobalFiles.FindFileByPath("internal/testprotos/test/ext.proto")
if err != nil {
t.Fatal(err)
for _, message := range []proto.Message{
(*testpb.TestAllExtensions)(nil),
} {
mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
prototest.Message{
Resolver: extResolver{},
}.Test(t, mt)
}
md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
opts := prototest.MessageOptions{
Resolver: extResolver{},
}
for i := 0; i < file.Extensions().Len(); i++ {
opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
}
prototest.TestMessage(t, dynamicpb.NewMessage(md), opts)
}
type extResolver struct{}
@ -60,3 +56,9 @@ func (extResolver) FindExtensionByNumber(message pref.FullName, field pref.Field
}
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
}
func (extResolver) RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool) {
preg.GlobalTypes.RangeExtensionsByMessage(message, func(xt pref.ExtensionType) bool {
return f(dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()))
})
}