internal/impl: validate messagesets

Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-02-07 10:06:53 -08:00
parent f9d4fdf054
commit 9afe9bb78b
3 changed files with 35 additions and 25 deletions

View File

@ -99,7 +99,7 @@ func Unmarshal(b []byte, wantLen bool, fn func(typeID wire.Number, value []byte)
b = b[n:]
continue
}
typeID, value, n, err := consumeFieldValue(b, wantLen)
typeID, value, n, err := ConsumeFieldValue(b, wantLen)
if err != nil {
return err
}
@ -114,13 +114,13 @@ func Unmarshal(b []byte, wantLen bool, fn func(typeID wire.Number, value []byte)
return nil
}
// consumeFieldValue parses b as a MessageSet item field value until and including
// ConsumeFieldValue parses b as a MessageSet item field value until and including
// the trailing end group marker. It assumes the start group tag has already been parsed.
// It returns the contents of the type_id and message subfields and the total
// item length.
//
// If wantLen is true, the returned message value includes the length prefix.
func consumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
ilen := len(b)
for {
num, wtyp, n := wire.ConsumeTag(b)

View File

@ -11,6 +11,7 @@ import (
"reflect"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/strs"
@ -93,6 +94,7 @@ const (
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
validationTypeMessageSetItem
)
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
@ -237,11 +239,6 @@ func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOp
State:
for len(states) > 0 {
st := &states[len(states)-1]
if st.mi != nil {
if flags.ProtoLegacy && st.mi.isMessageSet {
return out, ValidationUnknown
}
}
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
@ -274,8 +271,8 @@ State:
return out, ValidationInvalid
}
var vi validationInfo
switch st.typ {
case validationTypeMap:
switch {
case st.typ == validationTypeMap:
switch num {
case 1:
vi.typ = st.keyType
@ -284,6 +281,11 @@ State:
vi.mi = st.mi
vi.requiredBit = 1
}
case flags.ProtoLegacy && st.mi.isMessageSet:
switch num {
case messageset.FieldItem:
vi.typ = validationTypeMessageSetItem
}
default:
var f *coderFieldInfo
if int(num) < len(st.mi.denseCoderFields) {
@ -483,8 +485,8 @@ State:
}
b = b[8:]
case wire.StartGroupType:
switch vi.typ {
case validationTypeGroup:
switch {
case vi.typ == validationTypeGroup:
if vi.mi == nil {
return out, ValidationUnknown
}
@ -495,6 +497,27 @@ State:
endGroup: num,
})
continue State
case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
if err != nil {
return out, ValidationInvalid
}
xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
switch {
case err == preg.NotFound:
b = b[n:]
case err != nil:
return out, ValidationUnknown
default:
xvi := getExtensionFieldInfo(xt).validation
states = append(states, validationState{
typ: xvi.typ,
mi: xvi.mi,
tail: b[n:],
})
b = v
continue State
}
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {

View File

@ -8,7 +8,6 @@ import (
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
@ -41,7 +40,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet type_id after message content",
@ -62,7 +60,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet does not preserve unknown field",
@ -82,7 +79,6 @@ var messageSetTestProtos = []testProto{
// Unknown field
pack.Tag{4, pack.VarintType}, pack.Varint(30),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with unknown type_id",
@ -102,7 +98,6 @@ var messageSetTestProtos = []testProto{
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet merges repeated message fields in item",
@ -124,7 +119,6 @@ var messageSetTestProtos = []testProto{
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet merges message fields in repeated items",
@ -161,7 +155,6 @@ var messageSetTestProtos = []testProto{
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with missing type_id",
@ -175,7 +168,6 @@ var messageSetTestProtos = []testProto{
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with missing message",
@ -188,7 +180,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with type id out of valid field number range",
@ -205,7 +196,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with unknown type id out of valid field number range",
@ -226,7 +216,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with required field set",
@ -248,7 +237,6 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with required field unset",
@ -267,6 +255,5 @@ var messageSetTestProtos = []testProto{
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
validationStatus: impl.ValidationUnknown,
},
}