diff --git a/cmd/protoc-gen-go/internal_gengo/well_known_types.go b/cmd/protoc-gen-go/internal_gengo/well_known_types.go index 9a1b7bdf..dbaa529c 100644 --- a/cmd/protoc-gen-go/internal_gengo/well_known_types.go +++ b/cmd/protoc-gen-go/internal_gengo/well_known_types.go @@ -971,9 +971,12 @@ func genMessageKnownFunctions(g *protogen.GeneratedFile, f *fileInfo, m *message g.P() g.P(" // Identify the next message to search within.") g.P(" md = fd.Message() // may be nil") - g.P(" if fd.IsMap() {") - g.P(" md = fd.MapValue().Message() // may be nil") + g.P() + g.P(" // Repeated fields are only allowed at the last postion.") + g.P(" if fd.IsList() || fd.IsMap() {") + g.P(" md = nil") g.P(" }") + g.P() g.P(" return true") g.P(" }) {") g.P(" return i") diff --git a/types/known/fieldmaskpb/field_mask.pb.go b/types/known/fieldmaskpb/field_mask.pb.go index 6a8d872c..a852befe 100644 --- a/types/known/fieldmaskpb/field_mask.pb.go +++ b/types/known/fieldmaskpb/field_mask.pb.go @@ -393,9 +393,12 @@ func numValidPaths(m proto.Message, paths []string) int { // Identify the next message to search within. md = fd.Message() // may be nil - if fd.IsMap() { - md = fd.MapValue().Message() // may be nil + + // Repeated fields are only allowed at the last postion. + if fd.IsList() || fd.IsMap() { + md = nil } + return true }) { return i diff --git a/types/known/fieldmaskpb/field_mask_test.go b/types/known/fieldmaskpb/field_mask_test.go index 6d21711f..19756c50 100644 --- a/types/known/fieldmaskpb/field_mask_test.go +++ b/types/known/fieldmaskpb/field_mask_test.go @@ -37,7 +37,8 @@ func TestAppend(t *testing.T) { }, { inMessage: (*testpb.TestAllTypes)(nil), inPaths: []string{"optional_int32", "OptionalGroup.optional_nested_message", "map_uint32_uint32", "map_string_nested_message.corecursive", "oneof_bool"}, - wantPaths: []string{"optional_int32", "OptionalGroup.optional_nested_message", "map_uint32_uint32", "map_string_nested_message.corecursive", "oneof_bool"}, + wantPaths: []string{"optional_int32", "OptionalGroup.optional_nested_message", "map_uint32_uint32"}, + wantError: cmpopts.AnyError, }, { inMessage: (*testpb.TestAllTypes)(nil), inPaths: []string{"optional_nested_message", "optional_nested_message.corecursive", "optional_nested_message.corecursive.optional_nested_message", "optional_nested_message.corecursive.optional_nested_message.corecursive"}, @@ -194,3 +195,144 @@ func TestNormalize(t *testing.T) { }) } } + +func TestIsValid(t *testing.T) { + tests := []struct { + message proto.Message + paths []string + want bool + }{{ + message: (*testpb.TestAllTypes)(nil), + paths: []string{"no_such_field"}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{""}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{ + "optional_int32", + "optional_int32", + "optional_int64", + "optional_uint32", + "optional_uint64", + "optional_sint32", + "optional_sint64", + "optional_fixed32", + "optional_fixed64", + "optional_sfixed32", + "optional_sfixed64", + "optional_float", + "optional_double", + "optional_bool", + "optional_string", + "optional_bytes", + "OptionalGroup", + "optional_nested_message", + "optional_foreign_message", + "optional_import_message", + "optional_nested_enum", + "optional_foreign_enum", + "optional_import_enum", + "repeated_int32", + "repeated_int64", + "repeated_uint32", + "repeated_uint64", + "repeated_sint32", + "repeated_sint64", + "repeated_fixed32", + "repeated_fixed64", + "repeated_sfixed32", + "repeated_sfixed64", + "repeated_float", + "repeated_double", + "repeated_bool", + "repeated_string", + "repeated_bytes", + "RepeatedGroup", + "repeated_nested_message", + "repeated_foreign_message", + "repeated_importmessage", + "repeated_nested_enum", + "repeated_foreign_enum", + "repeated_importenum", + "map_int32_int32", + "map_int64_int64", + "map_uint32_uint32", + "map_uint64_uint64", + "map_sint32_sint32", + "map_sint64_sint64", + "map_fixed32_fixed32", + "map_fixed64_fixed64", + "map_sfixed32_sfixed32", + "map_sfixed64_sfixed64", + "map_int32_float", + "map_int32_double", + "map_bool_bool", + "map_string_string", + "map_string_bytes", + "map_string_nested_message", + "map_string_nested_enum", + "oneof_uint32", + "oneof_nested_message", + "oneof_string", + "oneof_bytes", + "oneof_bool", + "oneof_uint64", + "oneof_float", + "oneof_double", + "oneof_enum", + "OneofGroup", + }, + want: true, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{ + "optional_nested_message.a", + "optional_nested_message.corecursive", + "optional_nested_message.corecursive.optional_int32", + "optional_nested_message.corecursive.optional_nested_message.corecursive.optional_nested_message.a", + "OptionalGroup.a", + "OptionalGroup.optional_nested_message", + "OptionalGroup.optional_nested_message.corecursive", + "oneof_nested_message.a", + "oneof_nested_message.corecursive", + }, + want: true, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{"repeated_nested_message.a"}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{"repeated_nested_message[0]"}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{"repeated_nested_message[0].a"}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{"map_string_nested_message.a"}, + want: false, + }, { + message: (*testpb.TestAllTypes)(nil), + paths: []string{`map_string_nested_message["key"]`}, + want: false, + }, { + message: (*testpb.TestAllExtensions)(nil), + paths: []string{"nested_string_extension"}, + want: false, + }} + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + mask := &fmpb.FieldMask{Paths: tt.paths} + got := mask.IsValid(tt.message) + if got != tt.want { + t.Errorf("IsValid() returns %v want %v", got, tt.want) + } + }) + } +}