mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-04-17 02:42:35 +00:00
proto: add IsInitialized
Move all checks for required fields into a proto.IsInitialized function. Initial testing makes me confident that we can provide a fast-path implementation of IsInitialized which will perform more than acceptably. (In the degenerate-but-common case where a message transitively contains no required fields, this check can be nearly zero cost.) Unifying checks into a single function provides consistent behavior between the wire, text, and json codecs. Performing the check after decoding eliminates the wire decoder bug where a split message is incorrectly seen as missing required fields. Performing the check after decoding also provides consistent and arguably more correct behavior when the target message was partially prepopulated. Change-Id: I9478b7bebb263af00c0d9f66a1f26e31ff553522 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170787 Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
parent
96c229ab14
commit
4686e239b6
@ -71,6 +71,10 @@ func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error {
|
||||
if val.Type() != json.EOF {
|
||||
return unexpectedJSONError{val}
|
||||
}
|
||||
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(proto.IsInitialized(m))
|
||||
}
|
||||
return nerr.E
|
||||
}
|
||||
|
||||
@ -151,7 +155,6 @@ func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) err
|
||||
// unmarshalFields unmarshals the fields into the given protoreflect.Message.
|
||||
func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
|
||||
var nerr errors.NonFatal
|
||||
var reqNums set.Ints
|
||||
var seenNums set.Ints
|
||||
var seenOneofs set.Ints
|
||||
|
||||
@ -251,21 +254,6 @@ Loop:
|
||||
if err := o.unmarshalSingular(knownFields, fd); !nerr.Merge(err) {
|
||||
return errors.New("%v|%q: %v", fd.FullName(), name, err)
|
||||
}
|
||||
if !o.AllowPartial && cardinality == pref.Required {
|
||||
reqNums.Set(num)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !o.AllowPartial {
|
||||
// Check for any missing required fields.
|
||||
allReqNums := msgType.RequiredNumbers()
|
||||
if reqNums.Len() != allReqNums.Len() {
|
||||
for i := 0; i < allReqNums.Len(); i++ {
|
||||
if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
|
||||
nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,6 +63,9 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
|
||||
if !nerr.Merge(err) {
|
||||
return nil, err
|
||||
}
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(proto.IsInitialized(m))
|
||||
}
|
||||
return o.encoder.Bytes(), nerr.E
|
||||
}
|
||||
|
||||
@ -95,10 +98,6 @@ func (o MarshalOptions) marshalFields(m pref.Message) error {
|
||||
num := fd.Number()
|
||||
|
||||
if !knownFields.Has(num) {
|
||||
if !o.AllowPartial && fd.Cardinality() == pref.Required {
|
||||
// Treat unset required fields as a non-fatal error.
|
||||
nerr.AppendRequiredNotSet(string(fd.FullName()))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -64,6 +64,10 @@ func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(proto.IsInitialized(m))
|
||||
}
|
||||
|
||||
return nerr.E
|
||||
}
|
||||
|
||||
@ -102,7 +106,6 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
|
||||
fieldDescs := msgType.Fields()
|
||||
reservedNames := msgType.ReservedNames()
|
||||
xtTypes := knownFields.ExtensionTypes()
|
||||
var reqNums set.Ints
|
||||
var seenNums set.Ints
|
||||
var seenOneofs set.Ints
|
||||
|
||||
@ -176,25 +179,10 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
|
||||
if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
|
||||
return err
|
||||
}
|
||||
if !o.AllowPartial && cardinality == pref.Required {
|
||||
reqNums.Set(num)
|
||||
}
|
||||
seenNums.Set(num)
|
||||
}
|
||||
}
|
||||
|
||||
if !o.AllowPartial {
|
||||
// Check for any missing required fields.
|
||||
allReqNums := msgType.RequiredNumbers()
|
||||
if reqNums.Len() != allReqNums.Len() {
|
||||
for i := 0; i < allReqNums.Len(); i++ {
|
||||
if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
|
||||
nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nerr.E
|
||||
}
|
||||
|
||||
|
@ -63,6 +63,9 @@ func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) {
|
||||
if !nerr.Merge(err) {
|
||||
return nil, err
|
||||
}
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(proto.IsInitialized(m))
|
||||
}
|
||||
return b, nerr.E
|
||||
}
|
||||
|
||||
@ -91,10 +94,6 @@ func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
|
||||
num := fd.Number()
|
||||
|
||||
if !knownFields.Has(num) {
|
||||
if !o.AllowPartial && fd.Cardinality() == pref.Required {
|
||||
// Treat unset required fields as a non-fatal error.
|
||||
nerr.AppendRequiredNotSet(string(fd.FullName()))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -42,10 +42,18 @@ func Unmarshal(b []byte, m Message) error {
|
||||
// Unmarshal parses the wire-format message in b and places the result in m.
|
||||
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
|
||||
// TODO: Reset m?
|
||||
if err := o.unmarshalMessageFast(b, m); err != errInternalNoFast {
|
||||
err := o.unmarshalMessageFast(b, m)
|
||||
if err == errInternalNoFast {
|
||||
err = o.unmarshalMessage(b, m.ProtoReflect())
|
||||
}
|
||||
var nerr errors.NonFatal
|
||||
if !nerr.Merge(err) {
|
||||
return err
|
||||
}
|
||||
return o.unmarshalMessage(b, m.ProtoReflect())
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(IsInitialized(m))
|
||||
}
|
||||
return nerr.E
|
||||
}
|
||||
|
||||
func (o UnmarshalOptions) unmarshalMessageFast(b []byte, m Message) error {
|
||||
@ -100,9 +108,6 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
|
||||
}
|
||||
b = b[tagLen+valLen:]
|
||||
}
|
||||
if !o.AllowPartial {
|
||||
checkRequiredFields(m, &nerr)
|
||||
}
|
||||
return nerr.E
|
||||
}
|
||||
|
||||
@ -204,9 +209,6 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
|
||||
if !haveVal {
|
||||
switch valField.Kind() {
|
||||
case protoreflect.GroupKind, protoreflect.MessageKind:
|
||||
if !o.AllowPartial {
|
||||
checkRequiredFields(val.Message(), &nerr)
|
||||
}
|
||||
default:
|
||||
val = valField.Default()
|
||||
}
|
||||
|
@ -944,23 +944,20 @@ var testProtos = []testProto{
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
// TODO: Handle this case.
|
||||
/*
|
||||
{
|
||||
desc: "required field in optional message set (split across multiple tags)",
|
||||
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
|
||||
OptionalMessage: &testpb.TestRequired{
|
||||
RequiredField: scalar.Int32(1),
|
||||
},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
|
||||
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(1),
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
*/
|
||||
{
|
||||
desc: "required field in optional message set (split across multiple tags)",
|
||||
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
|
||||
OptionalMessage: &testpb.TestRequired{
|
||||
RequiredField: scalar.Int32(1),
|
||||
},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
|
||||
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(1),
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "required field in repeated message unset",
|
||||
partial: true,
|
||||
|
@ -69,10 +69,18 @@ func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
|
||||
// MarshalAppend appends the wire-format encoding of m to b,
|
||||
// returning the result.
|
||||
func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
|
||||
if b, err := o.marshalMessageFast(b, m); err != errInternalNoFast {
|
||||
b, err := o.marshalMessageFast(b, m)
|
||||
if err == errInternalNoFast {
|
||||
b, err = o.marshalMessage(b, m.ProtoReflect())
|
||||
}
|
||||
var nerr errors.NonFatal
|
||||
if !nerr.Merge(err) {
|
||||
return b, err
|
||||
}
|
||||
return o.marshalMessage(b, m.ProtoReflect())
|
||||
if !o.AllowPartial {
|
||||
nerr.Merge(IsInitialized(m))
|
||||
}
|
||||
return b, nerr.E
|
||||
}
|
||||
|
||||
func (o MarshalOptions) marshalMessageFast(b []byte, m Message) ([]byte, error) {
|
||||
@ -129,9 +137,6 @@ func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte
|
||||
b = append(b, raw...)
|
||||
return true
|
||||
})
|
||||
if !o.AllowPartial {
|
||||
checkRequiredFields(m, &nerr)
|
||||
}
|
||||
return b, nerr.E
|
||||
}
|
||||
|
||||
|
94
proto/isinit.go
Normal file
94
proto/isinit.go
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/v2/internal/errors"
|
||||
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
|
||||
)
|
||||
|
||||
// IsInitialized returns an error if any required fields in m are not set.
|
||||
func IsInitialized(m Message) error {
|
||||
if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil {
|
||||
// TODO: Do we need a way to disable the fast path here?
|
||||
//
|
||||
// TODO: Should detailed information about missing
|
||||
// fields always be provided by the slow-but-informative
|
||||
// reflective implementation?
|
||||
return methods.IsInitialized(m)
|
||||
}
|
||||
return isInitialized(m.ProtoReflect(), nil)
|
||||
}
|
||||
|
||||
// IsInitialized returns an error if any required fields in m are not set.
|
||||
func isInitialized(m pref.Message, stack []interface{}) error {
|
||||
md := m.Type()
|
||||
known := m.KnownFields()
|
||||
fields := md.Fields()
|
||||
for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
|
||||
num := nums.Get(i)
|
||||
if !known.Has(num) {
|
||||
stack = append(stack, fields.ByNumber(num).Name())
|
||||
return newRequiredNotSetError(stack)
|
||||
}
|
||||
}
|
||||
var err error
|
||||
known.Range(func(num pref.FieldNumber, v pref.Value) bool {
|
||||
field := fields.ByNumber(num)
|
||||
if field == nil {
|
||||
field = known.ExtensionTypes().ByNumber(num)
|
||||
}
|
||||
if field == nil {
|
||||
panic(fmt.Errorf("no descriptor for field %d in %q", num, md.FullName()))
|
||||
}
|
||||
// Look for fields containing a message: Messages, groups, and maps
|
||||
// with a message or group value.
|
||||
ft := field.MessageType()
|
||||
if ft == nil {
|
||||
return true
|
||||
}
|
||||
if field.IsMap() {
|
||||
if ft.Fields().ByNumber(2).MessageType() == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Recurse into the field
|
||||
stack := append(stack, field.Name())
|
||||
switch {
|
||||
case field.IsMap():
|
||||
v.Map().Range(func(key pref.MapKey, v pref.Value) bool {
|
||||
stack := append(stack, "[", key, "].")
|
||||
err = isInitialized(v.Message(), stack)
|
||||
return err == nil
|
||||
})
|
||||
case field.Cardinality() == pref.Repeated:
|
||||
for i, list := 0, v.List(); i < list.Len(); i++ {
|
||||
stack := append(stack, "[", i, "].")
|
||||
err = isInitialized(list.Get(i).Message(), stack)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
stack := append(stack, ".")
|
||||
err = isInitialized(v.Message(), stack)
|
||||
}
|
||||
return err == nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func newRequiredNotSetError(stack []interface{}) error {
|
||||
var buf bytes.Buffer
|
||||
for _, s := range stack {
|
||||
fmt.Fprint(&buf, s)
|
||||
}
|
||||
var nerr errors.NonFatal
|
||||
nerr.AppendRequiredNotSet(buf.String())
|
||||
return nerr.E
|
||||
}
|
60
proto/isinit_test.go
Normal file
60
proto/isinit_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/v2/internal/scalar"
|
||||
"github.com/golang/protobuf/v2/proto"
|
||||
|
||||
testpb "github.com/golang/protobuf/v2/internal/testprotos/test"
|
||||
)
|
||||
|
||||
func TestIsInitializedErrors(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
m proto.Message
|
||||
want string
|
||||
}{
|
||||
{
|
||||
&testpb.TestRequired{},
|
||||
`proto: required field required_field not set`,
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
OptionalMessage: &testpb.TestRequired{},
|
||||
},
|
||||
`proto: required field optional_message.required_field not set`,
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
RepeatedMessage: []*testpb.TestRequired{
|
||||
{RequiredField: scalar.Int32(1)},
|
||||
{},
|
||||
},
|
||||
},
|
||||
`proto: required field repeated_message[1].required_field not set`,
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
MapMessage: map[int32]*testpb.TestRequired{
|
||||
1: {},
|
||||
},
|
||||
},
|
||||
`proto: required field map_message[1].required_field not set`,
|
||||
},
|
||||
} {
|
||||
err := proto.IsInitialized(test.m)
|
||||
got := "<nil>"
|
||||
if err != nil {
|
||||
got = fmt.Sprintf("%q", err)
|
||||
}
|
||||
want := fmt.Sprintf("%q", test.want)
|
||||
if got != want {
|
||||
t.Errorf("IsInitialized(m):\n got: %v\nwant: %v\nMessage:\n%v", got, want, marshalText(test.m))
|
||||
}
|
||||
}
|
||||
}
|
@ -22,14 +22,3 @@ func protoMethods(m Message) *protoiface.Methods {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkRequiredFields(m protoreflect.Message, nerr *errors.NonFatal) {
|
||||
req := m.Type().RequiredNumbers()
|
||||
knownFields := m.KnownFields()
|
||||
for i, reqLen := 0, req.Len(); i < reqLen; i++ {
|
||||
num := req.Get(i)
|
||||
if !knownFields.Has(num) {
|
||||
nerr.AppendRequiredNotSet(string(m.Type().Fields().ByNumber(num).FullName()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ type Methods struct {
|
||||
Flags MethodFlag
|
||||
|
||||
// MarshalAppend appends the wire-format encoding of m to b, returning the result.
|
||||
// It does not perform required field checks.
|
||||
MarshalAppend func(b []byte, m protoreflect.ProtoMessage, opts MarshalOptions) ([]byte, error)
|
||||
|
||||
// Size returns the size in bytes of the wire-format encoding of m.
|
||||
@ -31,9 +32,12 @@ type Methods struct {
|
||||
CachedSize func(m protoreflect.ProtoMessage) int
|
||||
|
||||
// Unmarshal parses the wire-format message in b and places the result in m.
|
||||
// It does not reset m.
|
||||
// It does not reset m or perform required field checks.
|
||||
Unmarshal func(b []byte, m protoreflect.ProtoMessage, opts UnmarshalOptions) error
|
||||
|
||||
// IsInitialized returns an error if any required fields in m are not set.
|
||||
IsInitialized func(m protoreflect.ProtoMessage) error
|
||||
|
||||
pragma.NoUnkeyedLiterals
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user