all: do best-effort initialization check on fast path unmarshal

Add a fast check for required fields to the fast path unmarshal.
This is best-effort and will fail to detect some initialized
messages: Messages with more than 64 required fields, messages
split across multiple tags, possibly other cases.

In the cases where it works (which is most of them in practice),
this permits us to skip the IsInitialized check.

Change-Id: I6b70953a333033a5e64fb7ca37a59786cb0f75a0
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215878
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-01-21 15:00:33 -08:00
parent d30e561d9e
commit c600d6c086
7 changed files with 185 additions and 64 deletions

View File

@ -13,6 +13,7 @@ import (
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
piface "google.golang.org/protobuf/runtime/protoiface"
)
type errInvalidUTF8 struct{}
@ -227,10 +228,12 @@ func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, op
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem())))
}
if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil {
o, err := mi.unmarshalPointer(v, p.Elem(), 0, opts)
if err != nil {
return out, err
}
out.n = n
out.initialized = o.initialized
return out, nil
}
@ -252,10 +255,14 @@ func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOpt
if n < 0 {
return out, wire.ParseError(n)
}
if err := opts.Options().Unmarshal(v, m); err != nil {
o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{
Buf: v,
})
if err != nil {
return out, err
}
out.n = n
out.initialized = o.Initialized
return out, nil
}
@ -395,8 +402,15 @@ func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, op
if n < 0 {
return out, wire.ParseError(n)
}
o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{
Buf: b,
})
if err != nil {
return out, err
}
out.n = n
return out, opts.Options().Unmarshal(b, m)
out.initialized = o.Initialized
return out, nil
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@ -469,11 +483,13 @@ func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Typ
}
m := reflect.New(mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil {
o, err := mi.unmarshalPointer(v, mp, 0, opts)
if err != nil {
return out, err
}
p.AppendPointerSlice(mp)
out.n = n
out.initialized = o.initialized
return out, nil
}
@ -522,11 +538,15 @@ func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Typ
return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil {
o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{
Buf: v,
})
if err != nil {
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
out.n = n
out.initialized = o.Initialized
return out, nil
}
@ -580,11 +600,15 @@ func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wi
return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(v, m.Message().Interface()); err != nil {
o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{
Buf: v,
})
if err != nil {
return pref.Value{}, out, err
}
list.Append(m)
out.n = n
out.initialized = o.Initialized
return listv, out, nil
}
@ -642,11 +666,15 @@ func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wi
return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(b, m.Message().Interface()); err != nil {
o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{
Buf: b,
})
if err != nil {
return pref.Value{}, out, err
}
list.Append(m)
out.n = n
out.initialized = o.Initialized
return listv, out, nil
}
@ -728,11 +756,15 @@ func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goT
return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil {
o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{
Buf: b,
})
if err != nil {
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
out.n = n
out.initialized = o.Initialized
return out, nil
}

View File

@ -202,7 +202,13 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *map
if n < 0 {
return out, wire.ParseError(n)
}
_, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
var o unmarshalOutput
o, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
if o.initialized {
// Consider this map item initialized so long as we see
// an initialized value.
out.initialized = true
}
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)

View File

@ -5,6 +5,8 @@
package impl
import (
"math/bits"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
@ -58,7 +60,8 @@ func (o unmarshalOptions) DiscardUnknown() bool { return o.flags
func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
type unmarshalOutput struct {
n int // number of bytes consumed
n int // number of bytes consumed
initialized bool
}
// unmarshal is protoreflect.Methods.Unmarshal.
@ -69,8 +72,10 @@ func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput, opts
} else {
p = m.(*messageReflectWrapper).pointer()
}
_, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
return piface.UnmarshalOutput{}, err
out, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
return piface.UnmarshalOutput{
Initialized: out.initialized,
}, err
}
// errUnknown is returned during unmarshaling to indicate a parse error that
@ -86,6 +91,8 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
}
initialized := true
var requiredMask uint64
var exts *map[int32]ExtensionField
start := len(b)
for len(b) > 0 {
@ -104,8 +111,8 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if num != groupTag {
return out, errors.New("mismatching end group marker")
}
out.n = start - len(b)
return out, nil
groupTag = 0
break
}
var f *coderFieldInfo
@ -123,6 +130,12 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
var o unmarshalOutput
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
n = o.n
if reqi := f.validation.requiredIndex; reqi > 0 && err == nil {
requiredMask |= 1 << (reqi - 1)
}
if f.funcs.isInit != nil && !o.initialized {
initialized = false
}
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
@ -137,6 +150,9 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
var o unmarshalOutput
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
n = o.n
if !o.initialized {
initialized = false
}
}
if err != nil {
if err != errUnknown {
@ -157,7 +173,13 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if groupTag != 0 {
return out, errors.New("missing end group marker")
}
out.n = start
if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
initialized = false
}
if initialized {
out.initialized = true
}
out.n = start - len(b)
return out, nil
}

