diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go index 59c98b19..218d95f3 100644 --- a/encoding/textpb/decode.go +++ b/encoding/textpb/decode.go @@ -7,6 +7,7 @@ package textpb import ( "fmt" "strings" + "unicode/utf8" "github.com/golang/protobuf/v2/internal/encoding/text" "github.com/golang/protobuf/v2/internal/errors" @@ -293,7 +294,13 @@ func unmarshalScalar(input text.Value, fd pref.FieldDescriptor) (pref.Value, err } case pref.StringKind: if input.Type() == text.String { - return pref.ValueOf(string(input.String())), nil + s := input.String() + if utf8.ValidString(s) { + return pref.ValueOf(s), nil + } + var nerr errors.NonFatal + nerr.AppendInvalidUTF8(string(fd.FullName())) + return pref.ValueOf(s), nerr.E } case pref.BytesKind: if input.Type() == text.String { @@ -421,11 +428,12 @@ func unmarshalMapKey(input text.Value, fd pref.FieldDescriptor) (pref.MapKey, er return fd.Default().MapKey(), nil } + var nerr errors.NonFatal val, err := unmarshalScalar(input, fd) - if err != nil { + if !nerr.Merge(err) { return pref.MapKey{}, errors.New("%v contains invalid key: %v", fd.FullName(), input) } - return val.MapKey(), nil + return val.MapKey(), nerr.E } // unmarshalMapMessageValue unmarshals given message-type text.Value into a protoreflect.Map for @@ -447,18 +455,19 @@ func (o UnmarshalOptions) unmarshalMapMessageValue(input text.Value, pkey pref.M // unmarshalMapScalarValue unmarshals given scalar-type text.Value into a protoreflect.Map // for the given MapKey. func unmarshalMapScalarValue(input text.Value, pkey pref.MapKey, fd pref.FieldDescriptor, mmap pref.Map) error { + var nerr errors.NonFatal var val pref.Value if input.Type() == 0 { val = fd.Default() } else { var err error val, err = unmarshalScalar(input, fd) - if err != nil { + if !nerr.Merge(err) { return err } } mmap.Set(pkey, val) - return nil + return nerr.E } // isExpandedAny returns true if given [][2]text.Value may be an expanded Any that contains only one diff --git a/encoding/textpb/decode_test.go b/encoding/textpb/decode_test.go index 7c45641f..e98b0b3c 100644 --- a/encoding/textpb/decode_test.go +++ b/encoding/textpb/decode_test.go @@ -10,6 +10,7 @@ import ( protoV1 "github.com/golang/protobuf/proto" "github.com/golang/protobuf/v2/encoding/textpb" + "github.com/golang/protobuf/v2/internal/errors" "github.com/golang/protobuf/v2/internal/legacy" "github.com/golang/protobuf/v2/internal/scalar" "github.com/golang/protobuf/v2/proto" @@ -182,6 +183,14 @@ s_string: "谷歌" SBytes: []byte("\xe8\xb0\xb7\xe6\xad\x8c"), SString: "谷歌", }, + }, { + desc: "string with invalid UTF-8", + inputMessage: &pb3.Scalars{}, + inputText: `s_string: "abc\xff"`, + wantMessage: &pb3.Scalars{ + SString: "abc\xff", + }, + wantErr: true, }, { desc: "proto2 message contains unknown field", inputMessage: &pb2.Scalars{}, @@ -473,6 +482,19 @@ s_nested: { }, }, }, + }, { + desc: "proto3 nested message contains invalid UTF-8", + inputMessage: &pb3.Nests{}, + inputText: `s_nested: { + s_string: "abc\xff" +} +`, + wantMessage: &pb3.Nests{ + SNested: &pb3.Nested{ + SString: "abc\xff", + }, + }, + wantErr: true, }, { desc: "oneof set to empty string", inputMessage: &pb3.Oneofs{}, @@ -560,6 +582,14 @@ rpt_string: "b" RptString: []string{"a", "x", "y", "b"}, RptBool: []bool{true, false, true}, }, + }, { + desc: "repeated contains invalid UTF-8", + inputMessage: &pb2.Repeats{}, + inputText: `rpt_string: "abc\xff"`, + wantMessage: &pb2.Repeats{ + RptString: []string{"abc\xff"}, + }, + wantErr: true, }, { desc: "repeated enums", inputMessage: &pb2.Enums{}, @@ -870,6 +900,34 @@ int32_to_str: {} 0: "", }, }, + }, { + desc: "map field value contains invalid UTF-8", + inputMessage: &pb3.Maps{}, + inputText: `int32_to_str: { + key: 101 + value: "abc\xff" +} +`, + wantMessage: &pb3.Maps{ + Int32ToStr: map[int32]string{ + 101: "abc\xff", + }, + }, + wantErr: true, + }, { + desc: "map field key contains invalid UTF-8", + inputMessage: &pb3.Maps{}, + inputText: `str_to_nested: { + key: "abc\xff" + value: {} +} +`, + wantMessage: &pb3.Maps{ + StrToNested: map[string]*pb3.Nested{ + "abc\xff": {}, + }, + }, + wantErr: true, }, { desc: "map contains unknown field", inputMessage: &pb3.Maps{}, @@ -1164,6 +1222,16 @@ opt_int32: 42 }) return m }(), + }, { + desc: "extension field contains invalid UTF-8", + inputMessage: &pb2.Extensions{}, + inputText: `[pb2.opt_ext_string]: "abc\xff"`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{} + setExtension(m, pb2.E_OptExtString, "abc\xff") + return m + }(), + wantErr: true, }, { desc: "extensions of repeated fields", inputMessage: &pb2.Extensions{}, @@ -1418,6 +1486,32 @@ value: "some bytes" } }(), wantErr: true, + }, { + desc: "Any with invalid UTF-8", + umo: textpb.UnmarshalOptions{ + Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()), + }, + inputMessage: &knownpb.Any{}, + inputText: ` +[pb3.Nested]: { + s_string: "abc\xff" +} +`, + wantMessage: func() proto.Message { + m := &pb3.Nested{ + SString: "abc\xff", + } + var nerr errors.NonFatal + b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) + if !nerr.Merge(err) { + t.Fatalf("error in binary marshaling message for Any.value: %v", err) + } + return &knownpb.Any{ + TypeUrl: string(m.ProtoReflect().Type().FullName()), + Value: b, + } + }(), + wantErr: true, }, { desc: "Any expanded with unregistered type", umo: textpb.UnmarshalOptions{Resolver: preg.NewTypes()}, @@ -1459,7 +1553,6 @@ type_url: "pb2.Nested" for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { - t.Parallel() err := tt.umo.Unmarshal(tt.inputMessage, []byte(tt.inputText)) if err != nil && !tt.wantErr { t.Errorf("Unmarshal() returned error: %v\n\n", err) diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go index e2931432..c706898d 100644 --- a/encoding/textpb/encode.go +++ b/encoding/textpb/encode.go @@ -7,6 +7,7 @@ package textpb import ( "fmt" "sort" + "unicode/utf8" "github.com/golang/protobuf/v2/internal/encoding/text" "github.com/golang/protobuf/v2/internal/encoding/wire" @@ -174,9 +175,18 @@ func (o MarshalOptions) marshalSingular(val pref.Value, fd pref.FieldDescriptor) pref.Sfixed32Kind, pref.Fixed32Kind, pref.Sfixed64Kind, pref.Fixed64Kind, pref.FloatKind, pref.DoubleKind, - pref.StringKind, pref.BytesKind: + pref.BytesKind: return text.ValueOf(val.Interface()), nil + case pref.StringKind: + s := val.String() + if utf8.ValidString(s) { + return text.ValueOf(s), nil + } + var nerr errors.NonFatal + nerr.AppendInvalidUTF8(string(fd.FullName())) + return text.ValueOf(s), nerr.E + case pref.EnumKind: num := val.Enum() if desc := fd.EnumType().Values().ByNumber(num); desc != nil { diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go index 5b9ee38a..3397d660 100644 --- a/encoding/textpb/encode_test.go +++ b/encoding/textpb/encode_test.go @@ -169,6 +169,14 @@ opt_double: 1.0199999809265137 opt_bytes: "谷歌" opt_string: "谷歌" `, + }, { + desc: "string with invalid UTF-8", + input: &pb3.Scalars{ + SString: "abc\xff", + }, + want: `s_string: "abc\xff" +`, + wantErr: true, }, { desc: "float nan", input: &pb3.Scalars{ @@ -363,6 +371,18 @@ OptGroup: {} } } `, + }, { + desc: "proto3 nested message contains invalid UTF-8", + input: &pb3.Nests{ + SNested: &pb3.Nested{ + SString: "abc\xff", + }, + }, + want: `s_nested: { + s_string: "abc\xff" +} +`, + wantErr: true, }, { desc: "oneof not set", input: &pb3.Oneofs{}, @@ -472,6 +492,14 @@ rpt_string: "世界" rpt_bytes: "hello" rpt_bytes: "世界" `, + }, { + desc: "repeated contains invalid UTF-8", + input: &pb2.Repeats{ + RptString: []string{"abc\xff"}, + }, + want: `rpt_string: "abc\xff" +`, + wantErr: true, }, { desc: "repeated enums", input: &pb2.Enums{ @@ -670,6 +698,32 @@ str_to_oneofs: { } } `, + }, { + desc: "map field value contains invalid UTF-8", + input: &pb3.Maps{ + Int32ToStr: map[int32]string{ + 101: "abc\xff", + }, + }, + want: `int32_to_str: { + key: 101 + value: "abc\xff" +} +`, + wantErr: true, + }, { + desc: "map field key contains invalid UTF-8", + input: &pb3.Maps{ + StrToNested: map[string]*pb3.Nested{ + "abc\xff": {}, + }, + }, + want: `str_to_nested: { + key: "abc\xff" + value: {} +} +`, + wantErr: true, }, { desc: "map field contains nil value", input: &pb3.Maps{ @@ -918,6 +972,16 @@ opt_int32: 42 } [pb2.opt_ext_string]: "extension field" `, + }, { + desc: "extension field contains invalid UTF-8", + input: func() proto.Message { + m := &pb2.Extensions{} + setExtension(m, pb2.E_OptExtString, "abc\xff") + return m + }(), + want: `[pb2.opt_ext_string]: "abc\xff" +`, + wantErr: true, }, { desc: "extension partial returns error", input: func() proto.Message { @@ -1175,6 +1239,29 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception" want: `[pb2.PartialRequired]: { opt_string: "embedded inside Any" } +`, + wantErr: true, + }, { + desc: "Any with invalid UTF-8", + mo: textpb.MarshalOptions{ + Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()), + }, + input: func() proto.Message { + m := &pb3.Nested{ + SString: "abc\xff", + } + b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) + if err != nil { + t.Fatalf("error in binary marshaling message for Any.value: %v", err) + } + return &knownpb.Any{ + TypeUrl: string(m.ProtoReflect().Type().FullName()), + Value: b, + } + }(), + want: `[pb3.Nested]: { + s_string: "abc\xff" +} `, wantErr: true, }, {