proto: add IsInitialized

Move all checks for required fields into a proto.IsInitialized function.

Initial testing makes me confident that we can provide a fast-path
implementation of IsInitialized which will perform more than
acceptably.  (In the degenerate-but-common case where a message
transitively contains no required fields, this check can be nearly
zero cost.)

Unifying checks into a single function provides consistent behavior
between the wire, text, and json codecs.

Performing the check after decoding eliminates the wire decoder bug
where a split message is incorrectly seen as missing required fields.

Performing the check after decoding also provides consistent and
arguably more correct behavior when the target message was partially
prepopulated.

Change-Id: I9478b7bebb263af00c0d9f66a1f26e31ff553522
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170787
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Damien Neil 2019-04-05 13:31:40 -07:00
parent 96c229ab14
commit 4686e239b6
11 changed files with 207 additions and 82 deletions

View File

@ -71,6 +71,10 @@ func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error {
if val.Type() != json.EOF {
return unexpectedJSONError{val}
}
if !o.AllowPartial {
nerr.Merge(proto.IsInitialized(m))
}
return nerr.E
}
@ -151,7 +155,6 @@ func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) err
// unmarshalFields unmarshals the fields into the given protoreflect.Message.
func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
var nerr errors.NonFatal
var reqNums set.Ints
var seenNums set.Ints
var seenOneofs set.Ints
@ -251,21 +254,6 @@ Loop:
if err := o.unmarshalSingular(knownFields, fd); !nerr.Merge(err) {
return errors.New("%v|%q: %v", fd.FullName(), name, err)
}
if !o.AllowPartial && cardinality == pref.Required {
reqNums.Set(num)
}
}
}
if !o.AllowPartial {
// Check for any missing required fields.
allReqNums := msgType.RequiredNumbers()
if reqNums.Len() != allReqNums.Len() {
for i := 0; i < allReqNums.Len(); i++ {
if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
}
}
}
}

View File

@ -63,6 +63,9 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
if !nerr.Merge(err) {
return nil, err
}
if !o.AllowPartial {
nerr.Merge(proto.IsInitialized(m))
}
return o.encoder.Bytes(), nerr.E
}
@ -95,10 +98,6 @@ func (o MarshalOptions) marshalFields(m pref.Message) error {
num := fd.Number()
if !knownFields.Has(num) {
if !o.AllowPartial && fd.Cardinality() == pref.Required {
// Treat unset required fields as a non-fatal error.
nerr.AppendRequiredNotSet(string(fd.FullName()))
}
continue
}

View File

@ -64,6 +64,10 @@ func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error {
return err
}
if !o.AllowPartial {
nerr.Merge(proto.IsInitialized(m))
}
return nerr.E
}
@ -102,7 +106,6 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
fieldDescs := msgType.Fields()
reservedNames := msgType.ReservedNames()
xtTypes := knownFields.ExtensionTypes()
var reqNums set.Ints
var seenNums set.Ints
var seenOneofs set.Ints
@ -176,25 +179,10 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
return err
}
if !o.AllowPartial && cardinality == pref.Required {
reqNums.Set(num)
}
seenNums.Set(num)
}
}
if !o.AllowPartial {
// Check for any missing required fields.
allReqNums := msgType.RequiredNumbers()
if reqNums.Len() != allReqNums.Len() {
for i := 0; i < allReqNums.Len(); i++ {
if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
}
}
}
}
return nerr.E
}

View File

@ -63,6 +63,9 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
if !nerr.Merge(err) {
return nil, err
}
if !o.AllowPartial {
nerr.Merge(proto.IsInitialized(m))
}
return b, nerr.E
}
@ -91,10 +94,6 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
num := fd.Number()
if !knownFields.Has(num) {
if !o.AllowPartial && fd.Cardinality() == pref.Required {
// Treat unset required fields as a non-fatal error.
nerr.AppendRequiredNotSet(string(fd.FullName()))
}
continue
}

View File

@ -42,10 +42,18 @@ func Unmarshal(b []byte, m Message) error {
// Unmarshal parses the wire-format message in b and places the result in m.
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
// TODO: Reset m?
if err := o.unmarshalMessageFast(b, m); err != errInternalNoFast {
err := o.unmarshalMessageFast(b, m)
if err == errInternalNoFast {
err = o.unmarshalMessage(b, m.ProtoReflect())
}
var nerr errors.NonFatal
if !nerr.Merge(err) {
return err
}
return o.unmarshalMessage(b, m.ProtoReflect())
if !o.AllowPartial {
nerr.Merge(IsInitialized(m))
}
return nerr.E
}
func (o UnmarshalOptions) unmarshalMessageFast(b []byte, m Message) error {
@ -100,9 +108,6 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
}
b = b[tagLen+valLen:]
}
if !o.AllowPartial {
checkRequiredFields(m, &nerr)
}
return nerr.E
}
@ -204,9 +209,6 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
if !haveVal {
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
if !o.AllowPartial {
checkRequiredFields(val.Message(), &nerr)
}
default:
val = valField.Default()
}

View File

