diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go index b7b5477a..77522de2 100644 --- a/internal/encoding/messageset/messageset.go +++ b/internal/encoding/messageset/messageset.go @@ -77,31 +77,50 @@ func SizeField(num wire.Number) int { return 2*wire.SizeTag(FieldItem) + wire.SizeTag(FieldTypeID) + wire.SizeVarint(uint64(num)) } -// ConsumeField parses a MessageSet item field and returns the contents of the -// type_id and message subfields and the total item length. -func ConsumeField(b []byte) (typeid wire.Number, message []byte, n int, err error) { - num, wtyp, n := wire.ConsumeTag(b) - if n < 0 { - return 0, nil, 0, wire.ParseError(n) +// Unmarshal parses a MessageSet. +// +// It calls fn with the type ID and value of each item in the MessageSet. +// Unknown fields are discarded. +// +// If wantLen is true, the item values include the varint length prefix. +// This is ugly, but simplifies the fast-path decoder in internal/impl. +func Unmarshal(b []byte, wantLen bool, fn func(typeID wire.Number, value []byte) error) error { + for len(b) > 0 { + num, wtyp, n := wire.ConsumeTag(b) + if n < 0 { + return wire.ParseError(n) + } + b = b[n:] + if num != FieldItem || wtyp != wire.StartGroupType { + n := wire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return wire.ParseError(n) + } + b = b[n:] + continue + } + typeID, value, n, err := consumeFieldValue(b, wantLen) + if err != nil { + return err + } + b = b[n:] + if typeID == 0 { + continue + } + if err := fn(typeID, value); err != nil { + return err + } } - if num != FieldItem || wtyp != wire.StartGroupType { - return 0, nil, 0, errors.New("invalid MessageSet field number") - } - typeid, message, fieldLen, err := ConsumeFieldValue(b[n:], false) - if err != nil { - return 0, nil, 0, err - } - return typeid, message, n + fieldLen, nil + return nil } -// ConsumeFieldValue parses b as a MessageSet item field value until and including +// consumeFieldValue parses b as a MessageSet item field value until and including // the trailing end group marker. It assumes the start group tag has already been parsed. // It returns the contents of the type_id and message subfields and the total // item length. // // If wantLen is true, the returned message value includes the length prefix. -// This is ugly, but simplifies the fast-path decoder in internal/impl. -func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) { +func consumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) { ilen := len(b) for { num, wtyp, n := wire.ConsumeTag(b) @@ -173,3 +192,51 @@ func AppendFieldStart(b []byte, num wire.Number) []byte { func AppendFieldEnd(b []byte) []byte { return wire.AppendTag(b, FieldItem, wire.EndGroupType) } + +// SizeUnknown returns the size of an unknown fields section in MessageSet format. +// +// See AppendUnknown. +func SizeUnknown(unknown []byte) (size int) { + for len(unknown) > 0 { + num, typ, n := wire.ConsumeTag(unknown) + if n < 0 || typ != wire.BytesType { + return 0 + } + unknown = unknown[n:] + _, n = wire.ConsumeBytes(unknown) + if n < 0 { + return 0 + } + unknown = unknown[n:] + size += SizeField(num) + wire.SizeTag(FieldMessage) + n + } + return size +} + +// AppendUnknown appends unknown fields to b in MessageSet format. +// +// For historic reasons, unresolved items in a MessageSet are stored in a +// message's unknown fields section in non-MessageSet format. That is, an +// unknown item with typeID T and value V appears in the unknown fields as +// a field with number T and value V. +// +// This function converts the unknown fields back into MessageSet form. +func AppendUnknown(b, unknown []byte) ([]byte, error) { + for len(unknown) > 0 { + num, typ, n := wire.ConsumeTag(unknown) + if n < 0 || typ != wire.BytesType { + return nil, errors.New("invalid data in message set unknown fields") + } + unknown = unknown[n:] + _, n = wire.ConsumeBytes(unknown) + if n < 0 { + return nil, errors.New("invalid data in message set unknown fields") + } + b = AppendFieldStart(b, num) + b = wire.AppendTag(b, FieldMessage, wire.BytesType) + b = append(b, unknown[:n]...) + b = AppendFieldEnd(b) + unknown = unknown[n:] + } + return b, nil +} diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go index 4694718f..ab1d4ec5 100644 --- a/internal/impl/codec_message.go +++ b/internal/impl/codec_message.go @@ -29,6 +29,7 @@ type coderMessageInfo struct { unknownOffset offset extensionOffset offset needsInitCheck bool + isMessageSet bool extensionFieldInfosMu sync.RWMutex extensionFieldInfos map[pref.ExtensionType]*extensionFieldInfo @@ -97,16 +98,10 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { if !mi.extensionOffset.IsValid() { panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName())) } - cf := &coderFieldInfo{ - num: messageset.FieldItem, - offset: si.extensionOffset, - isPointer: true, - funcs: makeMessageSetFieldCoder(mi), + if !mi.unknownOffset.IsValid() { + panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName())) } - mi.orderedCoderFields = append(mi.orderedCoderFields, cf) - mi.coderFields[cf.num] = cf - // Invalidate the extension offset, since the field codec handles extensions. - mi.extensionOffset = invalidOffset + mi.isMessageSet = true } sort.Slice(mi.orderedCoderFields, func(i, j int) bool { return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go index 0b3746d8..d78afeb7 100644 --- a/internal/impl/codec_messageset.go +++ b/internal/impl/codec_messageset.go @@ -13,48 +13,36 @@ import ( "google.golang.org/protobuf/internal/flags" ) -func makeMessageSetFieldCoder(mi *MessageInfo) pointerCoderFuncs { - return pointerCoderFuncs{ - size: func(p pointer, tagsize int, opts marshalOptions) int { - return sizeMessageSet(mi, p, tagsize, opts) - }, - marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) { - return marshalMessageSet(mi, b, p, wiretag, opts) - }, - unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) { - return unmarshalMessageSet(mi, b, p, wtyp, opts) - }, - } -} - -func sizeMessageSet(mi *MessageInfo, p pointer, tagsize int, opts marshalOptions) (n int) { - ext := *p.Extensions() - if ext == nil { +func sizeMessageSet(mi *MessageInfo, p pointer, opts marshalOptions) (size int) { + if !flags.ProtoLegacy { return 0 } + + ext := *p.Apply(mi.extensionOffset).Extensions() for _, x := range ext { xi := mi.extensionFieldInfo(x.Type()) if xi.funcs.size == nil { continue } num, _ := wire.DecodeTag(xi.wiretag) - n += messageset.SizeField(num) - n += xi.funcs.size(x.Value(), wire.SizeTag(messageset.FieldMessage), opts) + size += messageset.SizeField(num) + size += xi.funcs.size(x.Value(), wire.SizeTag(messageset.FieldMessage), opts) } - return n + + unknown := *p.Apply(mi.unknownOffset).Bytes() + size += messageset.SizeUnknown(unknown) + + return size } -func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) { +func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts marshalOptions) ([]byte, error) { if !flags.ProtoLegacy { return b, errors.New("no support for message_set_wire_format") } - ext := *p.Extensions() - if ext == nil { - return b, nil - } + + ext := *p.Apply(mi.extensionOffset).Extensions() switch len(ext) { case 0: - return b, nil case 1: // Fast-path for one extension: Don't bother sorting the keys. for _, x := range ext { @@ -64,7 +52,6 @@ func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opt return b, err } } - return b, nil default: // Sort the keys to provide a deterministic encoding. // Not sure this is required, but the old code does it. @@ -80,8 +67,15 @@ func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opt return b, err } } - return b, nil } + + unknown := *p.Apply(mi.unknownOffset).Bytes() + b, err := messageset.AppendUnknown(b, unknown) + if err != nil { + return b, err + } + + return b, nil } func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) { @@ -96,24 +90,25 @@ func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts ma return b, nil } -func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) { +func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOptions) (int, error) { if !flags.ProtoLegacy { return 0, errors.New("no support for message_set_wire_format") } - if wtyp != wire.StartGroupType { - return 0, errUnknown - } - ep := p.Extensions() + + ep := p.Apply(mi.extensionOffset).Extensions() if *ep == nil { *ep = make(map[int32]ExtensionField) } ext := *ep - num, v, n, err := messageset.ConsumeFieldValue(b, true) - if err != nil { - return 0, err - } - if _, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts); err != nil { - return 0, err - } - return n, nil + unknown := p.Apply(mi.unknownOffset).Bytes() + err := messageset.Unmarshal(b, true, func(num wire.Number, v []byte) error { + _, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts) + if err == errUnknown { + *unknown = wire.AppendTag(*unknown, num, wire.BytesType) + *unknown = append(*unknown, v...) + return nil + } + return err + }) + return len(b), err } diff --git a/internal/impl/decode.go b/internal/impl/decode.go index 757bbde0..85cc6b0c 100644 --- a/internal/impl/decode.go +++ b/internal/impl/decode.go @@ -7,6 +7,7 @@ package impl import ( "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/proto" pref "google.golang.org/protobuf/reflect/protoreflect" preg "google.golang.org/protobuf/reflect/protoregistry" @@ -72,6 +73,9 @@ var errUnknown = errors.New("unknown") func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (int, error) { mi.init() + if flags.ProtoLegacy && mi.isMessageSet { + return unmarshalMessageSet(mi, b, p, opts) + } var exts *map[int32]ExtensionField start := len(b) for len(b) > 0 { diff --git a/internal/impl/encode.go b/internal/impl/encode.go index d7dc8929..cd57998b 100644 --- a/internal/impl/encode.go +++ b/internal/impl/encode.go @@ -8,6 +8,7 @@ import ( "sort" "sync/atomic" + "google.golang.org/protobuf/internal/flags" proto "google.golang.org/protobuf/proto" pref "google.golang.org/protobuf/reflect/protoreflect" piface "google.golang.org/protobuf/runtime/protoiface" @@ -69,6 +70,13 @@ func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) { } func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int) { + if flags.ProtoLegacy && mi.isMessageSet { + size = sizeMessageSet(mi, p, opts) + if mi.sizecacheOffset.IsValid() { + atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size)) + } + return size + } if mi.extensionOffset.IsValid() { e := p.Apply(mi.extensionOffset).Extensions() size += mi.sizeExtensions(e, opts) @@ -109,6 +117,9 @@ func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOpt if p.IsNil() { return b, nil } + if flags.ProtoLegacy && mi.isMessageSet { + return marshalMessageSet(mi, b, p, opts) + } var err error // The old marshaler encodes extensions at beginning. if mi.extensionOffset.IsValid() { @@ -132,7 +143,7 @@ func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOpt return b, err } } - if mi.unknownOffset.IsValid() { + if mi.unknownOffset.IsValid() && !mi.isMessageSet { u := *p.Apply(mi.unknownOffset).Bytes() b = append(b, u...) } diff --git a/internal/testprotos/messageset/messagesetpb/message_set.pb.go b/internal/testprotos/messageset/messagesetpb/message_set.pb.go index 1268e720..13406b2d 100644 --- a/internal/testprotos/messageset/messagesetpb/message_set.pb.go +++ b/internal/testprotos/messageset/messagesetpb/message_set.pb.go @@ -63,6 +63,53 @@ func (*MessageSet) ExtensionRangeArray() []protoiface.ExtensionRangeV1 { return extRange_MessageSet } +type MessageSetContainer struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + MessageSet *MessageSet `protobuf:"bytes,1,opt,name=message_set,json=messageSet" json:"message_set,omitempty"` +} + +func (x *MessageSetContainer) Reset() { + *x = MessageSetContainer{} + if protoimpl.UnsafeEnabled { + mi := &file_messageset_messagesetpb_message_set_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageSetContainer) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageSetContainer) ProtoMessage() {} + +func (x *MessageSetContainer) ProtoReflect() protoreflect.Message { + mi := &file_messageset_messagesetpb_message_set_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageSetContainer.ProtoReflect.Descriptor instead. +func (*MessageSetContainer) Descriptor() ([]byte, []int) { + return file_messageset_messagesetpb_message_set_proto_rawDescGZIP(), []int{1} +} + +func (x *MessageSetContainer) GetMessageSet() *MessageSet { + if x != nil { + return x.MessageSet + } + return nil +} + var File_messageset_messagesetpb_message_set_proto protoreflect.FileDescriptor var file_messageset_messagesetpb_message_set_proto_rawDesc = []byte{ @@ -72,11 +119,17 @@ var file_messageset_messagesetpb_message_set_proto_rawDesc = []byte{ 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x0a, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x2a, 0x08, 0x08, 0x04, 0x10, 0xff, 0xff, 0xff, 0xff, 0x07, 0x3a, 0x02, 0x08, - 0x01, 0x42, 0x48, 0x5a, 0x46, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, - 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, + 0x01, 0x22, 0x5c, 0x0a, 0x13, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x12, 0x45, 0x0a, 0x0b, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, + 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x53, 0x65, 0x74, 0x52, 0x0a, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x42, + 0x48, 0x5a, 0x46, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, + 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, } var ( @@ -91,16 +144,18 @@ func file_messageset_messagesetpb_message_set_proto_rawDescGZIP() []byte { return file_messageset_messagesetpb_message_set_proto_rawDescData } -var file_messageset_messagesetpb_message_set_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_messageset_messagesetpb_message_set_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_messageset_messagesetpb_message_set_proto_goTypes = []interface{}{ - (*MessageSet)(nil), // 0: goproto.proto.messageset.MessageSet + (*MessageSet)(nil), // 0: goproto.proto.messageset.MessageSet + (*MessageSetContainer)(nil), // 1: goproto.proto.messageset.MessageSetContainer } var file_messageset_messagesetpb_message_set_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 0, // 0: goproto.proto.messageset.MessageSetContainer.message_set:type_name -> goproto.proto.messageset.MessageSet + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_messageset_messagesetpb_message_set_proto_init() } @@ -123,6 +178,18 @@ func file_messageset_messagesetpb_message_set_proto_init() { return nil } } + file_messageset_messagesetpb_message_set_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageSetContainer); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -130,7 +197,7 @@ func file_messageset_messagesetpb_message_set_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_messageset_messagesetpb_message_set_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/internal/testprotos/messageset/messagesetpb/message_set.proto b/internal/testprotos/messageset/messagesetpb/message_set.proto index 08f7a4d5..4887977b 100644 --- a/internal/testprotos/messageset/messagesetpb/message_set.proto +++ b/internal/testprotos/messageset/messagesetpb/message_set.proto @@ -12,3 +12,7 @@ message MessageSet { option message_set_wire_format = true; extensions 4 to max; } + +message MessageSetContainer { + optional MessageSet message_set = 1; +} diff --git a/proto/messageset.go b/proto/messageset.go index 0d880974..e27e0b7d 100644 --- a/proto/messageset.go +++ b/proto/messageset.go @@ -20,7 +20,7 @@ func sizeMessageSet(m protoreflect.Message) (size int) { size += wire.SizeBytes(sizeMessage(v.Message())) return true }) - size += len(m.GetUnknown()) + size += messageset.SizeUnknown(m.GetUnknown()) return size } @@ -36,8 +36,7 @@ func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]by if err != nil { return b, err } - b = append(b, m.GetUnknown()...) - return b, nil + return messageset.AppendUnknown(b, m.GetUnknown()) } func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) { @@ -56,48 +55,34 @@ func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) e if !flags.ProtoLegacy { return errors.New("no support for message_set_wire_format") } - md := m.Descriptor() - for len(b) > 0 { - err := func() error { - num, v, n, err := messageset.ConsumeField(b) - if err != nil { - // Not a message set field. - // - // Return errUnknown to try to add this to the unknown fields. - // If the field is completely unparsable, we'll catch it - // when trying to skip the field. - return errUnknown - } - if !md.ExtensionRanges().Has(num) { - return errUnknown - } - xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) - if err == protoregistry.NotFound { - return errUnknown - } - if err != nil { - return err - } - xd := xt.TypeDescriptor() - if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil { - // Contents cannot be unmarshaled. - return err - } - b = b[n:] - return nil - }() + return messageset.Unmarshal(b, false, func(num wire.Number, v []byte) error { + err := unmarshalMessageSetField(m, num, v, o) if err == errUnknown { - _, _, n := wire.ConsumeField(b) - if n < 0 { - return wire.ParseError(n) - } - m.SetUnknown(append(m.GetUnknown(), b[:n]...)) - b = b[n:] - continue - } - if err != nil { - return err + unknown := m.GetUnknown() + unknown = wire.AppendTag(unknown, num, wire.BytesType) + unknown = wire.AppendBytes(unknown, v) + m.SetUnknown(unknown) + return nil } + return err + }) +} + +func unmarshalMessageSetField(m protoreflect.Message, num wire.Number, v []byte, o UnmarshalOptions) error { + md := m.Descriptor() + if !md.ExtensionRanges().Has(num) { + return errUnknown + } + xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) + if err == protoregistry.NotFound { + return errUnknown + } + if err != nil { + return err + } + xd := xt.TypeDescriptor() + if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil { + return err } return nil } diff --git a/proto/messageset_test.go b/proto/messageset_test.go index c1ef6c94..b7c4c727 100644 --- a/proto/messageset_test.go +++ b/proto/messageset_test.go @@ -22,48 +22,51 @@ func init() { var messageSetTestProtos = []testProto{ { desc: "MessageSet type_id before message content", - decodeTo: []proto.Message{build( - &messagesetpb.MessageSet{}, - extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ + decodeTo: []proto.Message{func() proto.Message { + m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}} + proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ Ext1Field1: proto.Int32(10), - }), - )}, + }) + return m + }()}, wire: pack.Message{ - pack.Tag{1, pack.StartGroupType}, - pack.Tag{2, pack.VarintType}, pack.Varint(1000), - pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ - pack.Tag{1, pack.VarintType}, pack.Varint(10), + pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.StartGroupType}, + pack.Tag{2, pack.VarintType}, pack.Varint(1000), + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(10), + }), + pack.Tag{1, pack.EndGroupType}, }), - pack.Tag{1, pack.EndGroupType}, }.Marshal(), }, { desc: "MessageSet type_id after message content", - decodeTo: []proto.Message{build( - &messagesetpb.MessageSet{}, - extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ + decodeTo: []proto.Message{func() proto.Message { + m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}} + proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ Ext1Field1: proto.Int32(10), - }), - )}, + }) + return m + }()}, wire: pack.Message{ - pack.Tag{1, pack.StartGroupType}, - pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ - pack.Tag{1, pack.VarintType}, pack.Varint(10), + pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.StartGroupType}, + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(10), + }), + pack.Tag{2, pack.VarintType}, pack.Varint(1000), + pack.Tag{1, pack.EndGroupType}, }), - pack.Tag{2, pack.VarintType}, pack.Varint(1000), - pack.Tag{1, pack.EndGroupType}, }.Marshal(), }, { - desc: "MessageSet preserves unknown field", + desc: "MessageSet does not preserve unknown field", decodeTo: []proto.Message{build( &messagesetpb.MessageSet{}, extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ Ext1Field1: proto.Int32(10), }), - unknown(pack.Message{ - pack.Tag{4, pack.VarintType}, pack.Varint(30), - }.Marshal()), )}, wire: pack.Message{ pack.Tag{1, pack.StartGroupType}, @@ -81,12 +84,9 @@ var messageSetTestProtos = []testProto{ decodeTo: []proto.Message{build( &messagesetpb.MessageSet{}, unknown(pack.Message{ - pack.Tag{1, pack.StartGroupType}, - pack.Tag{2, pack.VarintType}, pack.Varint(1002), - pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1002, pack.BytesType}, pack.LengthPrefix(pack.Message{ pack.Tag{1, pack.VarintType}, pack.Varint(10), }), - pack.Tag{1, pack.EndGroupType}, }.Marshal()), )}, wire: pack.Message{ @@ -159,13 +159,6 @@ var messageSetTestProtos = []testProto{ desc: "MessageSet with missing type_id", decodeTo: []proto.Message{build( &messagesetpb.MessageSet{}, - unknown(pack.Message{ - pack.Tag{1, pack.StartGroupType}, - pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ - pack.Tag{1, pack.VarintType}, pack.Varint(10), - }), - pack.Tag{1, pack.EndGroupType}, - }.Marshal()), )}, wire: pack.Message{ pack.Tag{1, pack.StartGroupType},