From c600d6c086a67c0ccaaa6b5b3b68215495300c02 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 21 Jan 2020 15:00:33 -0800 Subject: [PATCH] all: do best-effort initialization check on fast path unmarshal Add a fast check for required fields to the fast path unmarshal. This is best-effort and will fail to detect some initialized messages: Messages with more than 64 required fields, messages split across multiple tags, possibly other cases. In the cases where it works (which is most of them in practice), this permits us to skip the IsInitialized check. Change-Id: I6b70953a333033a5e64fb7ca37a59786cb0f75a0 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215878 Reviewed-by: Joe Tsai --- internal/impl/codec_field.go | 48 ++++++++++-- internal/impl/codec_map.go | 8 +- internal/impl/decode.go | 34 +++++++-- proto/methods_test.go | 26 +++++++ proto/testmessages_test.go | 125 ++++++++++++++++++++------------ reflect/protoreflect/methods.go | 1 + runtime/protoiface/methods.go | 7 +- 7 files changed, 185 insertions(+), 64 deletions(-) diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go index f1f0671c..433dacb4 100644 --- a/internal/impl/codec_field.go +++ b/internal/impl/codec_field.go @@ -13,6 +13,7 @@ import ( "google.golang.org/protobuf/proto" pref "google.golang.org/protobuf/reflect/protoreflect" preg "google.golang.org/protobuf/reflect/protoregistry" + piface "google.golang.org/protobuf/runtime/protoiface" ) type errInvalidUTF8 struct{} @@ -227,10 +228,12 @@ func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, op if p.Elem().IsNil() { p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem()))) } - if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil { + o, err := mi.unmarshalPointer(v, p.Elem(), 0, opts) + if err != nil { return out, err } out.n = n + out.initialized = o.initialized return out, nil } @@ -252,10 +255,14 @@ func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOpt if n < 0 { return out, wire.ParseError(n) } - if err := opts.Options().Unmarshal(v, m); err != nil { + o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{ + Buf: v, + }) + if err != nil { return out, err } out.n = n + out.initialized = o.Initialized return out, nil } @@ -395,8 +402,15 @@ func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, op if n < 0 { return out, wire.ParseError(n) } + o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{ + Buf: b, + }) + if err != nil { + return out, err + } out.n = n - return out, opts.Options().Unmarshal(b, m) + out.initialized = o.Initialized + return out, nil } func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs { @@ -469,11 +483,13 @@ func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Typ } m := reflect.New(mi.GoReflectType.Elem()).Interface() mp := pointerOfIface(m) - if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil { + o, err := mi.unmarshalPointer(v, mp, 0, opts) + if err != nil { return out, err } p.AppendPointerSlice(mp) out.n = n + out.initialized = o.initialized return out, nil } @@ -522,11 +538,15 @@ func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Typ return out, wire.ParseError(n) } mp := reflect.New(goType.Elem()) - if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil { + o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{ + Buf: v, + }) + if err != nil { return out, err } p.AppendPointerSlice(pointerOfValue(mp)) out.n = n + out.initialized = o.Initialized return out, nil } @@ -580,11 +600,15 @@ func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wi return pref.Value{}, out, wire.ParseError(n) } m := list.NewElement() - if err := opts.Options().Unmarshal(v, m.Message().Interface()); err != nil { + o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{ + Buf: v, + }) + if err != nil { return pref.Value{}, out, err } list.Append(m) out.n = n + out.initialized = o.Initialized return listv, out, nil } @@ -642,11 +666,15 @@ func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wi return pref.Value{}, out, wire.ParseError(n) } m := list.NewElement() - if err := opts.Options().Unmarshal(b, m.Message().Interface()); err != nil { + o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{ + Buf: b, + }) + if err != nil { return pref.Value{}, out, err } list.Append(m) out.n = n + out.initialized = o.Initialized return listv, out, nil } @@ -728,11 +756,15 @@ func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goT return out, wire.ParseError(n) } mp := reflect.New(goType.Elem()) - if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil { + o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{ + Buf: b, + }) + if err != nil { return out, err } p.AppendPointerSlice(pointerOfValue(mp)) out.n = n + out.initialized = o.Initialized return out, nil } diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go index b69ee1aa..c8c09255 100644 --- a/internal/impl/codec_map.go +++ b/internal/impl/codec_map.go @@ -202,7 +202,13 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *map if n < 0 { return out, wire.ParseError(n) } - _, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts) + var o unmarshalOutput + o, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts) + if o.initialized { + // Consider this map item initialized so long as we see + // an initialized value. + out.initialized = true + } } if err == errUnknown { n = wire.ConsumeFieldValue(num, wtyp, b) diff --git a/internal/impl/decode.go b/internal/impl/decode.go index 5427317b..fc93525a 100644 --- a/internal/impl/decode.go +++ b/internal/impl/decode.go @@ -5,6 +5,8 @@ package impl import ( + "math/bits" + "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/flags" @@ -58,7 +60,8 @@ func (o unmarshalOptions) DiscardUnknown() bool { return o.flags func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver } type unmarshalOutput struct { - n int // number of bytes consumed + n int // number of bytes consumed + initialized bool } // unmarshal is protoreflect.Methods.Unmarshal. @@ -69,8 +72,10 @@ func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput, opts } else { p = m.(*messageReflectWrapper).pointer() } - _, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts)) - return piface.UnmarshalOutput{}, err + out, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts)) + return piface.UnmarshalOutput{ + Initialized: out.initialized, + }, err } // errUnknown is returned during unmarshaling to indicate a parse error that @@ -86,6 +91,8 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe if flags.ProtoLegacy && mi.isMessageSet { return unmarshalMessageSet(mi, b, p, opts) } + initialized := true + var requiredMask uint64 var exts *map[int32]ExtensionField start := len(b) for len(b) > 0 { @@ -104,8 +111,8 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe if num != groupTag { return out, errors.New("mismatching end group marker") } - out.n = start - len(b) - return out, nil + groupTag = 0 + break } var f *coderFieldInfo @@ -123,6 +130,12 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe var o unmarshalOutput o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts) n = o.n + if reqi := f.validation.requiredIndex; reqi > 0 && err == nil { + requiredMask |= 1 << (reqi - 1) + } + if f.funcs.isInit != nil && !o.initialized { + initialized = false + } default: // Possible extension. if exts == nil && mi.extensionOffset.IsValid() { @@ -137,6 +150,9 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe var o unmarshalOutput o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) n = o.n + if !o.initialized { + initialized = false + } } if err != nil { if err != errUnknown { @@ -157,7 +173,13 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe if groupTag != 0 { return out, errors.New("missing end group marker") } - out.n = start + if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { + initialized = false + } + if initialized { + out.initialized = true + } + out.n = start - len(b) return out, nil } diff --git a/proto/methods_test.go b/proto/methods_test.go index 91019fa3..ee8d6f16 100644 --- a/proto/methods_test.go +++ b/proto/methods_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoiface" legacypb "google.golang.org/protobuf/internal/testprotos/legacy" ) @@ -129,3 +130,28 @@ func TestSelfMarshalerWithDescriptor(t *testing.T) { t.Fatalf("proto.Marshal(%v) = %v, %v; want %v, nil", m, got, err, want) } } + +func TestDecodeFastIsInitialized(t *testing.T) { + for _, test := range testValidMessages { + if !test.checkFastInit { + continue + } + for _, message := range test.decodeTo { + t.Run(fmt.Sprintf("%s (%T)", test.desc, message), func(t *testing.T) { + m := message.ProtoReflect().New() + opts := proto.UnmarshalOptions{ + AllowPartial: true, + } + out, err := opts.UnmarshalState(m.Interface(), protoiface.UnmarshalInput{ + Buf: test.wire, + }) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if got, want := out.Initialized, !test.partial; got != want { + t.Errorf("out.Initialized = %v, want %v", got, want) + } + }) + } + } +} diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go index 39ddeb30..8a7cc296 100644 --- a/proto/testmessages_test.go +++ b/proto/testmessages_test.go @@ -23,6 +23,7 @@ type testProto struct { wire []byte partial bool noEncode bool + checkFastInit bool validationStatus impl.ValidationStatus } @@ -1150,17 +1151,20 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in nil message unset", - partial: true, - decodeTo: []proto.Message{(*testpb.TestRequired)(nil)}, + desc: "required field in nil message unset", + checkFastInit: true, + partial: true, + decodeTo: []proto.Message{(*testpb.TestRequired)(nil)}, }, { - desc: "required int32 unset", - partial: true, - decodeTo: []proto.Message{&requiredpb.Int32{}}, + desc: "required int32 unset", + checkFastInit: true, + partial: true, + decodeTo: []proto.Message{&requiredpb.Int32{}}, }, { - desc: "required int32 set", + desc: "required int32 set", + checkFastInit: true, decodeTo: []proto.Message{&requiredpb.Int32{ V: proto.Int32(1), }}, @@ -1169,12 +1173,14 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required fixed32 unset", - partial: true, - decodeTo: []proto.Message{&requiredpb.Fixed32{}}, + desc: "required fixed32 unset", + checkFastInit: true, + partial: true, + decodeTo: []proto.Message{&requiredpb.Fixed32{}}, }, { - desc: "required fixed32 set", + desc: "required fixed32 set", + checkFastInit: true, decodeTo: []proto.Message{&requiredpb.Fixed32{ V: proto.Uint32(1), }}, @@ -1183,12 +1189,14 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required fixed64 unset", - partial: true, - decodeTo: []proto.Message{&requiredpb.Fixed64{}}, + desc: "required fixed64 unset", + checkFastInit: true, + partial: true, + decodeTo: []proto.Message{&requiredpb.Fixed64{}}, }, { - desc: "required fixed64 set", + desc: "required fixed64 set", + checkFastInit: true, decodeTo: []proto.Message{&requiredpb.Fixed64{ V: proto.Uint64(1), }}, @@ -1197,12 +1205,14 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required bytes unset", - partial: true, - decodeTo: []proto.Message{&requiredpb.Bytes{}}, + desc: "required bytes unset", + checkFastInit: true, + partial: true, + decodeTo: []proto.Message{&requiredpb.Bytes{}}, }, { - desc: "required bytes set", + desc: "required bytes set", + checkFastInit: true, decodeTo: []proto.Message{&requiredpb.Bytes{ V: []byte{}, }}, @@ -1211,8 +1221,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field with incompatible wire type", - partial: true, + desc: "required field with incompatible wire type", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{build( &testpb.TestRequired{}, unknown(pack.Message{ @@ -1224,8 +1235,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in optional message unset", - partial: true, + desc: "required field in optional message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ OptionalMessage: &testpb.TestRequired{}, }}, @@ -1234,7 +1246,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in optional message set", + desc: "required field in optional message set", + checkFastInit: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ OptionalMessage: &testpb.TestRequired{ RequiredField: proto.Int32(1), @@ -1247,7 +1260,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in optional message set (split across multiple tags)", + desc: "required field in optional message set (split across multiple tags)", + checkFastInit: false, // fast init checks don't handle split messages decodeTo: []proto.Message{&testpb.TestRequiredForeign{ OptionalMessage: &testpb.TestRequired{ RequiredField: proto.Int32(1), @@ -1262,8 +1276,9 @@ var testValidMessages = []testProto{ validationStatus: impl.ValidationValidMaybeUninitalized, }, { - desc: "required field in repeated message unset", - partial: true, + desc: "required field in repeated message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ RepeatedMessage: []*testpb.TestRequired{ {RequiredField: proto.Int32(1)}, @@ -1278,7 +1293,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in repeated message set", + desc: "required field in repeated message set", + checkFastInit: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ RepeatedMessage: []*testpb.TestRequired{ {RequiredField: proto.Int32(1)}, @@ -1295,8 +1311,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in map message unset", - partial: true, + desc: "required field in map message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ MapMessage: map[int32]*testpb.TestRequired{ 1: {RequiredField: proto.Int32(1)}, @@ -1317,8 +1334,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in absent map message value", - partial: true, + desc: "required field in absent map message value", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ MapMessage: map[int32]*testpb.TestRequired{ 2: {}, @@ -1331,7 +1349,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in map message set", + desc: "required field in map message set", + checkFastInit: true, decodeTo: []proto.Message{&testpb.TestRequiredForeign{ MapMessage: map[int32]*testpb.TestRequired{ 1: {RequiredField: proto.Int32(1)}, @@ -1354,8 +1373,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in optional group unset", - partial: true, + desc: "required field in optional group unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{ Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{}, }}, @@ -1365,7 +1385,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in optional group set", + desc: "required field in optional group set", + checkFastInit: true, decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{ Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{ A: proto.Int32(1), @@ -1378,8 +1399,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in repeated group unset", - partial: true, + desc: "required field in repeated group unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{ Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{ {A: proto.Int32(1)}, @@ -1395,7 +1417,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in repeated group set", + desc: "required field in repeated group set", + checkFastInit: true, decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{ Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{ {A: proto.Int32(1)}, @@ -1412,8 +1435,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in oneof message unset", - partial: true, + desc: "required field in oneof message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{ &testpb.TestRequiredForeign{OneofField: &testpb.TestRequiredForeign_OneofMessage{ &testpb.TestRequired{}, @@ -1422,7 +1446,8 @@ var testValidMessages = []testProto{ wire: pack.Message{pack.Tag{4, pack.BytesType}, pack.LengthPrefix(pack.Message{})}.Marshal(), }, { - desc: "required field in oneof message set", + desc: "required field in oneof message set", + checkFastInit: true, decodeTo: []proto.Message{ &testpb.TestRequiredForeign{OneofField: &testpb.TestRequiredForeign_OneofMessage{ &testpb.TestRequired{ @@ -1435,8 +1460,9 @@ var testValidMessages = []testProto{ })}.Marshal(), }, { - desc: "required field in extension message unset", - partial: true, + desc: "required field in extension message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{build( &testpb.TestAllExtensions{}, extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}), @@ -1446,7 +1472,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in extension message set", + desc: "required field in extension message set", + checkFastInit: true, decodeTo: []proto.Message{build( &testpb.TestAllExtensions{}, extend(testpb.E_TestRequired_Single, &testpb.TestRequired{ @@ -1460,8 +1487,9 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in repeated extension message unset", - partial: true, + desc: "required field in repeated extension message unset", + checkFastInit: true, + partial: true, decodeTo: []proto.Message{build( &testpb.TestAllExtensions{}, extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{ @@ -1477,7 +1505,8 @@ var testValidMessages = []testProto{ }.Marshal(), }, { - desc: "required field in repeated extension message set", + desc: "required field in repeated extension message set", + checkFastInit: true, decodeTo: []proto.Message{build( &testpb.TestAllExtensions{}, extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{ diff --git a/reflect/protoreflect/methods.go b/reflect/protoreflect/methods.go index d15e2bf8..fd4e07b3 100644 --- a/reflect/protoreflect/methods.go +++ b/reflect/protoreflect/methods.go @@ -44,6 +44,7 @@ type ( } unmarshalOutput = struct { pragma.NoUnkeyedLiterals + Initialized bool } unmarshalOptions = struct { pragma.NoUnkeyedLiterals diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go index a79ed31c..d5a7677c 100644 --- a/runtime/protoiface/methods.go +++ b/runtime/protoiface/methods.go @@ -27,9 +27,11 @@ type Methods = struct { // Marshal writes the wire-format encoding of m to the provided buffer. // Size should be provided if a custom MarshalAppend is provided. + // It should not return an error for a partial message. Marshal func(m protoreflect.Message, in MarshalInput, opts MarshalOptions) (MarshalOutput, error) // Unmarshal parses the wire-format encoding of a message and merges the result to m. + // It should not reset the target message or return an error for a partial message. Unmarshal func(m protoreflect.Message, in UnmarshalInput, opts UnmarshalOptions) (UnmarshalOutput, error) // IsInitialized returns an error if any required fields in m are not set. @@ -82,7 +84,10 @@ type UnmarshalInput = struct { type UnmarshalOutput = struct { pragma.NoUnkeyedLiterals - // Contents available for future expansion. + // Initialized may be set on return if all required fields are known to be set. + // A value of false does not indicate that the message is uninitialized, only + // that its status could not be confirmed. + Initialized bool } // UnmarshalOptions configures the unmarshaler.