@ -944,23 +944,20 @@ var testProtos = []testProto{
}),
}.Marshal(),
},
// TODO: Handle this case.
/*
{
desc: "required field in optional message set (split across multiple tags)",
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{
RequiredField: scalar.Int32(1),
},
}},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(1),
}),
}.Marshal(),
},
*/
{
desc: "required field in optional message set (split across multiple tags)",
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{
RequiredField: scalar.Int32(1),
},
}},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(1),
}),
}.Marshal(),
},
{
desc: "required field in repeated message unset",
partial: true,

View File

@ -69,10 +69,18 @@ func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
// MarshalAppend appends the wire-format encoding of m to b,
// returning the result.
func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
if b, err := o.marshalMessageFast(b, m); err != errInternalNoFast {
b, err := o.marshalMessageFast(b, m)
if err == errInternalNoFast {
b, err = o.marshalMessage(b, m.ProtoReflect())
}
var nerr errors.NonFatal
if !nerr.Merge(err) {
return b, err
}
return o.marshalMessage(b, m.ProtoReflect())
if !o.AllowPartial {
nerr.Merge(IsInitialized(m))
}
return b, nerr.E
}
func (o MarshalOptions) marshalMessageFast(b []byte, m Message) ([]byte, error) {
@ -129,9 +137,6 @@ func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte
b = append(b, raw...)
return true
})
if !o.AllowPartial {
checkRequiredFields(m, &nerr)
}
return b, nerr.E
}

94
proto/isinit.go Normal file
View File

@ -0,0 +1,94 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
package proto
import (
"bytes"
"fmt"
"github.com/golang/protobuf/v2/internal/errors"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
)
// IsInitialized returns an error if any required fields in m are not set.
func IsInitialized(m Message) error {
if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil {
// TODO: Do we need a way to disable the fast path here?
//
// TODO: Should detailed information about missing
// fields always be provided by the slow-but-informative
// reflective implementation?
return methods.IsInitialized(m)
}
return isInitialized(m.ProtoReflect(), nil)
}
// IsInitialized returns an error if any required fields in m are not set.
func isInitialized(m pref.Message, stack []interface{}) error {
md := m.Type()
known := m.KnownFields()
fields := md.Fields()
for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
num := nums.Get(i)
if !known.Has(num) {
stack = append(stack, fields.ByNumber(num).Name())
return newRequiredNotSetError(stack)
}
}
var err error
known.Range(func(num pref.FieldNumber, v pref.Value) bool {
field := fields.ByNumber(num)
if field == nil {
field = known.ExtensionTypes().ByNumber(num)
}
if field == nil {
panic(fmt.Errorf("no descriptor for field %d in %q", num, md.FullName()))
}
// Look for fields containing a message: Messages, groups, and maps
// with a message or group value.
ft := field.MessageType()
if ft == nil {
return true
}
if field.IsMap() {
if ft.Fields().ByNumber(2).MessageType() == nil {
return true
}
}
// Recurse into the field
stack := append(stack, field.Name())
switch {
case field.IsMap():
v.Map().Range(func(key pref.MapKey, v pref.Value) bool {
stack := append(stack, "[", key, "].")
err = isInitialized(v.Message(), stack)
return err == nil
})
case field.Cardinality() == pref.Repeated:
for i, list := 0, v.List(); i < list.Len(); i++ {
stack := append(stack, "[", i, "].")
err = isInitialized(list.Get(i).Message(), stack)
if err != nil {
break
}
}
default:
stack := append(stack, ".")
err = isInitialized(v.Message(), stack)
}
return err == nil
})
return err
}
func newRequiredNotSetError(stack []interface{}) error {
var buf bytes.Buffer
for _, s := range stack {
fmt.Fprint(&buf, s)
}
var nerr errors.NonFatal
nerr.AppendRequiredNotSet(buf.String())
return nerr.E
}

60
proto/isinit_test.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
package proto_test
import (
"fmt"
"testing"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
testpb "github.com/golang/protobuf/v2/internal/testprotos/test"
)
func TestIsInitializedErrors(t *testing.T) {
for _, test := range []struct {
m proto.Message
want string
}{
{
&testpb.TestRequired{},
`proto: required field required_field not set`,
},
{
&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{},
},
`proto: required field optional_message.required_field not set`,
},
{
&testpb.TestRequiredForeign{
RepeatedMessage: []*testpb.TestRequired{
{RequiredField: scalar.Int32(1)},
{},
},
},
`proto: required field repeated_message[1].required_field not set`,
},
{
&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
1: {},
},
},
`proto: required field map_message[1].required_field not set`,
},
} {
err := proto.IsInitialized(test.m)
got := "<nil>"
if err != nil {
got = fmt.Sprintf("%q", err)
}
want := fmt.Sprintf("%q", test.want)
if got != want {
t.Errorf("IsInitialized(m):\n got: %v\nwant: %v\nMessage:\n%v", got, want, marshalText(test.m))
}
}
}

View File

@ -22,14 +22,3 @@ func protoMethods(m Message) *protoiface.Methods {
}
return nil
}
func checkRequiredFields(m protoreflect.Message, nerr *errors.NonFatal) {
req := m.Type().RequiredNumbers()
knownFields := m.KnownFields()
for i, reqLen := 0, req.Len(); i < reqLen; i++ {
num := req.Get(i)
if !knownFields.Has(num) {
nerr.AppendRequiredNotSet(string(m.Type().Fields().ByNumber(num).FullName()))
}
}
}

View File

@ -21,6 +21,7 @@ type Methods struct {
Flags MethodFlag
// MarshalAppend appends the wire-format encoding of m to b, returning the result.
// It does not perform required field checks.
MarshalAppend func(b []byte, m protoreflect.ProtoMessage, opts MarshalOptions) ([]byte, error)
// Size returns the size in bytes of the wire-format encoding of m.
@ -31,9 +32,12 @@ type Methods struct {
CachedSize func(m protoreflect.ProtoMessage) int
// Unmarshal parses the wire-format message in b and places the result in m.
// It does not reset m.
// It does not reset m or perform required field checks.
Unmarshal func(b []byte, m protoreflect.ProtoMessage, opts UnmarshalOptions) error
// IsInitialized returns an error if any required fields in m are not set.
IsInitialized func(m protoreflect.ProtoMessage) error
pragma.NoUnkeyedLiterals
}