From 3d0706ac2e40495e56815f8d86b07077bc207eb3 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 9 Jul 2019 11:40:49 -0700 Subject: [PATCH] proto, internal/impl: make IsInitialized more consistent Make the fast-path and slow-path versions of IsInitialized report exactly the same errors: An errors.RequiredNotSet containing the full name of one of the unset required fields. Bugfix: Fast-path IsInitialized on a nil message reports an error only when the message directly contains required fields. Bugfix: Include fast-path IsInitialized in legacy messageIfaceWrapper. Fixes golang/protobuf#887 Change-Id: Ia5e4b386f8c23f6f855d995f4a098b1338acbae3 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185397 Reviewed-by: Joe Tsai --- internal/impl/isinit.go | 15 ++++++++------- internal/impl/message.go | 4 ++++ proto/decode_test.go | 13 +++++++++++++ proto/doc.go | 9 +++++++++ proto/isinit.go | 37 +++++++------------------------------ proto/isinit_test.go | 8 ++++---- 6 files changed, 45 insertions(+), 41 deletions(-) create mode 100644 proto/doc.go diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go index 29f9ad8f..a533cf41 100644 --- a/internal/impl/isinit.go +++ b/internal/impl/isinit.go @@ -7,15 +7,11 @@ package impl import ( "sync" + "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/proto" pref "google.golang.org/protobuf/reflect/protoreflect" ) -type errRequiredNotSet struct{} - -func (errRequiredNotSet) Error() string { return "proto: required field not set" } -func (errRequiredNotSet) RequiredNotSet() bool { return true } - func (mi *MessageInfo) isInitialized(msg proto.Message) error { return mi.isInitializedPointer(pointerOfIface(msg)) } @@ -26,7 +22,12 @@ func (mi *MessageInfo) isInitializedPointer(p pointer) error { return nil } if p.IsNil() { - return errRequiredNotSet{} + for _, f := range mi.orderedCoderFields { + if f.isRequired { + return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName())) + } + } + return nil } if mi.extensionOffset.IsValid() { e := p.Apply(mi.extensionOffset).Extensions() @@ -41,7 +42,7 @@ func (mi *MessageInfo) isInitializedPointer(p pointer) error { fptr := p.Apply(f.offset) if f.isPointer && fptr.Elem().IsNil() { if f.isRequired { - return errRequiredNotSet{} + return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName())) } continue } diff --git a/internal/impl/message.go b/internal/impl/message.go index ec9be4ec..e9a38c21 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -517,6 +517,7 @@ func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods { Flags: piface.MethodFlagDeterministicMarshal, MarshalAppend: m.marshalAppend, Size: m.size, + IsInitialized: m.isInitialized, } } func (m *messageIfaceWrapper) ProtoUnwrap() interface{} { @@ -528,3 +529,6 @@ func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) { return m.mi.sizePointer(m.p, 0) } +func (m *messageIfaceWrapper) isInitialized(msg pref.ProtoMessage) error { + return m.mi.isInitializedPointer(m.p) +} diff --git a/proto/decode_test.go b/proto/decode_test.go index 3b919aa9..03f337a0 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -951,6 +951,11 @@ var testProtos = []testProto{ }), }.Marshal(), }, + { + desc: "required field in nil message unset", + partial: true, + decodeTo: []proto.Message{(*testpb.TestRequired)(nil)}, + }, { desc: "required field unset", partial: true, @@ -1223,6 +1228,14 @@ var testProtos = []testProto{ }), }.Marshal(), }, + { + desc: "nil messages", + decodeTo: []proto.Message{ + (*testpb.TestAllTypes)(nil), + (*test3pb.TestAllTypes)(nil), + (*testpb.TestAllExtensions)(nil), + }, + }, { desc: "legacy", partial: true, diff --git a/proto/doc.go b/proto/doc.go new file mode 100644 index 00000000..0638dd5a --- /dev/null +++ b/proto/doc.go @@ -0,0 +1,9 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proto needs documentation. +// +// TODO: Add package docs. +// TODO: Document guarantees (or lack thereof) made about errors. +package proto diff --git a/proto/isinit.go b/proto/isinit.go index 1bb29bdf..e0fd947b 100644 --- a/proto/isinit.go +++ b/proto/isinit.go @@ -5,9 +5,6 @@ package proto import ( - "bytes" - "fmt" - "google.golang.org/protobuf/internal/errors" pref "google.golang.org/protobuf/reflect/protoreflect" ) @@ -15,66 +12,46 @@ import ( // IsInitialized returns an error if any required fields in m are not set. func IsInitialized(m Message) error { if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil { - if err := methods.IsInitialized(m); err == nil { - return nil - } - // Fall through to the slow path, since the fast-path - // implementation doesn't produce nice errors. - // - // TODO: Consider producing better errors from the fast path. + return methods.IsInitialized(m) } - return isInitialized(m.ProtoReflect(), nil) + return isInitialized(m.ProtoReflect()) } // IsInitialized returns an error if any required fields in m are not set. -func isInitialized(m pref.Message, stack []interface{}) error { +func isInitialized(m pref.Message) error { md := m.Descriptor() fds := md.Fields() for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ { fd := fds.ByNumber(nums.Get(i)) if !m.Has(fd) { - stack = append(stack, fd.Name()) - return newRequiredNotSetError(stack) + return errors.RequiredNotSet(string(fd.FullName())) } } var err error m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { - // Recurse into fields containing message values. - stack := append(stack, fd.Name()) switch { case fd.IsList(): if fd.Message() == nil { return true } for i, list := 0, v.List(); i < list.Len() && err == nil; i++ { - stack := append(stack, "[", i, "].") - err = isInitialized(list.Get(i).Message(), stack) + err = IsInitialized(list.Get(i).Message().Interface()) } case fd.IsMap(): if fd.MapValue().Message() == nil { return true } v.Map().Range(func(key pref.MapKey, v pref.Value) bool { - stack := append(stack, "[", key, "].") - err = isInitialized(v.Message(), stack) + err = IsInitialized(v.Message().Interface()) return err == nil }) default: if fd.Message() == nil { return true } - stack := append(stack, ".") - err = isInitialized(v.Message(), stack) + err = IsInitialized(v.Message().Interface()) } return err == nil }) return err } - -func newRequiredNotSetError(stack []interface{}) error { - var buf bytes.Buffer - for _, s := range stack { - fmt.Fprint(&buf, s) - } - return errors.RequiredNotSet(buf.String()) -} diff --git a/proto/isinit_test.go b/proto/isinit_test.go index 232202f4..1edbfb4c 100644 --- a/proto/isinit_test.go +++ b/proto/isinit_test.go @@ -21,13 +21,13 @@ func TestIsInitializedErrors(t *testing.T) { }{ { &testpb.TestRequired{}, - `proto: required field required_field not set`, + `proto: required field goproto.proto.test.TestRequired.required_field not set`, }, { &testpb.TestRequiredForeign{ OptionalMessage: &testpb.TestRequired{}, }, - `proto: required field optional_message.required_field not set`, + `proto: required field goproto.proto.test.TestRequired.required_field not set`, }, { &testpb.TestRequiredForeign{ @@ -36,7 +36,7 @@ func TestIsInitializedErrors(t *testing.T) { {}, }, }, - `proto: required field repeated_message[1].required_field not set`, + `proto: required field goproto.proto.test.TestRequired.required_field not set`, }, { &testpb.TestRequiredForeign{ @@ -44,7 +44,7 @@ func TestIsInitializedErrors(t *testing.T) { 1: {}, }, }, - `proto: required field map_message[1].required_field not set`, + `proto: required field goproto.proto.test.TestRequired.required_field not set`, }, } { err := proto.IsInitialized(test.m)