View File

@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoiface"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
)
@ -129,3 +130,28 @@ func TestSelfMarshalerWithDescriptor(t *testing.T) {
t.Fatalf("proto.Marshal(%v) = %v, %v; want %v, nil", m, got, err, want)
}
}
func TestDecodeFastIsInitialized(t *testing.T) {
for _, test := range testValidMessages {
if !test.checkFastInit {
continue
}
for _, message := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, message), func(t *testing.T) {
m := message.ProtoReflect().New()
opts := proto.UnmarshalOptions{
AllowPartial: true,
}
out, err := opts.UnmarshalState(m.Interface(), protoiface.UnmarshalInput{
Buf: test.wire,
})
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if got, want := out.Initialized, !test.partial; got != want {
t.Errorf("out.Initialized = %v, want %v", got, want)
}
})
}
}
}

View File

@ -23,6 +23,7 @@ type testProto struct {
wire []byte
partial bool
noEncode bool
checkFastInit bool
validationStatus impl.ValidationStatus
}
@ -1150,17 +1151,20 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in nil message unset",
partial: true,
decodeTo: []proto.Message{(*testpb.TestRequired)(nil)},
desc: "required field in nil message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{(*testpb.TestRequired)(nil)},
},
{
desc: "required int32 unset",
partial: true,
decodeTo: []proto.Message{&requiredpb.Int32{}},
desc: "required int32 unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&requiredpb.Int32{}},
},
{
desc: "required int32 set",
desc: "required int32 set",
checkFastInit: true,
decodeTo: []proto.Message{&requiredpb.Int32{
V: proto.Int32(1),
}},
@ -1169,12 +1173,14 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required fixed32 unset",
partial: true,
decodeTo: []proto.Message{&requiredpb.Fixed32{}},
desc: "required fixed32 unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&requiredpb.Fixed32{}},
},
{
desc: "required fixed32 set",
desc: "required fixed32 set",
checkFastInit: true,
decodeTo: []proto.Message{&requiredpb.Fixed32{
V: proto.Uint32(1),
}},
@ -1183,12 +1189,14 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required fixed64 unset",
partial: true,
decodeTo: []proto.Message{&requiredpb.Fixed64{}},
desc: "required fixed64 unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&requiredpb.Fixed64{}},
},
{
desc: "required fixed64 set",
desc: "required fixed64 set",
checkFastInit: true,
decodeTo: []proto.Message{&requiredpb.Fixed64{
V: proto.Uint64(1),
}},
@ -1197,12 +1205,14 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required bytes unset",
partial: true,
decodeTo: []proto.Message{&requiredpb.Bytes{}},
desc: "required bytes unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&requiredpb.Bytes{}},
},
{
desc: "required bytes set",
desc: "required bytes set",
checkFastInit: true,
decodeTo: []proto.Message{&requiredpb.Bytes{
V: []byte{},
}},
@ -1211,8 +1221,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field with incompatible wire type",
partial: true,
desc: "required field with incompatible wire type",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{build(
&testpb.TestRequired{},
unknown(pack.Message{
@ -1224,8 +1235,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional message unset",
partial: true,
desc: "required field in optional message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{},
}},
@ -1234,7 +1246,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional message set",
desc: "required field in optional message set",
checkFastInit: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{
RequiredField: proto.Int32(1),
@ -1247,7 +1260,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional message set (split across multiple tags)",
desc: "required field in optional message set (split across multiple tags)",
checkFastInit: false, // fast init checks don't handle split messages
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{
RequiredField: proto.Int32(1),
@ -1262,8 +1276,9 @@ var testValidMessages = []testProto{
validationStatus: impl.ValidationValidMaybeUninitalized,
},
{
desc: "required field in repeated message unset",
partial: true,
desc: "required field in repeated message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
RepeatedMessage: []*testpb.TestRequired{
{RequiredField: proto.Int32(1)},
@ -1278,7 +1293,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in repeated message set",
desc: "required field in repeated message set",
checkFastInit: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
RepeatedMessage: []*testpb.TestRequired{
{RequiredField: proto.Int32(1)},
@ -1295,8 +1311,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in map message unset",
partial: true,
desc: "required field in map message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
1: {RequiredField: proto.Int32(1)},
@ -1317,8 +1334,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in absent map message value",
partial: true,
desc: "required field in absent map message value",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
2: {},
@ -1331,7 +1349,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in map message set",
desc: "required field in map message set",
checkFastInit: true,
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
1: {RequiredField: proto.Int32(1)},
@ -1354,8 +1373,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional group unset",
partial: true,
desc: "required field in optional group unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{},
}},
@ -1365,7 +1385,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in optional group set",
desc: "required field in optional group set",
checkFastInit: true,
decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{
A: proto.Int32(1),
@ -1378,8 +1399,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in repeated group unset",
partial: true,
desc: "required field in repeated group unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
{A: proto.Int32(1)},
@ -1395,7 +1417,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in repeated group set",
desc: "required field in repeated group set",
checkFastInit: true,
decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
{A: proto.Int32(1)},
@ -1412,8 +1435,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in oneof message unset",
partial: true,
desc: "required field in oneof message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{
&testpb.TestRequiredForeign{OneofField: &testpb.TestRequiredForeign_OneofMessage{
&testpb.TestRequired{},
@ -1422,7 +1446,8 @@ var testValidMessages = []testProto{
wire: pack.Message{pack.Tag{4, pack.BytesType}, pack.LengthPrefix(pack.Message{})}.Marshal(),
},
{
desc: "required field in oneof message set",
desc: "required field in oneof message set",
checkFastInit: true,
decodeTo: []proto.Message{
&testpb.TestRequiredForeign{OneofField: &testpb.TestRequiredForeign_OneofMessage{
&testpb.TestRequired{
@ -1435,8 +1460,9 @@ var testValidMessages = []testProto{
})}.Marshal(),
},
{
desc: "required field in extension message unset",
partial: true,
desc: "required field in extension message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{build(
&testpb.TestAllExtensions{},
extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}),
@ -1446,7 +1472,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in extension message set",
desc: "required field in extension message set",
checkFastInit: true,
decodeTo: []proto.Message{build(
&testpb.TestAllExtensions{},
extend(testpb.E_TestRequired_Single, &testpb.TestRequired{
@ -1460,8 +1487,9 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in repeated extension message unset",
partial: true,
desc: "required field in repeated extension message unset",
checkFastInit: true,
partial: true,
decodeTo: []proto.Message{build(
&testpb.TestAllExtensions{},
extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
@ -1477,7 +1505,8 @@ var testValidMessages = []testProto{
}.Marshal(),
},
{
desc: "required field in repeated extension message set",
desc: "required field in repeated extension message set",
checkFastInit: true,
decodeTo: []proto.Message{build(
&testpb.TestAllExtensions{},
extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{

View File

@ -44,6 +44,7 @@ type (
}
unmarshalOutput = struct {
pragma.NoUnkeyedLiterals
Initialized bool
}
unmarshalOptions = struct {
pragma.NoUnkeyedLiterals

View File

@ -27,9 +27,11 @@ type Methods = struct {
// Marshal writes the wire-format encoding of m to the provided buffer.
// Size should be provided if a custom MarshalAppend is provided.
// It should not return an error for a partial message.
Marshal func(m protoreflect.Message, in MarshalInput, opts MarshalOptions) (MarshalOutput, error)
// Unmarshal parses the wire-format encoding of a message and merges the result to m.
// It should not reset the target message or return an error for a partial message.
Unmarshal func(m protoreflect.Message, in UnmarshalInput, opts UnmarshalOptions) (UnmarshalOutput, error)
// IsInitialized returns an error if any required fields in m are not set.
@ -82,7 +84,10 @@ type UnmarshalInput = struct {
type UnmarshalOutput = struct {
pragma.NoUnkeyedLiterals
// Contents available for future expansion.
// Initialized may be set on return if all required fields are known to be set.
// A value of false does not indicate that the message is uninitialized, only
// that its status could not be confirmed.
Initialized bool
}
// UnmarshalOptions configures the unmarshaler.