mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-08 19:14:05 +00:00
internal/impl: refactor validation a bit
Return the size of the field read from the validator, permitting us to avoid an extra parse when skipping over groups. Return an UnmarshalOutput from the validator, since it already combines two of the validator outputs: bytes read and initialization status. Remove initialization status from the ValidationStatus enum, since it's covered by the UnmarshalOutput. Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763 Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
parent
9b3d97c473
commit
cadb4ab3b1
@ -55,7 +55,9 @@ func BenchmarkEmptyMessage(b *testing.B) {
|
||||
Resolver: protoregistry.GlobalTypes,
|
||||
}
|
||||
for pb.Next() {
|
||||
if got, want := impl.Validate([]byte{}, mt, opts), impl.ValidationValidInitialized; got != want {
|
||||
_, got := impl.Validate([]byte{}, mt, opts)
|
||||
want := impl.ValidationValid
|
||||
if got != want {
|
||||
b.Fatalf("Validate = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@ -106,7 +108,9 @@ func BenchmarkRepeatedInt32(b *testing.B) {
|
||||
Resolver: protoregistry.GlobalTypes,
|
||||
}
|
||||
for pb.Next() {
|
||||
if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
|
||||
_, got := impl.Validate(w, mt, opts)
|
||||
want := impl.ValidationValid
|
||||
if got != want {
|
||||
b.Fatalf("Validate = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@ -167,7 +171,9 @@ func BenchmarkRequired(b *testing.B) {
|
||||
Resolver: protoregistry.GlobalTypes,
|
||||
}
|
||||
for pb.Next() {
|
||||
if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
|
||||
_, got := impl.Validate(w, mt, opts)
|
||||
want := impl.ValidationValid
|
||||
if got != want {
|
||||
b.Fatalf("Validate = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ import (
|
||||
// Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
|
||||
func Fuzz(data []byte) (score int) {
|
||||
m1 := &fuzzpb.Fuzz{}
|
||||
valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
|
||||
vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
|
||||
Resolver: protoregistry.GlobalTypes,
|
||||
})
|
||||
if err := (proto.UnmarshalOptions{
|
||||
@ -33,21 +33,14 @@ func Fuzz(data []byte) (score int) {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
if proto.IsInitialized(m1) == nil {
|
||||
switch valid {
|
||||
case impl.ValidationUnknown:
|
||||
case impl.ValidationValidInitialized:
|
||||
case impl.ValidationValidMaybeUninitalized:
|
||||
default:
|
||||
panic("unmarshal ok with validation status: " + valid.String())
|
||||
}
|
||||
} else {
|
||||
switch valid {
|
||||
case impl.ValidationUnknown:
|
||||
case impl.ValidationValidMaybeUninitalized:
|
||||
default:
|
||||
panic("partial unmarshal ok with validation status: " + valid.String())
|
||||
}
|
||||
switch valid {
|
||||
case impl.ValidationUnknown:
|
||||
case impl.ValidationValid:
|
||||
default:
|
||||
panic("unmarshal ok with validation status: " + valid.String())
|
||||
}
|
||||
if proto.IsInitialized(m1) != nil && vout.Initialized {
|
||||
panic("validation reports partial message is initialized")
|
||||
}
|
||||
data1, err := proto.MarshalOptions{
|
||||
AllowPartial: true,
|
||||
|
@ -196,11 +196,9 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
|
||||
}
|
||||
if flags.LazyUnmarshalExtensions {
|
||||
if opts.IsDefault() && x.canLazy(xt) {
|
||||
if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
|
||||
x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
|
||||
if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
|
||||
x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
|
||||
exts[int32(num)] = x
|
||||
out.n = n
|
||||
out.initialized = true
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
@ -224,35 +222,31 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
|
||||
func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
|
||||
if xi.validation.mi == nil {
|
||||
return 0, false
|
||||
return out, false
|
||||
}
|
||||
xi.validation.mi.init()
|
||||
var v []byte
|
||||
switch xi.validation.typ {
|
||||
case validationTypeMessage:
|
||||
if wtyp != wire.BytesType {
|
||||
return 0, false
|
||||
return out, false
|
||||
}
|
||||
v, n = wire.ConsumeBytes(b)
|
||||
v, n := wire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return 0, false
|
||||
return out, false
|
||||
}
|
||||
out, st := xi.validation.mi.validate(v, 0, opts)
|
||||
out.n = n
|
||||
return out, st == ValidationValid
|
||||
case validationTypeGroup:
|
||||
if wtyp != wire.StartGroupType {
|
||||
return 0, false
|
||||
}
|
||||
v, n = wire.ConsumeGroup(num, b)
|
||||
if n < 0 {
|
||||
return 0, false
|
||||
return out, false
|
||||
}
|
||||
out, st := xi.validation.mi.validate(v, num, opts)
|
||||
return out, st == ValidationValid
|
||||
default:
|
||||
return 0, false
|
||||
return out, false
|
||||
}
|
||||
if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
|
||||
}
|
||||
|
@ -33,16 +33,8 @@ const (
|
||||
// ValidationInvalid indicates that unmarshaling the message will fail.
|
||||
ValidationInvalid
|
||||
|
||||
// ValidationValidInitialized indicates that unmarshaling the message will succeed
|
||||
// and IsInitialized on the result will report success.
|
||||
ValidationValidInitialized
|
||||
|
||||
// ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
|
||||
// but the output of IsInitialized on the result is unknown.
|
||||
//
|
||||
// This status may be returned for an initialized message when a message value
|
||||
// is split across multiple fields.
|
||||
ValidationValidMaybeUninitalized
|
||||
// ValidationValid indicates that unmarshaling the message will succeed.
|
||||
ValidationValid
|
||||
)
|
||||
|
||||
func (v ValidationStatus) String() string {
|
||||
@ -51,10 +43,8 @@ func (v ValidationStatus) String() string {
|
||||
return "ValidationUnknown"
|
||||
case ValidationInvalid:
|
||||
return "ValidationInvalid"
|
||||
case ValidationValidInitialized:
|
||||
return "ValidationValidInitialized"
|
||||
case ValidationValidMaybeUninitalized:
|
||||
return "ValidationValidMaybeUninitalized"
|
||||
case ValidationValid:
|
||||
return "ValidationValid"
|
||||
default:
|
||||
return fmt.Sprintf("ValidationStatus(%d)", int(v))
|
||||
}
|
||||
@ -64,12 +54,14 @@ func (v ValidationStatus) String() string {
|
||||
// of the message type.
|
||||
//
|
||||
// This function is exposed for testing.
|
||||
func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
|
||||
func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
|
||||
mi, ok := mt.(*MessageInfo)
|
||||
if !ok {
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
return mi.validate(b, 0, unmarshalOptions(opts))
|
||||
o, st := mi.validate(b, 0, unmarshalOptions(opts))
|
||||
out.Initialized = o.initialized
|
||||
return out, st
|
||||
}
|
||||
|
||||
type validationInfo struct {
|
||||
@ -219,7 +211,7 @@ func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo
|
||||
return vi
|
||||
}
|
||||
|
||||
func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
|
||||
func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
|
||||
mi.init()
|
||||
type validationState struct {
|
||||
typ validationType
|
||||
@ -241,12 +233,13 @@ func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOp
|
||||
states[0].endGroup = groupTag
|
||||
}
|
||||
initialized := true
|
||||
start := len(b)
|
||||
State:
|
||||
for len(states) > 0 {
|
||||
st := &states[len(states)-1]
|
||||
if st.mi != nil {
|
||||
if flags.ProtoLegacy && st.mi.isMessageSet {
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
}
|
||||
for len(b) > 0 {
|
||||
@ -262,13 +255,13 @@ State:
|
||||
var n int
|
||||
tag, n = wire.ConsumeVarint(b)
|
||||
if n < 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
var num wire.Number
|
||||
if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
} else {
|
||||
num = wire.Number(n)
|
||||
}
|
||||
@ -278,7 +271,7 @@ State:
|
||||
if st.endGroup == num {
|
||||
goto PopState
|
||||
}
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
var vi validationInfo
|
||||
switch st.typ {
|
||||
@ -317,7 +310,7 @@ State:
|
||||
case preg.NotFound:
|
||||
vi.typ = validationTypeBytes
|
||||
default:
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
}
|
||||
break
|
||||
@ -332,7 +325,7 @@ State:
|
||||
// determine if the resolver is frozen.
|
||||
xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
|
||||
if err != nil && err != preg.NotFound {
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
if err == nil {
|
||||
vi = getExtensionFieldInfo(xt).validation
|
||||
@ -383,7 +376,7 @@ State:
|
||||
case b[9] < 0x80 && b[9] < 2:
|
||||
b = b[10:]
|
||||
default:
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
} else {
|
||||
switch {
|
||||
@ -408,7 +401,7 @@ State:
|
||||
case len(b) > 9 && b[9] < 2:
|
||||
b = b[10:]
|
||||
default:
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
}
|
||||
continue State
|
||||
@ -424,19 +417,19 @@ State:
|
||||
var n int
|
||||
size, n = wire.ConsumeVarint(b)
|
||||
if n < 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
if size > uint64(len(b)) {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
v := b[:size]
|
||||
b = b[size:]
|
||||
switch vi.typ {
|
||||
case validationTypeMessage:
|
||||
if vi.mi == nil {
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
vi.mi.init()
|
||||
fallthrough
|
||||
@ -455,40 +448,40 @@ State:
|
||||
for len(v) > 0 {
|
||||
_, n := wire.ConsumeVarint(v)
|
||||
if n < 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
v = v[n:]
|
||||
}
|
||||
case validationTypeRepeatedFixed32:
|
||||
// Packed field.
|
||||
if len(v)%4 != 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
case validationTypeRepeatedFixed64:
|
||||
// Packed field.
|
||||
if len(v)%8 != 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
case validationTypeUTF8String:
|
||||
if !utf8.Valid(v) {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
}
|
||||
case wire.Fixed32Type:
|
||||
if len(b) < 4 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = b[4:]
|
||||
case wire.Fixed64Type:
|
||||
if len(b) < 8 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = b[8:]
|
||||
case wire.StartGroupType:
|
||||
switch vi.typ {
|
||||
case validationTypeGroup:
|
||||
if vi.mi == nil {
|
||||
return ValidationUnknown
|
||||
return out, ValidationUnknown
|
||||
}
|
||||
vi.mi.init()
|
||||
states = append(states, validationState{
|
||||
@ -500,19 +493,19 @@ State:
|
||||
default:
|
||||
n := wire.ConsumeFieldValue(num, wtyp, b)
|
||||
if n < 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
default:
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
}
|
||||
if st.endGroup != 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
if len(b) != 0 {
|
||||
return ValidationInvalid
|
||||
return out, ValidationInvalid
|
||||
}
|
||||
b = st.tail
|
||||
PopState:
|
||||
@ -535,8 +528,9 @@ State:
|
||||
}
|
||||
states = states[:len(states)-1]
|
||||
}
|
||||
if !initialized {
|
||||
return ValidationValidMaybeUninitalized
|
||||
out.n = start - len(b)
|
||||
if initialized {
|
||||
out.initialized = true
|
||||
}
|
||||
return ValidationValidInitialized
|
||||
return out, ValidationValid
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ type testProto struct {
|
||||
checkFastInit bool
|
||||
unmarshalOptions proto.UnmarshalOptions
|
||||
validationStatus impl.ValidationStatus
|
||||
nocheckValidInit bool
|
||||
}
|
||||
|
||||
func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message {
|
||||
@ -1045,8 +1046,9 @@ var testValidMessages = []testProto{
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "required field in optional message set (split across multiple tags)",
|
||||
checkFastInit: false, // fast init checks don't handle split messages
|
||||
desc: "required field in optional message set (split across multiple tags)",
|
||||
checkFastInit: false, // fast init checks don't handle split messages
|
||||
nocheckValidInit: true, // validation doesn't either
|
||||
decodeTo: makeMessages(protobuild.Message{
|
||||
"optional_message": protobuild.Message{
|
||||
"required_field": 1,
|
||||
@ -1058,7 +1060,6 @@ var testValidMessages = []testProto{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(1),
|
||||
}),
|
||||
}.Marshal(),
|
||||
validationStatus: impl.ValidationValidMaybeUninitalized,
|
||||
},
|
||||
{
|
||||
desc: "required field in repeated message unset",
|
||||
|
@ -23,16 +23,18 @@ func TestValidateValid(t *testing.T) {
|
||||
for _, m := range test.decodeTo {
|
||||
t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
|
||||
mt := m.ProtoReflect().Type()
|
||||
want := impl.ValidationValidInitialized
|
||||
want := impl.ValidationValid
|
||||
if test.validationStatus != 0 {
|
||||
want = test.validationStatus
|
||||
} else if test.partial {
|
||||
want = impl.ValidationValidMaybeUninitalized
|
||||
}
|
||||
var opts piface.UnmarshalOptions
|
||||
opts.Resolver = protoregistry.GlobalTypes
|
||||
if got, want := impl.Validate(test.wire, mt, opts), want; got != want {
|
||||
t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
|
||||
out, status := impl.Validate(test.wire, mt, opts)
|
||||
if status != want {
|
||||
t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
|
||||
}
|
||||
if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
|
||||
t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -46,7 +48,9 @@ func TestValidateInvalid(t *testing.T) {
|
||||
mt := m.ProtoReflect().Type()
|
||||
var opts piface.UnmarshalOptions
|
||||
opts.Resolver = protoregistry.GlobalTypes
|
||||
if got, want := impl.Validate(test.wire, mt, opts), impl.ValidationInvalid; got != want {
|
||||
_, got := impl.Validate(test.wire, mt, opts)
|
||||
want := impl.ValidationInvalid
|
||||
if got != want {
|
||||
t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
|
||||
}
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user