reflect/protoregistry: centralize MessageSet extension resolution logic

Centralize the MessageSet extension resolution logic in the registry.
This avoids needless replication of this exact logic in multiple places
(for JSON and text) and elsewhere.

Change-Id: I70bfea899e295e8c589f418965bf0dd099f93628
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/240077
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Joe Tsai 2020-06-25 17:35:32 -07:00
parent 1726b83dc4
commit b78321453d
6 changed files with 34 additions and 47 deletions

View File

@ -170,7 +170,7 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
// Only extension names are in [name] format.
extName := pref.FullName(name[1 : len(name)-1])
extType, err := d.findExtension(extName)
extType, err := d.opts.Resolver.FindExtensionByName(extName)
if err != nil && err != protoregistry.NotFound {
return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
}
@ -257,15 +257,6 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
}
}
// findExtension returns protoreflect.ExtensionType from the resolver if found.
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
if err == nil {
return xt, nil
}
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
}
func isKnownValue(fd pref.FieldDescriptor) bool {
md := fd.Message()
return md != nil && md.FullName() == genid.Value_message_fullname

View File

@ -1403,7 +1403,7 @@ func TestUnmarshal(t *testing.T) {
"optString": "not a messageset extension"
}
}`,
wantErr: `unknown field "[pb2.FakeMessageSetExtension]"`,
wantErr: `unable to resolve "[pb2.FakeMessageSetExtension]": found wrong type`,
skip: !flags.ProtoLegacy,
}, {
desc: "not real MessageSet 3",

View File

@ -172,7 +172,7 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
case text.TypeName:
// Handle extensions only. This code path is not for Any.
xt, xtErr = d.findExtension(pref.FullName(tok.TypeName()))
xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName()))
case text.FieldNumber:
isFieldNumberName = true
@ -269,15 +269,6 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
return nil
}
// findExtension returns protoreflect.ExtensionType from the Resolver if found.
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
if err == nil {
return xt, nil
}
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
}
// unmarshalSingular unmarshals a non-repeated field value specified by the
// given FieldDescriptor.
func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error {

View File

@ -1508,7 +1508,7 @@ opt_int32: 42
opt_string: "not a messageset extension"
}
`,
wantErr: "unknown field: [pb2.FakeMessageSetExtension]",
wantErr: `unable to resolve [[pb2.FakeMessageSetExtension]]: found wrong type`,
skip: !flags.ProtoLegacy,
}, {
desc: "not real MessageSet 3",

View File

@ -11,7 +11,6 @@ import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
)
// The MessageSet wire format is equivalent to a message defiend as follows,
@ -48,33 +47,17 @@ func IsMessageSet(md pref.MessageDescriptor) bool {
return ok && xmd.IsMessageSet()
}
// IsMessageSetExtension reports this field extends a MessageSet.
// IsMessageSetExtension reports this field properly extends a MessageSet.
func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
if fd.Name() != ExtensionName {
switch {
case fd.Name() != ExtensionName:
return false
case !IsMessageSet(fd.ContainingMessage()):
return false
case fd.FullName().Parent() != fd.Message().FullName():
return false
}
if fd.FullName().Parent() != fd.Message().FullName() {
return false
}
return IsMessageSet(fd.ContainingMessage())
}
// FindMessageSetExtension locates a MessageSet extension field by name.
// In text and JSON formats, the extension name used is the message itself.
// The extension field name is derived by appending ExtensionName.
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
name := s.Append(ExtensionName)
xt, err := r.FindExtensionByName(name)
if err != nil {
if err == preg.NotFound {
return nil, err
}
return nil, errors.Wrap(err, "%q", name)
}
if !IsMessageSetExtension(xt.TypeDescriptor()) {
return nil, preg.NotFound
}
return xt, nil
return true
}
// SizeField returns the size of a MessageSet item field containing an extension

View File

@ -21,7 +21,9 @@ import (
"strings"
"sync"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/reflect/protoreflect"
)
@ -613,6 +615,26 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
return xt, nil
}
// MessageSet extensions are special in that the name of the extension
// is the name of the message type used to extend the MessageSet.
// This naming scheme is used by text and JSON serialization.
//
// This feature is protected by the ProtoLegacy flag since MessageSets
// are a proto1 feature that is long deprecated.
if flags.ProtoLegacy {
if _, ok := v.(protoreflect.MessageType); ok {
field := field.Append(messageset.ExtensionName)
if v := r.typesByName[field]; v != nil {
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
return xt, nil
}
}
}
}
}
return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
}
return nil, NotFound