internal/impl: refactor validation a bit

Return the size of the field read from the validator, permitting us to
avoid an extra parse when skipping over groups.

Return an UnmarshalOutput from the validator, since it already combines
two of the validator outputs: bytes read and initialization status.

Remove initialization status from the ValidationStatus enum, since it's
covered by the UnmarshalOutput.

Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-02-03 16:17:31 -08:00
parent 9b3d97c473
commit cadb4ab3b1
6 changed files with 83 additions and 91 deletions

View File

@ -55,7 +55,9 @@ func BenchmarkEmptyMessage(b *testing.B) {
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
if got, want := impl.Validate([]byte{}, mt, opts), impl.ValidationValidInitialized; got != want {
_, got := impl.Validate([]byte{}, mt, opts)
want := impl.ValidationValid
if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}
@ -106,7 +108,9 @@ func BenchmarkRepeatedInt32(b *testing.B) {
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
_, got := impl.Validate(w, mt, opts)
want := impl.ValidationValid
if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}
@ -167,7 +171,9 @@ func BenchmarkRequired(b *testing.B) {
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
_, got := impl.Validate(w, mt, opts)
want := impl.ValidationValid
if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}

View File

@ -19,7 +19,7 @@ import (
// Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
func Fuzz(data []byte) (score int) {
m1 := &fuzzpb.Fuzz{}
valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
Resolver: protoregistry.GlobalTypes,
})
if err := (proto.UnmarshalOptions{
@ -33,21 +33,14 @@ func Fuzz(data []byte) (score int) {
}
return 0
}
if proto.IsInitialized(m1) == nil {
switch valid {
case impl.ValidationUnknown:
case impl.ValidationValidInitialized:
case impl.ValidationValidMaybeUninitalized:
default:
panic("unmarshal ok with validation status: " + valid.String())
}
} else {
switch valid {
case impl.ValidationUnknown:
case impl.ValidationValidMaybeUninitalized:
default:
panic("partial unmarshal ok with validation status: " + valid.String())
}
switch valid {
case impl.ValidationUnknown:
case impl.ValidationValid:
default:
panic("unmarshal ok with validation status: " + valid.String())
}
if proto.IsInitialized(m1) != nil && vout.Initialized {
panic("validation reports partial message is initialized")
}
data1, err := proto.MarshalOptions{
AllowPartial: true,

View File

@ -196,11 +196,9 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
}
if flags.LazyUnmarshalExtensions {
if opts.IsDefault() && x.canLazy(xt) {
if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
exts[int32(num)] = x
out.n = n
out.initialized = true
return out, nil
}
}
@ -224,35 +222,31 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
return out, nil
}
func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
if xi.validation.mi == nil {
return 0, false
return out, false
}
xi.validation.mi.init()
var v []byte
switch xi.validation.typ {
case validationTypeMessage:
if wtyp != wire.BytesType {
return 0, false
return out, false
}
v, n = wire.ConsumeBytes(b)
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, false
return out, false
}
out, st := xi.validation.mi.validate(v, 0, opts)
out.n = n
return out, st == ValidationValid
case validationTypeGroup:
if wtyp != wire.StartGroupType {
return 0, false
}
v, n = wire.ConsumeGroup(num, b)
if n < 0 {
return 0, false
return out, false
}
out, st := xi.validation.mi.validate(v, num, opts)
return out, st == ValidationValid
default:
return 0, false
return out, false
}
if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
return 0, false
}
return n, true
}

View File

