protobuf-go/proto/messageset_test.go
Damien Neil e1c61a307e internal/encoding/messageset: fix decoding of some invalid data
For historical reasons, MessageSets items are allowed to have field
numbers outside the usual valid range. Detect the case where the field
number cannot fit in an int32 and report an error. Also check for
a field number of 0 (always invalid).

Handle the case where a MessageSet item includes an unknown field.
We have no place to put the contents of the field, so drop it. This is,
I believe, consistent with other implementations.

Change-Id: Ic403427e1c276cbfa232ca577e7a799cce706bc7
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/221939
Reviewed-by: Herbie Ong <herbie@google.com>
2020-03-04 02:12:26 +00:00

313 lines
10 KiB
Go

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proto_test
import (
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
)
func init() {
if flags.ProtoLegacy {
testValidMessages = append(testValidMessages, messageSetTestProtos...)
testInvalidMessages = append(testInvalidMessages, messageSetInvalidTestProtos...)
}
}
var messageSetTestProtos = []testProto{
{
desc: "MessageSet type_id before message content",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet type_id after message content",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet does not preserve unknown field",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
}),
)},
wire: pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{1, pack.EndGroupType},
// Unknown field
pack.Tag{4, pack.VarintType}, pack.Varint(30),
}.Marshal(),
},
{
desc: "MessageSet with unknown type_id",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
unknown(pack.Message{
pack.Tag{999, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
}.Marshal()),
)},
wire: pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(999),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet merges repeated message fields in item",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
Ext1Field2: proto.Int32(20),
}),
)},
wire: pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{2, pack.VarintType}, pack.Varint(20),
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet merges message fields in repeated items",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
Ext1Field2: proto.Int32(20),
}),
extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
Ext2Field1: proto.Int32(30),
}),
)},
wire: pack.Message{
// Ext1, field1
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{1, pack.EndGroupType},
// Ext2, field1
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1001),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(30),
}),
pack.Tag{1, pack.EndGroupType},
// Ext2, field2
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{2, pack.VarintType}, pack.Varint(20),
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet with missing type_id",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
)},
wire: pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet with missing message",
decodeTo: []proto.Message{build(
&messagesetpb.MessageSet{},
extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
)},
wire: pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
},
{
desc: "MessageSet with type id out of valid field number range",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_ExtLargeNumber_MessageSetExtension, &msetextpb.ExtLargeNumber{})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 1),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with unknown type id out of valid field number range",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
m.MessageSet.ProtoReflect().SetUnknown(
pack.Message{
pack.Tag{wire.MaxValidNumber + 2, pack.BytesType}, pack.LengthPrefix{},
}.Marshal(),
)
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 2),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with unknown field",
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
Ext1Field1: proto.Int32(10),
})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(10),
}),
pack.Tag{4, pack.VarintType}, pack.Varint(0),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with required field set",
checkFastInit: true,
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_ExtRequired_MessageSetExtension, &msetextpb.ExtRequired{
RequiredField1: proto.Int32(1),
})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1002),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(1),
}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with required field unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{func() proto.Message {
m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
proto.SetExtension(m.MessageSet, msetextpb.E_ExtRequired_MessageSetExtension, &msetextpb.ExtRequired{})
return m
}()},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Varint(1002),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
}
var messageSetInvalidTestProtos = []testProto{
{
desc: "MessageSet with type id 0",
decodeTo: []proto.Message{
(*messagesetpb.MessageSetContainer)(nil),
},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Uvarint(0),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
{
desc: "MessageSet with type id overflowing int32",
decodeTo: []proto.Message{
(*messagesetpb.MessageSetContainer)(nil),
},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.StartGroupType},
pack.Tag{2, pack.VarintType}, pack.Uvarint(0x80000000),
pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
},
}