testing/protocmp: add Message.Unwrap

The Unwrap method returns the original concrete message value.
In theory this allows users to mutate the original message when the
cmp documentation says that all options should be mutation free.
If users want to disregard this documented restriction, they can
already do so in a number of different ways.

Updates #1347

Change-Id: I65225681ab5dbce0763a140fd02666a4ab542a04
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/340489
Trust: Joe Tsai <joetsai@digital-static.net>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2021-08-06 11:27:51 -07:00 committed by Damien Neil
parent 05be61fde3
commit 5aec41b480
4 changed files with 49 additions and 31 deletions

View File

@ -68,7 +68,7 @@ func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoref
}
// Range over populated extension fields.
for _, xd := range m[messageTypeKey].(messageType).xds {
for _, xd := range m[messageTypeKey].(messageMeta).xds {
if m.Has(xd) && !f(xd, m.Get(xd)) {
return
}
@ -91,7 +91,7 @@ func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value
return protoreflect.ValueOfMap(reflectMap{})
case fd.Message() != nil:
return protoreflect.ValueOfMessage(reflectMessage{
messageTypeKey: messageType{md: m.Descriptor()},
messageTypeKey: messageMeta{md: fd.Message()},
})
default:
return fd.Default()

View File

@ -297,11 +297,11 @@ func (f *nameFilters) filterFieldName(m Message, k string) bool {
return true // treat missing fields as already filtered
}
var fd protoreflect.FieldDescriptor
switch mt := m[messageTypeKey].(messageType); {
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mt.md.Fields().ByTextName(k)
fd = mm.md.Fields().ByTextName(k)
default:
fd = mt.xds[k]
fd = mm.xds[k]
}
if fd != nil {
return f.names[fd.FullName()]
@ -376,11 +376,11 @@ func isDefaultScalar(m Message, k string) bool {
}
var fd protoreflect.FieldDescriptor
switch mt := m[messageTypeKey].(messageType); {
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mt.md.Fields().ByTextName(k)
fd = mm.md.Fields().ByTextName(k)
default:
fd = mt.xds[k]
fd = mm.xds[k]
}
if fd == nil || !fd.Default().IsValid() {
return false

View File

@ -68,20 +68,28 @@ func (e Enum) String() string {
}
const (
messageTypeKey = "@type"
// messageTypeKey indicates the protobuf message type.
// The value type is always messageMeta.
// From the public API, it presents itself as only the type, but the
// underlying data structure holds arbitrary metadata about the message.
messageTypeKey = "@type"
// messageInvalidKey indicates that the message is invalid.
// The value is always the boolean "true".
messageInvalidKey = "@invalid"
)
type messageType struct {
type messageMeta struct {
m proto.Message
md protoreflect.MessageDescriptor
xds map[string]protoreflect.ExtensionDescriptor
}
func (t messageType) String() string {
func (t messageMeta) String() string {
return string(t.md.FullName())
}
func (t1 messageType) Equal(t2 messageType) bool {
func (t1 messageMeta) Equal(t2 messageMeta) bool {
return t1.md.FullName() == t2.md.FullName()
}
@ -109,11 +117,18 @@ func (t1 messageType) Equal(t2 messageType) bool {
// Message values must not be created by or mutated by users.
type Message map[string]interface{}
// Unwrap returns the original message value.
// It returns nil if this Message was not constructed from another message.
func (m Message) Unwrap() proto.Message {
mm, _ := m[messageTypeKey].(messageMeta)
return mm.m
}
// Descriptor return the message descriptor.
// It returns nil for a zero Message value.
func (m Message) Descriptor() protoreflect.MessageDescriptor {
mt, _ := m[messageTypeKey].(messageType)
return mt.md
mm, _ := m[messageTypeKey].(messageMeta)
return mm.md
}
// ProtoReflect returns a reflective view of m.
@ -201,7 +216,7 @@ func Transform(...option) cmp.Option {
case m == nil:
return nil
case !m.IsValid():
return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true}
return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
default:
return transformMessage(m)
}
@ -218,7 +233,7 @@ func isMessageType(t reflect.Type) bool {
func transformMessage(m protoreflect.Message) Message {
mx := Message{}
mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
// Handle known and extension fields.
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {

View File

@ -40,7 +40,7 @@ func TestTransform(t *testing.T) {
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{A: proto.Int32(5)},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"optional_bool": bool(false),
"optional_int32": int32(-32),
"optional_int64": int64(-64),
@ -51,7 +51,7 @@ func TestTransform(t *testing.T) {
"optional_string": string("string"),
"optional_bytes": []byte("bytes"),
"optional_nested_enum": enumOf(testpb.TestAllTypes_NEG),
"optional_nested_message": Message{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
"optional_nested_message": Message{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
},
}, {
in: &testpb.TestAllTypes{
@ -74,7 +74,7 @@ func TestTransform(t *testing.T) {
},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"repeated_bool": []bool{false, true},
"repeated_int32": []int32{32, -32},
"repeated_int64": []int64{64, -64},
@ -89,8 +89,8 @@ func TestTransform(t *testing.T) {
enumOf(testpb.TestAllTypes_BAR),
},
"repeated_nested_message": []Message{
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
},
},
}, {
@ -112,7 +112,7 @@ func TestTransform(t *testing.T) {
},
},
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"map_bool_bool": map[bool]bool{true: false},
"map_int32_int32": map[int32]int32{-32: 32},
"map_int64_int64": map[int64]int64{-64: 64},
@ -126,7 +126,7 @@ func TestTransform(t *testing.T) {
"k": enumOf(testpb.TestAllTypes_FOO),
},
"map_string_nested_message": map[string]Message{
"k": {messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
"k": {messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
},
},
}, {
@ -146,7 +146,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
"[goproto.proto.test.optional_bool]": bool(false),
"[goproto.proto.test.optional_int32]": int32(-32),
"[goproto.proto.test.optional_int64]": int64(-64),
@ -157,7 +157,7 @@ func TestTransform(t *testing.T) {
"[goproto.proto.test.optional_string]": string("string"),
"[goproto.proto.test.optional_bytes]": []byte("bytes"),
"[goproto.proto.test.optional_nested_enum]": enumOf(testpb.TestAllTypes_NEG),
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
},
}, {
in: func() proto.Message {
@ -182,7 +182,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
"[goproto.proto.test.repeated_bool]": []bool{false, true},
"[goproto.proto.test.repeated_int32]": []int32{32, -32},
"[goproto.proto.test.repeated_int64]": []int64{64, -64},
@ -197,8 +197,8 @@ func TestTransform(t *testing.T) {
enumOf(testpb.TestAllTypes_BAR),
},
"[goproto.proto.test.repeated_nested_message]": []Message{
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
},
},
}, {
@ -229,7 +229,7 @@ func TestTransform(t *testing.T) {
return m
}(),
want: Message{
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
"50000": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50000, Type: protopack.VarintType}, protopack.Uvarint(100)}.Marshal()),
"50001": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50001, Type: protopack.Fixed32Type}, protopack.Uint32(200)}.Marshal()),
"50002": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50002, Type: protopack.Fixed64Type}, protopack.Uint64(300)}.Marshal()),
@ -258,6 +258,9 @@ func TestTransform(t *testing.T) {
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
}
if got.Unwrap() != tt.in {
t.Errorf("got.Unwrap() = %p, want %p", got.Unwrap(), tt.in)
}
})
}
}
@ -266,6 +269,6 @@ func enumOf(e protoreflect.Enum) Enum {
return Enum{e.Number(), e.Descriptor()}
}
func messageTypeOf(m protoreflect.ProtoMessage) messageType {
return messageType{md: m.ProtoReflect().Descriptor()}
func messageMetaOf(m protoreflect.ProtoMessage) messageMeta {
return messageMeta{m: m, md: m.ProtoReflect().Descriptor()}
}