@ -33,16 +33,8 @@ const (
// ValidationInvalid indicates that unmarshaling the message will fail.
ValidationInvalid
// ValidationValidInitialized indicates that unmarshaling the message will succeed
// and IsInitialized on the result will report success.
ValidationValidInitialized
// ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
// but the output of IsInitialized on the result is unknown.
//
// This status may be returned for an initialized message when a message value
// is split across multiple fields.
ValidationValidMaybeUninitalized
// ValidationValid indicates that unmarshaling the message will succeed.
ValidationValid
)
func (v ValidationStatus) String() string {
@ -51,10 +43,8 @@ func (v ValidationStatus) String() string {
return "ValidationUnknown"
case ValidationInvalid:
return "ValidationInvalid"
case ValidationValidInitialized:
return "ValidationValidInitialized"
case ValidationValidMaybeUninitalized:
return "ValidationValidMaybeUninitalized"
case ValidationValid:
return "ValidationValid"
default:
return fmt.Sprintf("ValidationStatus(%d)", int(v))
}
@ -64,12 +54,14 @@ func (v ValidationStatus) String() string {
// of the message type.
//
// This function is exposed for testing.
func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
mi, ok := mt.(*MessageInfo)
if !ok {
return ValidationUnknown
return out, ValidationUnknown
}
return mi.validate(b, 0, unmarshalOptions(opts))
o, st := mi.validate(b, 0, unmarshalOptions(opts))
out.Initialized = o.initialized
return out, st
}
type validationInfo struct {
@ -219,7 +211,7 @@ func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo
return vi
}
func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
mi.init()
type validationState struct {
typ validationType
@ -241,12 +233,13 @@ func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOp
states[0].endGroup = groupTag
}
initialized := true
start := len(b)
State:
for len(states) > 0 {
st := &states[len(states)-1]
if st.mi != nil {
if flags.ProtoLegacy && st.mi.isMessageSet {
return ValidationUnknown
return out, ValidationUnknown
}
}
for len(b) > 0 {
@ -262,13 +255,13 @@ State:
var n int
tag, n = wire.ConsumeVarint(b)
if n < 0 {
return ValidationInvalid
return out, ValidationInvalid
}
b = b[n:]
}
var num wire.Number
if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
return ValidationInvalid
return out, ValidationInvalid
} else {
num = wire.Number(n)
}
@ -278,7 +271,7 @@ State:
if st.endGroup == num {
goto PopState
}
return ValidationInvalid
return out, ValidationInvalid
}
var vi validationInfo
switch st.typ {
@ -317,7 +310,7 @@ State:
case preg.NotFound:
vi.typ = validationTypeBytes
default:
return ValidationUnknown
return out, ValidationUnknown
}
}
break
@ -332,7 +325,7 @@ State:
// determine if the resolver is frozen.
xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
if err != nil && err != preg.NotFound {
return ValidationUnknown
return out, ValidationUnknown
}
if err == nil {
vi = getExtensionFieldInfo(xt).validation
@ -383,7 +376,7 @@ State:
case b[9] < 0x80 && b[9] < 2:
b = b[10:]
default:
return ValidationInvalid
return out, ValidationInvalid
}
} else {
switch {
@ -408,7 +401,7 @@ State:
case len(b) > 9 && b[9] < 2:
b = b[10:]
default:
return ValidationInvalid
return out, ValidationInvalid
}
}
continue State
@ -424,19 +417,19 @@ State:
var n int
size, n = wire.ConsumeVarint(b)
if n < 0 {
return ValidationInvalid
return out, ValidationInvalid
}
b = b[n:]
}
if size > uint64(len(b)) {
return ValidationInvalid
return out, ValidationInvalid
}
v := b[:size]
b = b[size:]
switch vi.typ {
case validationTypeMessage:
if vi.mi == nil {
return ValidationUnknown
return out, ValidationUnknown
}
vi.mi.init()
fallthrough
@ -455,40 +448,40 @@ State:
for len(v) > 0 {
_, n := wire.ConsumeVarint(v)
if n < 0 {
return ValidationInvalid
return out, ValidationInvalid
}
v = v[n:]
}
case validationTypeRepeatedFixed32:
// Packed field.
if len(v)%4 != 0 {
return ValidationInvalid
return out, ValidationInvalid
}
case validationTypeRepeatedFixed64:
// Packed field.
if len(v)%8 != 0 {
return ValidationInvalid
return out, ValidationInvalid
}
case validationTypeUTF8String:
if !utf8.Valid(v) {
return ValidationInvalid
return out, ValidationInvalid
}
}
case wire.Fixed32Type:
if len(b) < 4 {
return ValidationInvalid
return out, ValidationInvalid
}
b = b[4:]
case wire.Fixed64Type:
if len(b) < 8 {
return ValidationInvalid
return out, ValidationInvalid
}
b = b[8:]
case wire.StartGroupType:
switch vi.typ {
case validationTypeGroup:
if vi.mi == nil {
return ValidationUnknown
return out, ValidationUnknown
}
vi.mi.init()
states = append(states, validationState{
@ -500,19 +493,19 @@ State:
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return ValidationInvalid
return out, ValidationInvalid
}
b = b[n:]
}
default:
return ValidationInvalid
return out, ValidationInvalid
}
}
if st.endGroup != 0 {
return ValidationInvalid
return out, ValidationInvalid
}
if len(b) != 0 {
return ValidationInvalid
return out, ValidationInvalid
}
b = st.tail
PopState:
@ -535,8 +528,9 @@ State:
}
states = states[:len(states)-1]
}
if !initialized {
return ValidationValidMaybeUninitalized
out.n = start - len(b)
if initialized {
out.initialized = true
}
return ValidationValidInitialized
return out, ValidationValid
}

View File

@ -28,6 +28,7 @@ type testProto struct {
checkFastInit bool
unmarshalOptions proto.UnmarshalOptions
validationStatus impl.ValidationStatus
nocheckValidInit bool
}
func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message {
@ -1045,8 +1046,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional message set (split across multiple tags)",
checkFastInit: false, // fast init checks don't handle split messages
desc: "required field in optional message set (split across multiple tags)",
checkFastInit: false, // fast init checks don't handle split messages
nocheckValidInit: true, // validation doesn't either
decodeTo: makeMessages(protobuild.Message{
"optional_message": protobuild.Message{
"required_field": 1,
@ -1058,7 +1060,6 @@ var testValidMessages = []testProto{
pack.Tag{1, pack.VarintType}, pack.Varint(1),
}),
}.Marshal(),
validationStatus: impl.ValidationValidMaybeUninitalized,
},
{
desc: "required field in repeated message unset",

View File

@ -23,16 +23,18 @@ func TestValidateValid(t *testing.T) {
for _, m := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
mt := m.ProtoReflect().Type()
want := impl.ValidationValidInitialized
want := impl.ValidationValid
if test.validationStatus != 0 {
want = test.validationStatus
} else if test.partial {
want = impl.ValidationValidMaybeUninitalized
}
var opts piface.UnmarshalOptions
opts.Resolver = protoregistry.GlobalTypes
if got, want := impl.Validate(test.wire, mt, opts), want; got != want {
t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
out, status := impl.Validate(test.wire, mt, opts)
if status != want {
t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
}
if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
}
})
}
@ -46,7 +48,9 @@ func TestValidateInvalid(t *testing.T) {
mt := m.ProtoReflect().Type()
var opts piface.UnmarshalOptions
opts.Resolver = protoregistry.GlobalTypes
if got, want := impl.Validate(test.wire, mt, opts), impl.ValidationInvalid; got != want {
_, got := impl.Validate(test.wire, mt, opts)
want := impl.ValidationInvalid
if got != want {
t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
}
})