From 54a0a0476acc1f76f3aa6b788ca33fa52a7ca76f Mon Sep 17 00:00:00 2001
From: Damien Neil <dneil@google.com>
Date: Wed, 8 Jan 2020 17:53:16 -0800
Subject: [PATCH] internal/impl: check for required fields in missing map value

If a map value is a message with required fields, the validator should
note that it is uninitialized if a map item contains no value. In this
case, the value is an empty message which obviously does not have the
required field set.

Change-Id: I7698e60765e3c95478f293e121bba3ad7fc88e27
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/213900
Reviewed-by: Joe Tsai <joetsai@google.com>
---
 internal/impl/validate.go  | 19 ++++++++++++++-----
 proto/testmessages_test.go | 14 ++++++++++++++
 2 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index e6b3d233..5f3c3f5a 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -266,6 +266,7 @@ State:
 				case 2:
 					vi.typ = st.valType
 					vi.mi = st.mi
+					vi.requiredIndex = 1
 				}
 			default:
 				var f *coderFieldInfo
@@ -436,15 +437,23 @@ State:
 		}
 		b = st.tail
 	PopState:
+		numRequiredFields := 0
 		switch st.typ {
 		case validationTypeMessage, validationTypeGroup:
-			// If there are more than 64 required fields, this check will
-			// always fail and we will report that the message is potentially
-			// uninitialized.
-			if st.mi.numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != int(st.mi.numRequiredFields) {
-				initialized = false
+			numRequiredFields = int(st.mi.numRequiredFields)
+		case validationTypeMap:
+			// If this is a map field with a message value that contains
+			// required fields, require that the value be present.
+			if st.mi != nil && st.mi.numRequiredFields > 0 {
+				numRequiredFields = 1
 			}
 		}
+		// If there are more than 64 required fields, this check will
+		// always fail and we will report that the message is potentially
+		// uninitialized.
+		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
+			initialized = false
+		}
 		states = states[:len(states)-1]
 	}
 	if !initialized {
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 8c62dbb9..de85421c 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -1270,6 +1270,20 @@ var testValidMessages = []testProto{
 			}),
 		}.Marshal(),
 	},
+	{
+		desc:    "required field in absent map message value",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			MapMessage: map[int32]*testpb.TestRequired{
+				2: {},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(2),
+			}),
+		}.Marshal(),
+	},
 	{
 		desc: "required field in map message set",
 		decodeTo: []proto.Message{&testpb.TestRequiredForeign{