encoding: unify MessageSet extension handling logic

This CL unifies common MessageSet logic in prototext and protojson
into the messageset package. While we are at it, also enable
MessageSet support only if the proto1_legacy build flag is enabled.

Change-Id: I1a7d475e8bb1dad61ecd286df45e4239e5bef072
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185898
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-07-11 18:23:08 -07:00
parent af57087245
commit 5ae10aa9f0
9 changed files with 110 additions and 76 deletions

View File

@ -12,7 +12,9 @@ import (
"strings"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/set"
"google.golang.org/protobuf/proto"
@ -136,13 +138,14 @@ func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) err
// unmarshalFields unmarshals the fields into the given protoreflect.Message.
func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
var seenNums set.Ints
var seenOneofs set.Ints
messageDesc := m.Descriptor()
fieldDescs := messageDesc.Fields()
Loop:
for {
// Read field name.
jval, err := o.decoder.Read()
@ -153,7 +156,7 @@ Loop:
default:
return unexpectedJSONError{jval}
case json.EndObject:
break Loop
return nil
case json.Name:
// Continue below.
}
@ -243,8 +246,6 @@ Loop:
}
}
}
return nil
}
// findExtension returns protoreflect.ExtensionType from the resolver if found.
@ -253,13 +254,7 @@ func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionTyp
if err == nil {
return xt, nil
}
// Check if this is a MessageSet extension field.
xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
if err == nil && isMessageSetExtension(xt) {
return xt, nil
}
return nil, protoregistry.NotFound
return messageset.FindMessageSetExtension(o.Resolver, xtName)
}
func isKnownValue(fd pref.FieldDescriptor) bool {

View File

@ -1296,6 +1296,17 @@ func TestUnmarshal(t *testing.T) {
inputMessage: &pb2.Extensions{},
inputText: `{ "[pb2.invalid_message_field]": true }`,
wantErr: true,
}, {
desc: "extensions of repeated field contains null",
inputMessage: &pb2.Extensions{},
inputText: `{
"[pb2.ExtensionsContainer.rpt_ext_nested]": [
{"optString": "one"},
null,
{"optString": "three"}
],
}`,
wantErr: true,
}, {
desc: "MessageSet",
inputMessage: &pb2.MessageSet{},
@ -1323,17 +1334,7 @@ func TestUnmarshal(t *testing.T) {
})
return m
}(),
}, {
desc: "extensions of repeated field contains null",
inputMessage: &pb2.Extensions{},
inputText: `{
"[pb2.ExtensionsContainer.rpt_ext_nested]": [
{"optString": "one"},
null,
{"optString": "three"}
],
}`,
wantErr: true,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 1",
inputMessage: &pb2.FakeMessageSet{},
@ -1349,6 +1350,7 @@ func TestUnmarshal(t *testing.T) {
})
return m
}(),
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 2",
inputMessage: &pb2.FakeMessageSet{},
@ -1358,6 +1360,7 @@ func TestUnmarshal(t *testing.T) {
}
}`,
wantErr: true,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 3",
inputMessage: &pb2.MessageSet{},
@ -1373,6 +1376,7 @@ func TestUnmarshal(t *testing.T) {
})
return m
}(),
skip: !flags.Proto1Legacy,
}, {
desc: "Empty",
inputMessage: &emptypb.Empty{},

View File

@ -10,6 +10,9 @@ import (
"sort"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
@ -84,8 +87,13 @@ func (o MarshalOptions) marshalMessage(m pref.Message) error {
// marshalFields marshals the fields in the given protoreflect.Message.
func (o MarshalOptions) marshalFields(m pref.Message) error {
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
// Marshal out known fields.
fieldDescs := m.Descriptor().Fields()
fieldDescs := messageDesc.Fields()
for i := 0; i < fieldDescs.Len(); i++ {
fd := fieldDescs.Get(i)
if !m.Has(fd) {
@ -257,10 +265,10 @@ func (o MarshalOptions) marshalExtensions(m pref.Message) error {
return true
}
// If extended type is a MessageSet, set field name to be the message type name.
// For MessageSet extensions, the name used is the parent message.
name := fd.FullName()
if isMessageSetExtension(fd) {
name = fd.Message().FullName()
if messageset.IsMessageSetExtension(fd) {
name = name.Parent()
}
// Use [name] format for JSON field name.
@ -291,19 +299,3 @@ func (o MarshalOptions) marshalExtensions(m pref.Message) error {
}
return nil
}
// isMessageSetExtension reports whether extension extends a message set.
func isMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != "message_set_extension" {
return false
}
md := fd.Message()
if md == nil {
return false
}
if fd.FullName().Parent() != md.FullName() {
return false
}
xmd, ok := fd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}

View File

@ -12,6 +12,7 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/flags"
pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry"
@ -40,6 +41,7 @@ func TestMarshal(t *testing.T) {
input proto.Message
want string
wantErr bool // TODO: Verify error message substring.
skip bool
}{{
desc: "proto2 optional scalars not set",
input: &pb2.Scalars{},
@ -1038,6 +1040,7 @@ func TestMarshal(t *testing.T) {
"optString": "not a messageset extension"
}
}`,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 1",
input: func() proto.Message {
@ -1052,6 +1055,7 @@ func TestMarshal(t *testing.T) {
"optString": "not a messageset extension"
}
}`,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 2",
input: func() proto.Message {
@ -1066,6 +1070,7 @@ func TestMarshal(t *testing.T) {
"optString": "another not a messageset extension"
}
}`,
skip: !flags.Proto1Legacy,
}, {
desc: "BoolValue empty",
input: &wrapperspb.BoolValue{},
@ -1898,6 +1903,9 @@ func TestMarshal(t *testing.T) {
for _, tt := range tests {
tt := tt
if tt.skip {
continue
}
t.Run(tt.desc, func(t *testing.T) {
// Use 2-space indentation on all MarshalOptions.
tt.mo.Indent = " "

View File

@ -9,9 +9,11 @@ import (
"strings"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/text"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/set"
"google.golang.org/protobuf/proto"
@ -74,17 +76,18 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
// unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message.
func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error {
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
// Handle expanded Any message.
if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) {
return o.unmarshalAny(tmsg[0], m)
}
fieldDescs := messageDesc.Fields()
reservedNames := messageDesc.ReservedNames()
var seenNums set.Ints
var seenOneofs set.Ints
fieldDescs := messageDesc.Fields()
for _, tfield := range tmsg {
tkey := tfield[0]
tval := tfield[1]
@ -128,7 +131,7 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
if fd == nil {
// Ignore reserved names.
if reservedNames.Has(name) {
if messageDesc.ReservedNames().Has(name) {
continue
}
// TODO: Can provide option to ignore unknown message fields.
@ -193,13 +196,7 @@ func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionTyp
if err == nil {
return xt, nil
}
// Check if this is a MessageSet extension field.
xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
if err == nil && isMessageSetExtension(xt) {
return xt, nil
}
return nil, protoregistry.NotFound
return messageset.FindMessageSetExtension(o.Resolver, xtName)
}
// unmarshalSingular unmarshals given text.Value into the non-repeated field.

View File

@ -1310,6 +1310,7 @@ opt_int32: 42
})
return m
}(),
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 1",
inputMessage: &pb2.FakeMessageSet{},
@ -1325,6 +1326,7 @@ opt_int32: 42
})
return m
}(),
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 2",
inputMessage: &pb2.FakeMessageSet{},
@ -1334,6 +1336,7 @@ opt_int32: 42
}
`,
wantErr: true,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 3",
inputMessage: &pb2.MessageSet{},
@ -1348,6 +1351,7 @@ opt_int32: 42
})
return m
}(),
skip: !flags.Proto1Legacy,
}, {
desc: "Any not expanded",
inputMessage: &anypb.Any{},

View File

@ -9,10 +9,12 @@ import (
"sort"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/text"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/proto"
@ -72,8 +74,10 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
// marshalMessage converts a protoreflect.Message to a text.Value.
func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
var msgFields [][2]text.Value
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return text.Value{}, errors.New("no support for proto1 MessageSets")
}
// Handle Any expansion.
if messageDesc.FullName() == "google.protobuf.Any" {
@ -85,6 +89,7 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
}
// Handle known fields.
var msgFields [][2]text.Value
fieldDescs := messageDesc.Fields()
size := fieldDescs.Len()
for i := 0; i < size; i++ {
@ -253,10 +258,10 @@ func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Messa
return true
}
// If extended type is a MessageSet, set field name to be the message type name.
// For MessageSet extensions, the name used is the parent message.
name := fd.FullName()
if isMessageSetExtension(fd) {
name = fd.Message().FullName()
if messageset.IsMessageSetExtension(fd) {
name = name.Parent()
}
// Use string type to produce [name] format.
@ -279,22 +284,6 @@ func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Messa
return append(msgFields, entries...), nil
}
// isMessageSetExtension reports whether extension extends a message set.
func isMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != "message_set_extension" {
return false
}
md := fd.Message()
if md == nil {
return false
}
if fd.FullName().Parent() != md.FullName() {
return false
}
xmd, ok := fd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}
// appendUnknown parses the given []byte and appends field(s) into the given fields slice.
// This function assumes proper encoding in the given []byte.
func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value {

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/flags"
pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry"
@ -39,6 +40,7 @@ func TestMarshal(t *testing.T) {
input proto.Message
want string
wantErr bool // TODO: Verify error message content.
skip bool
}{{
desc: "proto2 optional scalars not set",
input: &pb2.Scalars{},
@ -1082,6 +1084,7 @@ opt_int32: 42
opt_string: "not a messageset extension"
}
`,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 1",
input: func() proto.Message {
@ -1095,6 +1098,7 @@ opt_int32: 42
opt_string: "not a messageset extension"
}
`,
skip: !flags.Proto1Legacy,
}, {
desc: "not real MessageSet 2",
input: func() proto.Message {
@ -1108,6 +1112,7 @@ opt_int32: 42
opt_string: "another not a messageset extension"
}
`,
skip: !flags.Proto1Legacy,
}, {
desc: "Any not expanded",
mo: prototext.MarshalOptions{
@ -1201,6 +1206,9 @@ value: "\x80"
for _, tt := range tests {
tt := tt
if tt.skip {
continue
}
t.Run(tt.desc, func(t *testing.T) {
// Use 2-space indentation on all MarshalOptions.
tt.mo.Indent = " "

View File

@ -9,6 +9,7 @@ import (
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
)
// The MessageSet wire format is equivalent to a message defiend as follows,
@ -28,12 +29,48 @@ const (
FieldMessage = wire.Number(3)
)
// ExtensionName is the field name for extensions of MessageSet.
//
// A valid MessageSet extension must be of the form:
// message MyMessage {
// extend proto2.bridge.MessageSet {
// optional MyMessage message_set_extension = 1234;
// }
// ...
// }
const ExtensionName = "message_set_extension"
// IsMessageSet returns whether the message uses the MessageSet wire format.
func IsMessageSet(md pref.MessageDescriptor) bool {
xmd, ok := md.(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}
// IsMessageSetExtension reports this field extends a MessageSet.
func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != ExtensionName {
return false
}
if fd.FullName().Parent() != fd.Message().FullName() {
return false
}
return IsMessageSet(fd.ContainingMessage())
}
// FindMessageSetExtension locates a MessageSet extension field by name.
// In text and JSON formats, the extension name used is the message itself.
// The extension field name is derived by appending ExtensionName.
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
xt, err := r.FindExtensionByName(s.Append(ExtensionName))
if err != nil {
return nil, err
}
if !IsMessageSetExtension(xt) {
return nil, preg.NotFound
}
return xt, nil
}
// SizeField returns the size of a MessageSet item field containing an extension
// with the given field number, not counting the contents of the message subfield.
func SizeField(num wire.Number) int {