mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-09 22:13:27 +00:00
reflect/protoregistry: centralize MessageSet extension resolution logic
Centralize the MessageSet extension resolution logic in the registry. This avoids needless replication of this exact logic in multiple places (for JSON and text) and elsewhere. Change-Id: I70bfea899e295e8c589f418965bf0dd099f93628 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/240077 Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
parent
1726b83dc4
commit
b78321453d
@ -170,7 +170,7 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
|
||||
if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
|
||||
// Only extension names are in [name] format.
|
||||
extName := pref.FullName(name[1 : len(name)-1])
|
||||
extType, err := d.findExtension(extName)
|
||||
extType, err := d.opts.Resolver.FindExtensionByName(extName)
|
||||
if err != nil && err != protoregistry.NotFound {
|
||||
return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
|
||||
}
|
||||
@ -257,15 +257,6 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
// findExtension returns protoreflect.ExtensionType from the resolver if found.
|
||||
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
|
||||
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
|
||||
if err == nil {
|
||||
return xt, nil
|
||||
}
|
||||
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
|
||||
}
|
||||
|
||||
func isKnownValue(fd pref.FieldDescriptor) bool {
|
||||
md := fd.Message()
|
||||
return md != nil && md.FullName() == genid.Value_message_fullname
|
||||
|
@ -1403,7 +1403,7 @@ func TestUnmarshal(t *testing.T) {
|
||||
"optString": "not a messageset extension"
|
||||
}
|
||||
}`,
|
||||
wantErr: `unknown field "[pb2.FakeMessageSetExtension]"`,
|
||||
wantErr: `unable to resolve "[pb2.FakeMessageSetExtension]": found wrong type`,
|
||||
skip: !flags.ProtoLegacy,
|
||||
}, {
|
||||
desc: "not real MessageSet 3",
|
||||
|
@ -172,7 +172,7 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
|
||||
|
||||
case text.TypeName:
|
||||
// Handle extensions only. This code path is not for Any.
|
||||
xt, xtErr = d.findExtension(pref.FullName(tok.TypeName()))
|
||||
xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName()))
|
||||
|
||||
case text.FieldNumber:
|
||||
isFieldNumberName = true
|
||||
@ -269,15 +269,6 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// findExtension returns protoreflect.ExtensionType from the Resolver if found.
|
||||
func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
|
||||
xt, err := d.opts.Resolver.FindExtensionByName(xtName)
|
||||
if err == nil {
|
||||
return xt, nil
|
||||
}
|
||||
return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
|
||||
}
|
||||
|
||||
// unmarshalSingular unmarshals a non-repeated field value specified by the
|
||||
// given FieldDescriptor.
|
||||
func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error {
|
||||
|
@ -1508,7 +1508,7 @@ opt_int32: 42
|
||||
opt_string: "not a messageset extension"
|
||||
}
|
||||
`,
|
||||
wantErr: "unknown field: [pb2.FakeMessageSetExtension]",
|
||||
wantErr: `unable to resolve [[pb2.FakeMessageSetExtension]]: found wrong type`,
|
||||
skip: !flags.ProtoLegacy,
|
||||
}, {
|
||||
desc: "not real MessageSet 3",
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/internal/errors"
|
||||
pref "google.golang.org/protobuf/reflect/protoreflect"
|
||||
preg "google.golang.org/protobuf/reflect/protoregistry"
|
||||
)
|
||||
|
||||
// The MessageSet wire format is equivalent to a message defiend as follows,
|
||||
@ -48,33 +47,17 @@ func IsMessageSet(md pref.MessageDescriptor) bool {
|
||||
return ok && xmd.IsMessageSet()
|
||||
}
|
||||
|
||||
// IsMessageSetExtension reports this field extends a MessageSet.
|
||||
// IsMessageSetExtension reports this field properly extends a MessageSet.
|
||||
func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
|
||||
if fd.Name() != ExtensionName {
|
||||
switch {
|
||||
case fd.Name() != ExtensionName:
|
||||
return false
|
||||
case !IsMessageSet(fd.ContainingMessage()):
|
||||
return false
|
||||
case fd.FullName().Parent() != fd.Message().FullName():
|
||||
return false
|
||||
}
|
||||
if fd.FullName().Parent() != fd.Message().FullName() {
|
||||
return false
|
||||
}
|
||||
return IsMessageSet(fd.ContainingMessage())
|
||||
}
|
||||
|
||||
// FindMessageSetExtension locates a MessageSet extension field by name.
|
||||
// In text and JSON formats, the extension name used is the message itself.
|
||||
// The extension field name is derived by appending ExtensionName.
|
||||
func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
|
||||
name := s.Append(ExtensionName)
|
||||
xt, err := r.FindExtensionByName(name)
|
||||
if err != nil {
|
||||
if err == preg.NotFound {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errors.Wrap(err, "%q", name)
|
||||
}
|
||||
if !IsMessageSetExtension(xt.TypeDescriptor()) {
|
||||
return nil, preg.NotFound
|
||||
}
|
||||
return xt, nil
|
||||
return true
|
||||
}
|
||||
|
||||
// SizeField returns the size of a MessageSet item field containing an extension
|
||||
|
@ -21,7 +21,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/internal/encoding/messageset"
|
||||
"google.golang.org/protobuf/internal/errors"
|
||||
"google.golang.org/protobuf/internal/flags"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
@ -613,6 +615,26 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E
|
||||
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
|
||||
return xt, nil
|
||||
}
|
||||
|
||||
// MessageSet extensions are special in that the name of the extension
|
||||
// is the name of the message type used to extend the MessageSet.
|
||||
// This naming scheme is used by text and JSON serialization.
|
||||
//
|
||||
// This feature is protected by the ProtoLegacy flag since MessageSets
|
||||
// are a proto1 feature that is long deprecated.
|
||||
if flags.ProtoLegacy {
|
||||
if _, ok := v.(protoreflect.MessageType); ok {
|
||||
field := field.Append(messageset.ExtensionName)
|
||||
if v := r.typesByName[field]; v != nil {
|
||||
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
|
||||
if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
|
||||
return xt, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
|
||||
}
|
||||
return nil, NotFound
|
||||
|
Loading…
x
Reference in New Issue
Block a user