diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go index 55526cf3..4bd0e4eb 100644 --- a/internal/encoding/messageset/messageset.go +++ b/internal/encoding/messageset/messageset.go @@ -6,6 +6,8 @@ package messageset import ( + "math" + "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/errors" pref "google.golang.org/protobuf/reflect/protoreflect" @@ -146,6 +148,9 @@ func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []by return 0, nil, 0, wire.ParseError(n) } b = b[n:] + if v < 1 || v > math.MaxInt32 { + return 0, nil, 0, errors.New("invalid type_id in message set") + } typeid = wire.Number(v) case num == FieldMessage && wtyp == wire.BytesType: m, n := wire.ConsumeBytes(b) @@ -178,6 +183,13 @@ func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []by } } b = b[n:] + default: + // We have no place to put it, so we just ignore unknown fields. + n := wire.ConsumeFieldValue(num, wtyp, b) + if n < 0 { + return 0, nil, 0, wire.ParseError(n) + } + b = b[n:] } } } diff --git a/proto/messageset_test.go b/proto/messageset_test.go index ff413421..1eedcc96 100644 --- a/proto/messageset_test.go +++ b/proto/messageset_test.go @@ -17,6 +17,7 @@ import ( func init() { if flags.ProtoLegacy { testValidMessages = append(testValidMessages, messageSetTestProtos...) + testInvalidMessages = append(testInvalidMessages, messageSetInvalidTestProtos...) } } @@ -217,6 +218,27 @@ var messageSetTestProtos = []testProto{ }), }.Marshal(), }, + { + desc: "MessageSet with unknown field", + decodeTo: []proto.Message{func() proto.Message { + m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}} + proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{ + Ext1Field1: proto.Int32(10), + }) + return m + }()}, + wire: pack.Message{ + pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.StartGroupType}, + pack.Tag{2, pack.VarintType}, pack.Varint(1000), + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(10), + }), + pack.Tag{4, pack.VarintType}, pack.Varint(0), + pack.Tag{1, pack.EndGroupType}, + }), + }.Marshal(), + }, { desc: "MessageSet with required field set", checkFastInit: true, @@ -257,3 +279,34 @@ var messageSetTestProtos = []testProto{ }.Marshal(), }, } + +var messageSetInvalidTestProtos = []testProto{ + { + desc: "MessageSet with type id 0", + decodeTo: []proto.Message{ + (*messagesetpb.MessageSetContainer)(nil), + }, + wire: pack.Message{ + pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.StartGroupType}, + pack.Tag{2, pack.VarintType}, pack.Uvarint(0), + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}), + pack.Tag{1, pack.EndGroupType}, + }), + }.Marshal(), + }, + { + desc: "MessageSet with type id overflowing int32", + decodeTo: []proto.Message{ + (*messagesetpb.MessageSetContainer)(nil), + }, + wire: pack.Message{ + pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.StartGroupType}, + pack.Tag{2, pack.VarintType}, pack.Uvarint(0x80000000), + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}), + pack.Tag{1, pack.EndGroupType}, + }), + }.Marshal(), + }, +}