encoding: unify MessageSet extension handling logic

This CL unifies common MessageSet logic in prototext and protojson
into the messageset package. While we are at it, also enable
MessageSet support only if the proto1_legacy build flag is enabled.

Change-Id: I1a7d475e8bb1dad61ecd286df45e4239e5bef072
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185898
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-07-11 18:23:08 -07:00
parent af57087245
commit 5ae10aa9f0
9 changed files with 110 additions and 76 deletions

View File

@ -12,7 +12,9 @@ import (
"strings" "strings"
"google.golang.org/protobuf/internal/encoding/json" "google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/set" "google.golang.org/protobuf/internal/set"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -136,13 +138,14 @@ func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) err
// unmarshalFields unmarshals the fields into the given protoreflect.Message. // unmarshalFields unmarshals the fields into the given protoreflect.Message.
func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error { func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
var seenNums set.Ints var seenNums set.Ints
var seenOneofs set.Ints var seenOneofs set.Ints
messageDesc := m.Descriptor()
fieldDescs := messageDesc.Fields() fieldDescs := messageDesc.Fields()
Loop:
for { for {
// Read field name. // Read field name.
jval, err := o.decoder.Read() jval, err := o.decoder.Read()
@ -153,7 +156,7 @@ Loop:
default: default:
return unexpectedJSONError{jval} return unexpectedJSONError{jval}
case json.EndObject: case json.EndObject:
break Loop return nil
case json.Name: case json.Name:
// Continue below. // Continue below.
} }
@ -243,8 +246,6 @@ Loop:
} }
} }
} }
return nil
} }
// findExtension returns protoreflect.ExtensionType from the resolver if found. // findExtension returns protoreflect.ExtensionType from the resolver if found.
@ -253,13 +254,7 @@ func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionTyp
if err == nil { if err == nil {
return xt, nil return xt, nil
} }
return messageset.FindMessageSetExtension(o.Resolver, xtName)
// Check if this is a MessageSet extension field.
xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
if err == nil && isMessageSetExtension(xt) {
return xt, nil
}
return nil, protoregistry.NotFound
} }
func isKnownValue(fd pref.FieldDescriptor) bool { func isKnownValue(fd pref.FieldDescriptor) bool {

View File

@ -1296,6 +1296,17 @@ func TestUnmarshal(t *testing.T) {
inputMessage: &pb2.Extensions{}, inputMessage: &pb2.Extensions{},
inputText: `{ "[pb2.invalid_message_field]": true }`, inputText: `{ "[pb2.invalid_message_field]": true }`,
wantErr: true, wantErr: true,
}, {
desc: "extensions of repeated field contains null",
inputMessage: &pb2.Extensions{},
inputText: `{
"[pb2.ExtensionsContainer.rpt_ext_nested]": [
{"optString": "one"},
null,
{"optString": "three"}
],
}`,
wantErr: true,
}, { }, {
desc: "MessageSet", desc: "MessageSet",
inputMessage: &pb2.MessageSet{}, inputMessage: &pb2.MessageSet{},
@ -1323,17 +1334,7 @@ func TestUnmarshal(t *testing.T) {
}) })
return m return m
}(), }(),
}, { skip: !flags.Proto1Legacy,
desc: "extensions of repeated field contains null",
inputMessage: &pb2.Extensions{},
inputText: `{
"[pb2.ExtensionsContainer.rpt_ext_nested]": [
{"optString": "one"},
null,
{"optString": "three"}
],
}`,
wantErr: true,
}, { }, {
desc: "not real MessageSet 1", desc: "not real MessageSet 1",
inputMessage: &pb2.FakeMessageSet{}, inputMessage: &pb2.FakeMessageSet{},
@ -1349,6 +1350,7 @@ func TestUnmarshal(t *testing.T) {
}) })
return m return m
}(), }(),
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 2", desc: "not real MessageSet 2",
inputMessage: &pb2.FakeMessageSet{}, inputMessage: &pb2.FakeMessageSet{},
@ -1358,6 +1360,7 @@ func TestUnmarshal(t *testing.T) {
} }
}`, }`,
wantErr: true, wantErr: true,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 3", desc: "not real MessageSet 3",
inputMessage: &pb2.MessageSet{}, inputMessage: &pb2.MessageSet{},
@ -1373,6 +1376,7 @@ func TestUnmarshal(t *testing.T) {
}) })
return m return m
}(), }(),
skip: !flags.Proto1Legacy,
}, { }, {
desc: "Empty", desc: "Empty",
inputMessage: &emptypb.Empty{}, inputMessage: &emptypb.Empty{},

View File

@ -10,6 +10,9 @@ import (
"sort" "sort"
"google.golang.org/protobuf/internal/encoding/json" "google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
@ -84,8 +87,13 @@ func (o MarshalOptions) marshalMessage(m pref.Message) error {
// marshalFields marshals the fields in the given protoreflect.Message. // marshalFields marshals the fields in the given protoreflect.Message.
func (o MarshalOptions) marshalFields(m pref.Message) error { func (o MarshalOptions) marshalFields(m pref.Message) error {
messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
// Marshal out known fields. // Marshal out known fields.
fieldDescs := m.Descriptor().Fields() fieldDescs := messageDesc.Fields()
for i := 0; i < fieldDescs.Len(); i++ { for i := 0; i < fieldDescs.Len(); i++ {
fd := fieldDescs.Get(i) fd := fieldDescs.Get(i)
if !m.Has(fd) { if !m.Has(fd) {
@ -257,10 +265,10 @@ func (o MarshalOptions) marshalExtensions(m pref.Message) error {
return true return true
} }
// If extended type is a MessageSet, set field name to be the message type name. // For MessageSet extensions, the name used is the parent message.
name := fd.FullName() name := fd.FullName()
if isMessageSetExtension(fd) { if messageset.IsMessageSetExtension(fd) {
name = fd.Message().FullName() name = name.Parent()
} }
// Use [name] format for JSON field name. // Use [name] format for JSON field name.
@ -291,19 +299,3 @@ func (o MarshalOptions) marshalExtensions(m pref.Message) error {
} }
return nil return nil
} }
// isMessageSetExtension reports whether extension extends a message set.
func isMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != "message_set_extension" {
return false
}
md := fd.Message()
if md == nil {
return false
}
if fd.FullName().Parent() != md.FullName() {
return false
}
xmd, ok := fd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}

View File

@ -12,6 +12,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/internal/encoding/pack" "google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/flags"
pimpl "google.golang.org/protobuf/internal/impl" pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry" preg "google.golang.org/protobuf/reflect/protoregistry"
@ -40,6 +41,7 @@ func TestMarshal(t *testing.T) {
input proto.Message input proto.Message
want string want string
wantErr bool // TODO: Verify error message substring. wantErr bool // TODO: Verify error message substring.
skip bool
}{{ }{{
desc: "proto2 optional scalars not set", desc: "proto2 optional scalars not set",
input: &pb2.Scalars{}, input: &pb2.Scalars{},
@ -1038,6 +1040,7 @@ func TestMarshal(t *testing.T) {
"optString": "not a messageset extension" "optString": "not a messageset extension"
} }
}`, }`,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 1", desc: "not real MessageSet 1",
input: func() proto.Message { input: func() proto.Message {
@ -1052,6 +1055,7 @@ func TestMarshal(t *testing.T) {
"optString": "not a messageset extension" "optString": "not a messageset extension"
} }
}`, }`,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 2", desc: "not real MessageSet 2",
input: func() proto.Message { input: func() proto.Message {
@ -1066,6 +1070,7 @@ func TestMarshal(t *testing.T) {
"optString": "another not a messageset extension" "optString": "another not a messageset extension"
} }
}`, }`,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "BoolValue empty", desc: "BoolValue empty",
input: &wrapperspb.BoolValue{}, input: &wrapperspb.BoolValue{},
@ -1898,6 +1903,9 @@ func TestMarshal(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
if tt.skip {
continue
}
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
// Use 2-space indentation on all MarshalOptions. // Use 2-space indentation on all MarshalOptions.
tt.mo.Indent = " " tt.mo.Indent = " "

View File

@ -9,9 +9,11 @@ import (
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/text" "google.golang.org/protobuf/internal/encoding/text"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum" "google.golang.org/protobuf/internal/fieldnum"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/set" "google.golang.org/protobuf/internal/set"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -74,17 +76,18 @@ func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
// unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message. // unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message.
func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error { func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error {
messageDesc := m.Descriptor() messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
// Handle expanded Any message. // Handle expanded Any message.
if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) { if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) {
return o.unmarshalAny(tmsg[0], m) return o.unmarshalAny(tmsg[0], m)
} }
fieldDescs := messageDesc.Fields()
reservedNames := messageDesc.ReservedNames()
var seenNums set.Ints var seenNums set.Ints
var seenOneofs set.Ints var seenOneofs set.Ints
fieldDescs := messageDesc.Fields()
for _, tfield := range tmsg { for _, tfield := range tmsg {
tkey := tfield[0] tkey := tfield[0]
tval := tfield[1] tval := tfield[1]
@ -128,7 +131,7 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
if fd == nil { if fd == nil {
// Ignore reserved names. // Ignore reserved names.
if reservedNames.Has(name) { if messageDesc.ReservedNames().Has(name) {
continue continue
} }
// TODO: Can provide option to ignore unknown message fields. // TODO: Can provide option to ignore unknown message fields.
@ -193,13 +196,7 @@ func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionTyp
if err == nil { if err == nil {
return xt, nil return xt, nil
} }
return messageset.FindMessageSetExtension(o.Resolver, xtName)
// Check if this is a MessageSet extension field.
xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
if err == nil && isMessageSetExtension(xt) {
return xt, nil
}
return nil, protoregistry.NotFound
} }
// unmarshalSingular unmarshals given text.Value into the non-repeated field. // unmarshalSingular unmarshals given text.Value into the non-repeated field.

View File

@ -1310,6 +1310,7 @@ opt_int32: 42
}) })
return m return m
}(), }(),
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 1", desc: "not real MessageSet 1",
inputMessage: &pb2.FakeMessageSet{}, inputMessage: &pb2.FakeMessageSet{},
@ -1325,6 +1326,7 @@ opt_int32: 42
}) })
return m return m
}(), }(),
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 2", desc: "not real MessageSet 2",
inputMessage: &pb2.FakeMessageSet{}, inputMessage: &pb2.FakeMessageSet{},
@ -1334,6 +1336,7 @@ opt_int32: 42
} }
`, `,
wantErr: true, wantErr: true,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 3", desc: "not real MessageSet 3",
inputMessage: &pb2.MessageSet{}, inputMessage: &pb2.MessageSet{},
@ -1348,6 +1351,7 @@ opt_int32: 42
}) })
return m return m
}(), }(),
skip: !flags.Proto1Legacy,
}, { }, {
desc: "Any not expanded", desc: "Any not expanded",
inputMessage: &anypb.Any{}, inputMessage: &anypb.Any{},

View File

@ -9,10 +9,12 @@ import (
"sort" "sort"
"unicode/utf8" "unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/text" "google.golang.org/protobuf/internal/encoding/text"
"google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum" "google.golang.org/protobuf/internal/fieldnum"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/mapsort" "google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/pragma" "google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -72,8 +74,10 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
// marshalMessage converts a protoreflect.Message to a text.Value. // marshalMessage converts a protoreflect.Message to a text.Value.
func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) { func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
var msgFields [][2]text.Value
messageDesc := m.Descriptor() messageDesc := m.Descriptor()
if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
return text.Value{}, errors.New("no support for proto1 MessageSets")
}
// Handle Any expansion. // Handle Any expansion.
if messageDesc.FullName() == "google.protobuf.Any" { if messageDesc.FullName() == "google.protobuf.Any" {
@ -85,6 +89,7 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
} }
// Handle known fields. // Handle known fields.
var msgFields [][2]text.Value
fieldDescs := messageDesc.Fields() fieldDescs := messageDesc.Fields()
size := fieldDescs.Len() size := fieldDescs.Len()
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@ -253,10 +258,10 @@ func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Messa
return true return true
} }
// If extended type is a MessageSet, set field name to be the message type name. // For MessageSet extensions, the name used is the parent message.
name := fd.FullName() name := fd.FullName()
if isMessageSetExtension(fd) { if messageset.IsMessageSetExtension(fd) {
name = fd.Message().FullName() name = name.Parent()
} }
// Use string type to produce [name] format. // Use string type to produce [name] format.
@ -279,22 +284,6 @@ func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Messa
return append(msgFields, entries...), nil return append(msgFields, entries...), nil
} }
// isMessageSetExtension reports whether extension extends a message set.
func isMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != "message_set_extension" {
return false
}
md := fd.Message()
if md == nil {
return false
}
if fd.FullName().Parent() != md.FullName() {
return false
}
xmd, ok := fd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}
// appendUnknown parses the given []byte and appends field(s) into the given fields slice. // appendUnknown parses the given []byte and appends field(s) into the given fields slice.
// This function assumes proper encoding in the given []byte. // This function assumes proper encoding in the given []byte.
func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value { func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value {

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/detrand" "google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/encoding/pack" "google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/flags"
pimpl "google.golang.org/protobuf/internal/impl" pimpl "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
preg "google.golang.org/protobuf/reflect/protoregistry" preg "google.golang.org/protobuf/reflect/protoregistry"
@ -39,6 +40,7 @@ func TestMarshal(t *testing.T) {
input proto.Message input proto.Message
want string want string
wantErr bool // TODO: Verify error message content. wantErr bool // TODO: Verify error message content.
skip bool
}{{ }{{
desc: "proto2 optional scalars not set", desc: "proto2 optional scalars not set",
input: &pb2.Scalars{}, input: &pb2.Scalars{},
@ -1082,6 +1084,7 @@ opt_int32: 42
opt_string: "not a messageset extension" opt_string: "not a messageset extension"
} }
`, `,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 1", desc: "not real MessageSet 1",
input: func() proto.Message { input: func() proto.Message {
@ -1095,6 +1098,7 @@ opt_int32: 42
opt_string: "not a messageset extension" opt_string: "not a messageset extension"
} }
`, `,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "not real MessageSet 2", desc: "not real MessageSet 2",
input: func() proto.Message { input: func() proto.Message {
@ -1108,6 +1112,7 @@ opt_int32: 42
opt_string: "another not a messageset extension" opt_string: "another not a messageset extension"
} }
`, `,
skip: !flags.Proto1Legacy,
}, { }, {
desc: "Any not expanded", desc: "Any not expanded",
mo: prototext.MarshalOptions{ mo: prototext.MarshalOptions{
@ -1201,6 +1206,9 @@ value: "\x80"
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
if tt.skip {
continue
}
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
// Use 2-space indentation on all MarshalOptions. // Use 2-space indentation on all MarshalOptions.
tt.mo.Indent = " " tt.mo.Indent = " "

View File

@ -9,6 +9,7 @@ import (
"google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
pref "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
) )
// The MessageSet wire format is equivalent to a message defiend as follows, // The MessageSet wire format is equivalent to a message defiend as follows,
@ -28,12 +29,48 @@ const (
FieldMessage = wire.Number(3) FieldMessage = wire.Number(3)
) )
// ExtensionName is the field name for extensions of MessageSet.
//
// A valid MessageSet extension must be of the form:
// message MyMessage {
// extend proto2.bridge.MessageSet {
// optional MyMessage message_set_extension = 1234;
// }
// ...
// }
const ExtensionName = "message_set_extension"
// IsMessageSet returns whether the message uses the MessageSet wire format. // IsMessageSet returns whether the message uses the MessageSet wire format.
func IsMessageSet(md pref.MessageDescriptor) bool { func IsMessageSet(md pref.MessageDescriptor) bool {
xmd, ok := md.(interface{ IsMessageSet() bool }) xmd, ok := md.(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet() return ok && xmd.IsMessageSet()
} }
// IsMessageSetExtension reports this field extends a MessageSet.
func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != ExtensionName {
return false
}
if fd.FullName().Parent() != fd.Message().FullName() {
return false
}
return IsMessageSet(fd.ContainingMessage())
}
// FindMessageSetExtension locates a MessageSet extension field by name.
// In text and JSON formats, the extension name used is the message itself.
// The extension field name is derived by appending ExtensionName.
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
xt, err := r.FindExtensionByName(s.Append(ExtensionName))
if err != nil {
return nil, err
}
if !IsMessageSetExtension(xt) {
return nil, preg.NotFound
}
return xt, nil
}
// SizeField returns the size of a MessageSet item field containing an extension // SizeField returns the size of a MessageSet item field containing an extension
// with the given field number, not counting the contents of the message subfield. // with the given field number, not counting the contents of the message subfield.
func SizeField(num wire.Number) int { func SizeField(num wire.Number) int {