Damien Neil 5698f90d86 internal/impl: fix messageset validation bug
The validator was not ensuring the the MessageInfo for messageset
items was initialized. Fixed.

One or more of the existing messageset tests fail when run in isolation
due to this bug, but running all of them in sequence passes due to an
earlier test initializing the MessageInfo first.

Change-Id: Ifa7bd525c6d1cef9d1bed7bf761b0380907e35ee
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/221023
Reviewed-by: Joe Tsai <joetsai@google.com>
2020-02-26 18:56:34 +00:00

576 lines
15 KiB
Go

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"math"
"math/bits"
"reflect"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/strs"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
piface "google.golang.org/protobuf/runtime/protoiface"
)
// ValidationStatus is the result of validating the wire-format encoding of a message.
type ValidationStatus int
const (
// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
// The validator was unable to render a judgement.
//
// The only causes of this status are an aberrant message type appearing somewhere
// in the message or a failure in the extension resolver.
ValidationUnknown ValidationStatus = iota + 1
// ValidationInvalid indicates that unmarshaling the message will fail.
ValidationInvalid
// ValidationValid indicates that unmarshaling the message will succeed.
ValidationValid
)
func (v ValidationStatus) String() string {
switch v {
case ValidationUnknown:
return "ValidationUnknown"
case ValidationInvalid:
return "ValidationInvalid"
case ValidationValid:
return "ValidationValid"
default:
return fmt.Sprintf("ValidationStatus(%d)", int(v))
}
}
// Validate determines whether the contents of the buffer are a valid wire encoding
// of the message type.
//
// This function is exposed for testing.
func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
mi, ok := mt.(*MessageInfo)
if !ok {
return out, ValidationUnknown
}
if in.Resolver == nil {
in.Resolver = preg.GlobalTypes
}
o, st := mi.validate(in.Buf, 0, unmarshalOptions{
flags: in.Flags,
resolver: in.Resolver,
})
if o.initialized {
out.Flags |= piface.UnmarshalInitialized
}
return out, st
}
type validationInfo struct {
mi *MessageInfo
typ validationType
keyType, valType validationType
// For non-required fields, requiredBit is 0.
//
// For required fields, requiredBit's nth bit is set, where n is a
// unique index in the range [0, MessageInfo.numRequiredFields).
//
// If there are more than 64 required fields, requiredBit is 0.
requiredBit uint64
}
type validationType uint8
const (
validationTypeOther validationType = iota
validationTypeMessage
validationTypeGroup
validationTypeMap
validationTypeRepeatedVarint
validationTypeRepeatedFixed32
validationTypeRepeatedFixed64
validationTypeVarint
validationTypeFixed32
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
validationTypeMessageSetItem
)
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
var vi validationInfo
switch {
case fd.ContainingOneof() != nil:
switch fd.Kind() {
case pref.MessageKind:
vi.typ = validationTypeMessage
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
vi.mi = getMessageInfo(ot.Field(0).Type)
}
case pref.GroupKind:
vi.typ = validationTypeGroup
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
vi.mi = getMessageInfo(ot.Field(0).Type)
}
case pref.StringKind:
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
}
default:
vi = newValidationInfo(fd, ft)
}
if fd.Cardinality() == pref.Required {
// Avoid overflow. The required field check is done with a 64-bit mask, with
// any message containing more than 64 required fields always reported as
// potentially uninitialized, so it is not important to get a precise count
// of the required fields past 64.
if mi.numRequiredFields < math.MaxUint8 {
mi.numRequiredFields++
vi.requiredBit = 1 << (mi.numRequiredFields - 1)
}
}
return vi
}
func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
var vi validationInfo
switch {
case fd.IsList():
switch fd.Kind() {
case pref.MessageKind:
vi.typ = validationTypeMessage
if ft.Kind() == reflect.Slice {
vi.mi = getMessageInfo(ft.Elem())
}
case pref.GroupKind:
vi.typ = validationTypeGroup
if ft.Kind() == reflect.Slice {
vi.mi = getMessageInfo(ft.Elem())
}
case pref.StringKind:
vi.typ = validationTypeBytes
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
default:
switch wireTypes[fd.Kind()] {
case wire.VarintType:
vi.typ = validationTypeRepeatedVarint
case wire.Fixed32Type:
vi.typ = validationTypeRepeatedFixed32
case wire.Fixed64Type:
vi.typ = validationTypeRepeatedFixed64
}
}
case fd.IsMap():
vi.typ = validationTypeMap
switch fd.MapKey().Kind() {
case pref.StringKind:
if strs.EnforceUTF8(fd) {
vi.keyType = validationTypeUTF8String
}
}
switch fd.MapValue().Kind() {
case pref.MessageKind:
vi.valType = validationTypeMessage
if ft.Kind() == reflect.Map {
vi.mi = getMessageInfo(ft.Elem())
}
case pref.StringKind:
if strs.EnforceUTF8(fd) {
vi.valType = validationTypeUTF8String
}
}
default:
switch fd.Kind() {
case pref.MessageKind:
vi.typ = validationTypeMessage
if !fd.IsWeak() {
vi.mi = getMessageInfo(ft)
}
case pref.GroupKind:
vi.typ = validationTypeGroup
vi.mi = getMessageInfo(ft)
case pref.StringKind:
vi.typ = validationTypeBytes
if strs.EnforceUTF8(fd) {
vi.typ = validationTypeUTF8String
}
default:
switch wireTypes[fd.Kind()] {
case wire.VarintType:
vi.typ = validationTypeVarint
case wire.Fixed32Type:
vi.typ = validationTypeFixed32
case wire.Fixed64Type:
vi.typ = validationTypeFixed64
case wire.BytesType:
vi.typ = validationTypeBytes
}
}
}
return vi
}
func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
mi.init()
type validationState struct {
typ validationType
keyType, valType validationType
endGroup wire.Number
mi *MessageInfo
tail []byte
requiredMask uint64
}
// Pre-allocate some slots to avoid repeated slice reallocation.
states := make([]validationState, 0, 16)
states = append(states, validationState{
typ: validationTypeMessage,
mi: mi,
})
if groupTag > 0 {
states[0].typ = validationTypeGroup
states[0].endGroup = groupTag
}
initialized := true
start := len(b)
State:
for len(states) > 0 {
st := &states[len(states)-1]
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
if b[0] < 0x80 {
tag = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
tag, n = wire.ConsumeVarint(b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
var num wire.Number
if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
return out, ValidationInvalid
} else {
num = wire.Number(n)
}
wtyp := wire.Type(tag & 7)
if wtyp == wire.EndGroupType {
if st.endGroup == num {
goto PopState
}
return out, ValidationInvalid
}
var vi validationInfo
switch {
case st.typ == validationTypeMap:
switch num {
case 1:
vi.typ = st.keyType
case 2:
vi.typ = st.valType
vi.mi = st.mi
vi.requiredBit = 1
}
case flags.ProtoLegacy && st.mi.isMessageSet:
switch num {
case messageset.FieldItem:
vi.typ = validationTypeMessageSetItem
}
default:
var f *coderFieldInfo
if int(num) < len(st.mi.denseCoderFields) {
f = st.mi.denseCoderFields[num]
} else {
f = st.mi.coderFields[num]
}
if f != nil {
vi = f.validation
if vi.typ == validationTypeMessage && vi.mi == nil {
// Probable weak field.
//
// TODO: Consider storing the results of this lookup somewhere
// rather than recomputing it on every validation.
fd := st.mi.Desc.Fields().ByNumber(num)
if fd == nil || !fd.IsWeak() {
break
}
messageName := fd.Message().FullName()
messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
switch err {
case nil:
vi.mi, _ = messageType.(*MessageInfo)
case preg.NotFound:
vi.typ = validationTypeBytes
default:
return out, ValidationUnknown
}
}
break
}
// Possible extension field.
//
// TODO: We should return ValidationUnknown when:
// 1. The resolver is not frozen. (More extensions may be added to it.)
// 2. The resolver returns preg.NotFound.
// In this case, a type added to the resolver in the future could cause
// unmarshaling to begin failing. Supporting this requires some way to
// determine if the resolver is frozen.
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
if err != nil && err != preg.NotFound {
return out, ValidationUnknown
}
if err == nil {
vi = getExtensionFieldInfo(xt).validation
}
}
if vi.requiredBit != 0 {
// Check that the field has a compatible wire type.
// We only need to consider non-repeated field types,
// since repeated fields (and maps) can never be required.
ok := false
switch vi.typ {
case validationTypeVarint:
ok = wtyp == wire.VarintType
case validationTypeFixed32:
ok = wtyp == wire.Fixed32Type
case validationTypeFixed64:
ok = wtyp == wire.Fixed64Type
case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
ok = wtyp == wire.BytesType
case validationTypeGroup:
ok = wtyp == wire.StartGroupType
}
if ok {
st.requiredMask |= vi.requiredBit
}
}
switch wtyp {
case wire.VarintType:
if len(b) >= 10 {
switch {
case b[0] < 0x80:
b = b[1:]
case b[1] < 0x80:
b = b[2:]
case b[2] < 0x80:
b = b[3:]
case b[3] < 0x80:
b = b[4:]
case b[4] < 0x80:
b = b[5:]
case b[5] < 0x80:
b = b[6:]
case b[6] < 0x80:
b = b[7:]
case b[7] < 0x80:
b = b[8:]
case b[8] < 0x80:
b = b[9:]
case b[9] < 0x80 && b[9] < 2:
b = b[10:]
default:
return out, ValidationInvalid
}
} else {
switch {
case len(b) > 0 && b[0] < 0x80:
b = b[1:]
case len(b) > 1 && b[1] < 0x80:
b = b[2:]
case len(b) > 2 && b[2] < 0x80:
b = b[3:]
case len(b) > 3 && b[3] < 0x80:
b = b[4:]
case len(b) > 4 && b[4] < 0x80:
b = b[5:]
case len(b) > 5 && b[5] < 0x80:
b = b[6:]
case len(b) > 6 && b[6] < 0x80:
b = b[7:]
case len(b) > 7 && b[7] < 0x80:
b = b[8:]
case len(b) > 8 && b[8] < 0x80:
b = b[9:]
case len(b) > 9 && b[9] < 2:
b = b[10:]
default:
return out, ValidationInvalid
}
}
continue State
case wire.BytesType:
var size uint64
if len(b) >= 1 && b[0] < 0x80 {
size = uint64(b[0])
b = b[1:]
} else if len(b) >= 2 && b[1] < 128 {
size = uint64(b[0]&0x7f) + uint64(b[1])<<7
b = b[2:]
} else {
var n int
size, n = wire.ConsumeVarint(b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
if size > uint64(len(b)) {
return out, ValidationInvalid
}
v := b[:size]
b = b[size:]
switch vi.typ {
case validationTypeMessage:
if vi.mi == nil {
return out, ValidationUnknown
}
vi.mi.init()
fallthrough
case validationTypeMap:
if vi.mi != nil {
vi.mi.init()
}
states = append(states, validationState{
typ: vi.typ,
keyType: vi.keyType,
valType: vi.valType,
mi: vi.mi,
tail: b,
})
b = v
continue State
case validationTypeRepeatedVarint:
// Packed field.
for len(v) > 0 {
_, n := wire.ConsumeVarint(v)
if n < 0 {
return out, ValidationInvalid
}
v = v[n:]
}
case validationTypeRepeatedFixed32:
// Packed field.
if len(v)%4 != 0 {
return out, ValidationInvalid
}
case validationTypeRepeatedFixed64:
// Packed field.
if len(v)%8 != 0 {
return out, ValidationInvalid
}
case validationTypeUTF8String:
if !utf8.Valid(v) {
return out, ValidationInvalid
}
}
case wire.Fixed32Type:
if len(b) < 4 {
return out, ValidationInvalid
}
b = b[4:]
case wire.Fixed64Type:
if len(b) < 8 {
return out, ValidationInvalid
}
b = b[8:]
case wire.StartGroupType:
switch {
case vi.typ == validationTypeGroup:
if vi.mi == nil {
return out, ValidationUnknown
}
vi.mi.init()
states = append(states, validationState{
typ: validationTypeGroup,
mi: vi.mi,
endGroup: num,
})
continue State
case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
if err != nil {
return out, ValidationInvalid
}
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
switch {
case err == preg.NotFound:
b = b[n:]
case err != nil:
return out, ValidationUnknown
default:
xvi := getExtensionFieldInfo(xt).validation
if xvi.mi != nil {
xvi.mi.init()
}
states = append(states, validationState{
typ: xvi.typ,
mi: xvi.mi,
tail: b[n:],
})
b = v
continue State
}
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return out, ValidationInvalid
}
b = b[n:]
}
default:
return out, ValidationInvalid
}
}
if st.endGroup != 0 {
return out, ValidationInvalid
}
if len(b) != 0 {
return out, ValidationInvalid
}
b = st.tail
PopState:
numRequiredFields := 0
switch st.typ {
case validationTypeMessage, validationTypeGroup:
numRequiredFields = int(st.mi.numRequiredFields)
case validationTypeMap:
// If this is a map field with a message value that contains
// required fields, require that the value be present.
if st.mi != nil && st.mi.numRequiredFields > 0 {
numRequiredFields = 1
}
}
// If there are more than 64 required fields, this check will
// always fail and we will report that the message is potentially
// uninitialized.
if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
initialized = false
}
states = states[:len(states)-1]
}
out.n = start - len(b)
if initialized {
out.initialized = true
}
return out, ValidationValid
}