diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go index 9bf4e8c1..5ba9ebf2 100644 --- a/encoding/protojson/decode.go +++ b/encoding/protojson/decode.go @@ -170,7 +170,7 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error { if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") { // Only extension names are in [name] format. extName := pref.FullName(name[1 : len(name)-1]) - extType, err := d.findExtension(extName) + extType, err := d.opts.Resolver.FindExtensionByName(extName) if err != nil && err != protoregistry.NotFound { return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err) } @@ -257,15 +257,6 @@ func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error { } } -// findExtension returns protoreflect.ExtensionType from the resolver if found. -func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) { - xt, err := d.opts.Resolver.FindExtensionByName(xtName) - if err == nil { - return xt, nil - } - return messageset.FindMessageSetExtension(d.opts.Resolver, xtName) -} - func isKnownValue(fd pref.FieldDescriptor) bool { md := fd.Message() return md != nil && md.FullName() == genid.Value_message_fullname diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go index 5e3dcffb..4791f650 100644 --- a/encoding/protojson/decode_test.go +++ b/encoding/protojson/decode_test.go @@ -1403,7 +1403,7 @@ func TestUnmarshal(t *testing.T) { "optString": "not a messageset extension" } }`, - wantErr: `unknown field "[pb2.FakeMessageSetExtension]"`, + wantErr: `unable to resolve "[pb2.FakeMessageSetExtension]": found wrong type`, skip: !flags.ProtoLegacy, }, { desc: "not real MessageSet 3", diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go index cab95a42..8cce1e06 100644 --- a/encoding/prototext/decode.go +++ b/encoding/prototext/decode.go @@ -172,7 +172,7 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error { case text.TypeName: // Handle extensions only. This code path is not for Any. - xt, xtErr = d.findExtension(pref.FullName(tok.TypeName())) + xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName())) case text.FieldNumber: isFieldNumberName = true @@ -269,15 +269,6 @@ func (d decoder) unmarshalMessage(m pref.Message, checkDelims bool) error { return nil } -// findExtension returns protoreflect.ExtensionType from the Resolver if found. -func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) { - xt, err := d.opts.Resolver.FindExtensionByName(xtName) - if err == nil { - return xt, nil - } - return messageset.FindMessageSetExtension(d.opts.Resolver, xtName) -} - // unmarshalSingular unmarshals a non-repeated field value specified by the // given FieldDescriptor. func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error { diff --git a/encoding/prototext/decode_test.go b/encoding/prototext/decode_test.go index dceded16..441dcb24 100644 --- a/encoding/prototext/decode_test.go +++ b/encoding/prototext/decode_test.go @@ -1508,7 +1508,7 @@ opt_int32: 42 opt_string: "not a messageset extension" } `, - wantErr: "unknown field: [pb2.FakeMessageSetExtension]", + wantErr: `unable to resolve [[pb2.FakeMessageSetExtension]]: found wrong type`, skip: !flags.ProtoLegacy, }, { desc: "not real MessageSet 3", diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go index b1eeea50..453a81a5 100644 --- a/internal/encoding/messageset/messageset.go +++ b/internal/encoding/messageset/messageset.go @@ -11,7 +11,6 @@ import ( "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/errors" pref "google.golang.org/protobuf/reflect/protoreflect" - preg "google.golang.org/protobuf/reflect/protoregistry" ) // The MessageSet wire format is equivalent to a message defiend as follows, @@ -48,33 +47,17 @@ func IsMessageSet(md pref.MessageDescriptor) bool { return ok && xmd.IsMessageSet() } -// IsMessageSetExtension reports this field extends a MessageSet. +// IsMessageSetExtension reports this field properly extends a MessageSet. func IsMessageSetExtension(fd pref.FieldDescriptor) bool { - if fd.Name() != ExtensionName { + switch { + case fd.Name() != ExtensionName: + return false + case !IsMessageSet(fd.ContainingMessage()): + return false + case fd.FullName().Parent() != fd.Message().FullName(): return false } - if fd.FullName().Parent() != fd.Message().FullName() { - return false - } - return IsMessageSet(fd.ContainingMessage()) -} - -// FindMessageSetExtension locates a MessageSet extension field by name. -// In text and JSON formats, the extension name used is the message itself. -// The extension field name is derived by appending ExtensionName. -func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) { - name := s.Append(ExtensionName) - xt, err := r.FindExtensionByName(name) - if err != nil { - if err == preg.NotFound { - return nil, err - } - return nil, errors.Wrap(err, "%q", name) - } - if !IsMessageSetExtension(xt.TypeDescriptor()) { - return nil, preg.NotFound - } - return xt, nil + return true } // SizeField returns the size of a MessageSet item field containing an extension diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index 5e5f9671..7dcf4d9e 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -21,7 +21,9 @@ import ( "strings" "sync" + "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -613,6 +615,26 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E if xt, _ := v.(protoreflect.ExtensionType); xt != nil { return xt, nil } + + // MessageSet extensions are special in that the name of the extension + // is the name of the message type used to extend the MessageSet. + // This naming scheme is used by text and JSON serialization. + // + // This feature is protected by the ProtoLegacy flag since MessageSets + // are a proto1 feature that is long deprecated. + if flags.ProtoLegacy { + if _, ok := v.(protoreflect.MessageType); ok { + field := field.Append(messageset.ExtensionName) + if v := r.typesByName[field]; v != nil { + if xt, _ := v.(protoreflect.ExtensionType); xt != nil { + if messageset.IsMessageSetExtension(xt.TypeDescriptor()) { + return xt, nil + } + } + } + } + } + return nil, errors.New("found wrong type: got %v, want extension", typeName(v)) } return nil, NotFound