all: don't allow invalid field numbers when legacy support is on

The deprecated messageset format permits extension fields with numbers
greater than the usual maximum (1<<29-1). To support this, the
internal/encoding/wire package has disabled field number validation when
legacy support is enabled.

We shouldn't skip validating all field numbers for validity just because
we support larger ones in messagesets.

This change drops range validation from the wire package (other than
checking that numbers fit in an int32) and adds it to the wire
unmarshalers instead. This gives us validation where we care
about it (when unmarshaling a wire-format message) and allows for
best-effort handling of out-of-range numbers everywhere else.

Fixes golang/protobuf#996

Change-Id: I4e11b8a8aa177dd60e89723570af074a317c2451
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210290
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-12-06 15:36:03 -08:00
parent 5366f825ad
commit fe15dd4cdd
7 changed files with 142 additions and 29 deletions

View File

@ -13,7 +13,6 @@ import (
"math/bits"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
)
// Number represents the field number.
@ -490,18 +489,12 @@ func SizeGroup(num Number, n int) int {
}
// DecodeTag decodes the field Number and wire Type from its unified form.
// The Number is -1 if the decoded field number overflows.
// The Number is -1 if the decoded field number overflows int32.
// Other than overflow, this does not check for field number validity.
func DecodeTag(x uint64) (Number, Type) {
// NOTE: MessageSet allows for larger field numbers than normal.
if flags.ProtoLegacy {
if x>>3 > uint64(math.MaxInt32) {
return -1, 0
}
} else {
if x>>3 > uint64(MaxValidNumber) {
return -1, 0
}
if x>>3 > uint64(math.MaxInt32) {
return -1, 0
}
return Number(x >> 3), Type(x & 7)
}

View File

@ -85,6 +85,9 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if n < 0 {
return 0, wire.ParseError(n)
}
if num > wire.MaxValidNumber {
return 0, errors.New("invalid field number")
}
b = b[n:]
var f *coderFieldInfo

View File

@ -117,6 +117,44 @@ func (x *Ext2) GetExt2Field1() int32 {
return 0
}
type ExtLargeNumber struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ExtLargeNumber) Reset() {
*x = ExtLargeNumber{}
if protoimpl.UnsafeEnabled {
mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ExtLargeNumber) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExtLargeNumber) ProtoMessage() {}
func (x *ExtLargeNumber) ProtoReflect() protoreflect.Message {
mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExtLargeNumber.ProtoReflect.Descriptor instead.
func (*ExtLargeNumber) Descriptor() ([]byte, []int) {
return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{2}
}
var file_messageset_msetextpb_msetextpb_proto_extTypes = []protoimpl.ExtensionInfo{
{
ExtendedType: (*messagesetpb.MessageSet)(nil),
@ -134,6 +172,14 @@ var file_messageset_msetextpb_msetextpb_proto_extTypes = []protoimpl.ExtensionIn
Tag: "bytes,1001,opt,name=message_set_extension",
Filename: "messageset/msetextpb/msetextpb.proto",
},
{
ExtendedType: (*messagesetpb.MessageSet)(nil),
ExtensionType: (*ExtLargeNumber)(nil),
Field: 536870912,
Name: "goproto.proto.messageset.ExtLargeNumber",
Tag: "bytes,536870912,opt,name=message_set_extension",
Filename: "messageset/msetextpb/msetextpb.proto",
},
}
// Extension fields to messagesetpb.MessageSet.
@ -142,6 +188,8 @@ var (
E_Ext1_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[0]
// optional goproto.proto.messageset.Ext2 message_set_extension = 1001;
E_Ext2_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[1]
// optional goproto.proto.messageset.ExtLargeNumber message_set_extension = 536870912;
E_ExtLargeNumber_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[2] // 1<<29
)
var File_messageset_msetextpb_msetextpb_proto protoreflect.FileDescriptor
@ -176,11 +224,21 @@ var file_messageset_msetextpb_msetextpb_proto_rawDesc = []byte{
0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74,
0x32, 0x52, 0x13, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74,
0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65,
0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x99, 0x01, 0x0a, 0x0e, 0x45, 0x78, 0x74, 0x4c, 0x61,
0x72, 0x67, 0x65, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x32, 0x86, 0x01, 0x0a, 0x15, 0x6d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73,
0x69, 0x6f, 0x6e, 0x12, 0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0x80, 0x80, 0x80, 0x80, 0x02, 0x20,
0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45,
0x78, 0x74, 0x4c, 0x61, 0x72, 0x67, 0x65, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x52, 0x13, 0x6d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69,
0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c,
0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f,
0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
}
var (
@ -195,21 +253,24 @@ func file_messageset_msetextpb_msetextpb_proto_rawDescGZIP() []byte {
return file_messageset_msetextpb_msetextpb_proto_rawDescData
}
var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_messageset_msetextpb_msetextpb_proto_goTypes = []interface{}{
(*Ext1)(nil), // 0: goproto.proto.messageset.Ext1
(*Ext2)(nil), // 1: goproto.proto.messageset.Ext2
(*messagesetpb.MessageSet)(nil), // 2: goproto.proto.messageset.MessageSet
(*ExtLargeNumber)(nil), // 2: goproto.proto.messageset.ExtLargeNumber
(*messagesetpb.MessageSet)(nil), // 3: goproto.proto.messageset.MessageSet
}
var file_messageset_msetextpb_msetextpb_proto_depIdxs = []int32{
2, // 0: goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
2, // 1: goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
0, // 2: goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
1, // 3: goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
4, // [4:4] is the sub-list for method output_type
4, // [4:4] is the sub-list for method input_type
2, // [2:4] is the sub-list for extension type_name
0, // [0:2] is the sub-list for extension extendee
3, // 0: goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
3, // 1: goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
3, // 2: goproto.proto.messageset.ExtLargeNumber.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
0, // 3: goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
1, // 4: goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
2, // 5: goproto.proto.messageset.ExtLargeNumber.message_set_extension:type_name -> goproto.proto.messageset.ExtLargeNumber
6, // [6:6] is the sub-list for method output_type
6, // [6:6] is the sub-list for method input_type
3, // [3:6] is the sub-list for extension type_name
0, // [0:3] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
@ -243,6 +304,18 @@ func file_messageset_msetextpb_msetextpb_proto_init() {
return nil
}
}
file_messageset_msetextpb_msetextpb_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ExtLargeNumber); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -250,8 +323,8 @@ func file_messageset_msetextpb_msetextpb_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_messageset_msetextpb_msetextpb_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 2,
NumMessages: 3,
NumExtensions: 3,
NumServices: 0,
},
GoTypes: file_messageset_msetextpb_msetextpb_proto_goTypes,

