From 54a0a0476acc1f76f3aa6b788ca33fa52a7ca76f Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Wed, 8 Jan 2020 17:53:16 -0800 Subject: [PATCH] internal/impl: check for required fields in missing map value If a map value is a message with required fields, the validator should note that it is uninitialized if a map item contains no value. In this case, the value is an empty message which obviously does not have the required field set. Change-Id: I7698e60765e3c95478f293e121bba3ad7fc88e27 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/213900 Reviewed-by: Joe Tsai <joetsai@google.com> --- internal/impl/validate.go | 19 ++++++++++++++----- proto/testmessages_test.go | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/internal/impl/validate.go b/internal/impl/validate.go index e6b3d233..5f3c3f5a 100644 --- a/internal/impl/validate.go +++ b/internal/impl/validate.go @@ -266,6 +266,7 @@ State: case 2: vi.typ = st.valType vi.mi = st.mi + vi.requiredIndex = 1 } default: var f *coderFieldInfo @@ -436,15 +437,23 @@ State: } b = st.tail PopState: + numRequiredFields := 0 switch st.typ { case validationTypeMessage, validationTypeGroup: - // If there are more than 64 required fields, this check will - // always fail and we will report that the message is potentially - // uninitialized. - if st.mi.numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != int(st.mi.numRequiredFields) { - initialized = false + numRequiredFields = int(st.mi.numRequiredFields) + case validationTypeMap: + // If this is a map field with a message value that contains + // required fields, require that the value be present. + if st.mi != nil && st.mi.numRequiredFields > 0 { + numRequiredFields = 1 } } + // If there are more than 64 required fields, this check will + // always fail and we will report that the message is potentially + // uninitialized. + if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields { + initialized = false + } states = states[:len(states)-1] } if !initialized { diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go index 8c62dbb9..de85421c 100644 --- a/proto/testmessages_test.go +++ b/proto/testmessages_test.go @@ -1270,6 +1270,20 @@ var testValidMessages = []testProto{ }), }.Marshal(), }, + { + desc: "required field in absent map message value", + partial: true, + decodeTo: []proto.Message{&testpb.TestRequiredForeign{ + MapMessage: map[int32]*testpb.TestRequired{ + 2: {}, + }, + }}, + wire: pack.Message{ + pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(2), + }), + }.Marshal(), + }, { desc: "required field in map message set", decodeTo: []proto.Message{&testpb.TestRequiredForeign{