View File

@ -24,3 +24,9 @@ message Ext2 {
}
optional int32 ext2_field1 = 1;
}
message ExtLargeNumber {
extend MessageSet {
optional ExtLargeNumber message_set_extension = 536870912; // 1<<29
}
}

View File

@ -88,6 +88,9 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
if tagLen < 0 {
return wire.ParseError(tagLen)
}
if num > wire.MaxValidNumber {
return errors.New("invalid field number")
}
// Find the field descriptor for this field number.
fd := fields.ByNumber(num)

View File

@ -1762,14 +1762,12 @@ var invalidFieldNumberTestProtos = []struct {
pack.Tag{pack.MaxValidNumber, pack.VarintType}, pack.Varint(1006),
pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1007),
}.Marshal(),
allowed: flags.ProtoLegacy,
},
{
desc: "max+1",
wire: pack.Message{
pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1008),
}.Marshal(),
allowed: flags.ProtoLegacy,
},
}

View File

@ -6,6 +6,7 @@ package proto_test
import (
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
@ -180,4 +181,40 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet with type id out of valid field number range",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_ExtLargeNumber_MessageSetExtension, &msetextpb.ExtLargeNumber{})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 1),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with unknown type id out of valid field number range",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
m.MessageSet.ProtoReflect().SetUnknown(
pack.Message{
pack.Tag{wire.MaxValidNumber + 2, pack.BytesType}, pack.LengthPrefix{},
}.Marshal(),
)
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 2),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
}