From bc310b58c62c4620dcda356e20d298cf8fbde4ef Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 11 Apr 2019 11:46:55 -0700 Subject: [PATCH] proto: validate UTF-8 in proto3 strings Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700 Reviewed-by: Herbie Ong --- encoding/jsonpb/decode_test.go | 13 ++-- encoding/jsonpb/encode_test.go | 9 +-- encoding/textpb/encode_test.go | 5 +- internal/cmd/generate-types/main.go | 1 + internal/cmd/generate-types/proto.go | 33 +++++++--- proto/decode.go | 25 +++++--- proto/decode_gen.go | 17 +++-- proto/decode_test.go | 95 ++++++++++++++++++++++++++++ proto/encode.go | 15 +++-- proto/encode_gen.go | 10 ++- proto/encode_test.go | 21 ++++++ 11 files changed, 200 insertions(+), 44 deletions(-) diff --git a/encoding/jsonpb/decode_test.go b/encoding/jsonpb/decode_test.go index 941e4169..cb99bcdc 100644 --- a/encoding/jsonpb/decode_test.go +++ b/encoding/jsonpb/decode_test.go @@ -5,6 +5,7 @@ package jsonpb_test import ( + "bytes" "math" "testing" @@ -2130,14 +2131,14 @@ func TestUnmarshal(t *testing.T) { "value": "` + "abc\xff" + `" }`, wantMessage: func() proto.Message { - m := &knownpb.StringValue{Value: "abc\xff"} + m := &knownpb.StringValue{Value: "abcd"} b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) if err != nil { t.Fatalf("error in binary marshaling message for Any.value: %v", err) } return &knownpb.Any{ TypeUrl: "google.protobuf.StringValue", - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), wantErr: true, @@ -2216,14 +2217,14 @@ func TestUnmarshal(t *testing.T) { "value": "` + "abc\xff" + `" }`, wantMessage: func() proto.Message { - m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}} + m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}} b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) if err != nil { t.Fatalf("error in binary marshaling message for Any.value: %v", err) } return &knownpb.Any{ TypeUrl: "google.protobuf.Value", - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), wantErr: true, @@ -2369,7 +2370,7 @@ func TestUnmarshal(t *testing.T) { } }`, wantMessage: func() proto.Message { - m1 := &knownpb.StringValue{Value: "abc\xff"} + m1 := &knownpb.StringValue{Value: "abcd"} b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m1) if err != nil { t.Fatalf("error in binary marshaling message for Any.value: %v", err) @@ -2385,7 +2386,7 @@ func TestUnmarshal(t *testing.T) { } return &knownpb.Any{ TypeUrl: "pb2.KnownTypes", - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), wantErr: true, diff --git a/encoding/jsonpb/encode_test.go b/encoding/jsonpb/encode_test.go index 1a2858e5..4277a3d0 100644 --- a/encoding/jsonpb/encode_test.go +++ b/encoding/jsonpb/encode_test.go @@ -5,6 +5,7 @@ package jsonpb_test import ( + "bytes" "encoding/hex" "math" "strings" @@ -1687,14 +1688,14 @@ func TestMarshal(t *testing.T) { Resolver: preg.NewTypes((&knownpb.StringValue{}).ProtoReflect().Type()), }, input: func() proto.Message { - m := &knownpb.StringValue{Value: "abc\xff"} + m := &knownpb.StringValue{Value: "abcd"} b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) if err != nil { t.Fatalf("error in binary marshaling message for Any.value: %v", err) } return &knownpb.Any{ TypeUrl: "google.protobuf.StringValue", - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), want: `{ @@ -1765,14 +1766,14 @@ func TestMarshal(t *testing.T) { Resolver: preg.NewTypes((&knownpb.Value{}).ProtoReflect().Type()), }, input: func() proto.Message { - m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}} + m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}} b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) if err != nil { t.Fatalf("error in binary marshaling message for Any.value: %v", err) } return &knownpb.Any{ TypeUrl: "type.googleapis.com/google.protobuf.Value", - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), want: `{ diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go index 3397d660..41cae94f 100644 --- a/encoding/textpb/encode_test.go +++ b/encoding/textpb/encode_test.go @@ -5,6 +5,7 @@ package textpb_test import ( + "bytes" "encoding/hex" "math" "strings" @@ -1248,7 +1249,7 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception" }, input: func() proto.Message { m := &pb3.Nested{ - SString: "abc\xff", + SString: "abcd", } b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m) if err != nil { @@ -1256,7 +1257,7 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception" } return &knownpb.Any{ TypeUrl: string(m.ProtoReflect().Type().FullName()), - Value: b, + Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1), } }(), want: `[pb3.Nested]: { diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go index 2f738722..c8931d41 100644 --- a/internal/cmd/generate-types/main.go +++ b/internal/cmd/generate-types/main.go @@ -312,6 +312,7 @@ func writeSource(file, src string) { "fmt", "math", "sync", + "unicode/utf8", "", "github.com/golang/protobuf/v2/internal/encoding/wire", "github.com/golang/protobuf/v2/internal/errors", diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index 3609be33..dbbc566a 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -157,8 +157,8 @@ var protoDecodeTemplate = template.Must(template.New("").Parse(` // unmarshalScalar decodes a value of the given kind. // // Message values are decoded into a []byte which aliases the input data. -func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) { - switch kind { +func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) { + switch field.Kind() { {{- range .}} case {{.Expr}}: if wtyp != {{.WireType.Expr}} { @@ -172,6 +172,13 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num if n < 0 { return val, 0, wire.ParseError(n) } + {{if (eq .Name "String") -}} + if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + var nerr errors.NonFatal + nerr.AppendInvalidUTF8(string(field.FullName())) + return protoreflect.ValueOf(string(v)), n, nerr.E + } + {{end -}} return protoreflect.ValueOf({{.ToValue}}), n, nil {{- end}} default: @@ -179,9 +186,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num } } -func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) { +func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) { var nerr errors.NonFatal - switch kind { + switch field.Kind() { {{- range .}} case {{.Expr}}: {{- if .WireType.Packable}} @@ -212,6 +219,11 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe if n < 0 { return 0, wire.ParseError(n) } + {{if (eq .Name "String") -}} + if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + nerr.AppendInvalidUTF8(string(field.FullName())) + } + {{end -}} {{if or (eq .Name "Message") (eq .Name "Group") -}} m := list.NewMessage() if err := o.unmarshalMessage(v, m); !nerr.Merge(err) { @@ -240,12 +252,17 @@ var wireTypes = map[protoreflect.Kind]wire.Type{ {{- end}} } -func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) { +func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) { var nerr errors.NonFatal - switch kind { + switch field.Kind() { {{- range .}} case {{.Expr}}: - {{if (eq .Name "Message") -}} + {{- if (eq .Name "String") }} + if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) { + nerr.AppendInvalidUTF8(string(field.FullName())) + } + {{end -}} + {{- if (eq .Name "Message") -}} var pos int var err error b, pos = appendSpeculativeLength(b) @@ -266,7 +283,7 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref {{- end}} {{- end}} default: - return b, errors.New("invalid kind %v", kind) + return b, errors.New("invalid kind %v", field.Kind()) } return b, nerr.E } diff --git a/proto/decode.go b/proto/decode.go index 3e00074d..0b1aa3fe 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -86,7 +86,7 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err case fieldType.Cardinality() != protoreflect.Repeated: valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType) case !fieldType.IsMap(): - valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind()) + valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType) default: valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType) } @@ -105,8 +105,9 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err } func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) { - v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind()) - if err != nil { + var nerr errors.NonFatal + v, n, err := o.unmarshalScalar(b, wtyp, num, field) + if !nerr.Merge(err) { return 0, err } switch field.Kind() { @@ -124,12 +125,14 @@ func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wir knownFields.Set(num, protoreflect.ValueOf(m)) } // Pass up errors (fatal and otherwise). - err = o.unmarshalMessage(v.Bytes(), m) + if err := o.unmarshalMessage(v.Bytes(), m); !nerr.Merge(err) { + return n, err + } default: // Non-message scalars replace the previous value. knownFields.Set(num, v) } - return n, err + return n, nerr.E } func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) { @@ -164,17 +167,19 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number err = errUnknown switch num { case 1: - key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind()) - if err != nil { + key, n, err = o.unmarshalScalar(b, wtyp, num, keyField) + if !nerr.Merge(err) { break } + err = nil haveKey = true case 2: var v protoreflect.Value - v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind()) - if err != nil { + v, n, err = o.unmarshalScalar(b, wtyp, num, valField) + if !nerr.Merge(err) { break } + err = nil switch valField.Kind() { case protoreflect.GroupKind, protoreflect.MessageKind: if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) { @@ -190,7 +195,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number if n < 0 { return 0, wire.ParseError(n) } - } else if !nerr.Merge(err) { + } else if err != nil { return 0, err } b = b[n:] diff --git a/proto/decode_gen.go b/proto/decode_gen.go index 51b85d7b..1a3ef15e 100644 --- a/proto/decode_gen.go +++ b/proto/decode_gen.go @@ -8,6 +8,7 @@ package proto import ( "math" + "unicode/utf8" "github.com/golang/protobuf/v2/internal/encoding/wire" "github.com/golang/protobuf/v2/internal/errors" @@ -17,8 +18,8 @@ import ( // unmarshalScalar decodes a value of the given kind. // // Message values are decoded into a []byte which aliases the input data. -func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) { - switch kind { +func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) { + switch field.Kind() { case protoreflect.BoolKind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -153,6 +154,11 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num if n < 0 { return val, 0, wire.ParseError(n) } + if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + var nerr errors.NonFatal + nerr.AppendInvalidUTF8(string(field.FullName())) + return protoreflect.ValueOf(string(v)), n, nerr.E + } return protoreflect.ValueOf(string(v)), n, nil case protoreflect.BytesKind: if wtyp != wire.BytesType { @@ -186,9 +192,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num } } -func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) { +func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) { var nerr errors.NonFatal - switch kind { + switch field.Kind() { case protoreflect.BoolKind: if wtyp == wire.BytesType { buf, n := wire.ConsumeBytes(b) @@ -547,6 +553,9 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe if n < 0 { return 0, wire.ParseError(n) } + if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + nerr.AppendInvalidUTF8(string(field.FullName())) + } list.Append(protoreflect.ValueOf(string(v))) return n, nerr.E case protoreflect.BytesKind: diff --git a/proto/decode_test.go b/proto/decode_test.go index dda4db1d..2c95f6b2 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -12,6 +12,7 @@ import ( protoV1 "github.com/golang/protobuf/proto" "github.com/golang/protobuf/v2/encoding/textpb" "github.com/golang/protobuf/v2/internal/encoding/pack" + "github.com/golang/protobuf/v2/internal/errors" "github.com/golang/protobuf/v2/internal/scalar" "github.com/golang/protobuf/v2/proto" pref "github.com/golang/protobuf/v2/reflect/protoreflect" @@ -80,6 +81,23 @@ func TestDecodeRequiredFieldChecks(t *testing.T) { } } +func TestDecodeInvalidUTF8(t *testing.T) { + for _, test := range invalidUTF8TestProtos { + for _, want := range test.decodeTo { + t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { + got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) + err := proto.Unmarshal(test.wire, got) + if !isErrInvalidUTF8(err) { + t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want)) + } + if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) { + t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want)) + } + }) + } + } +} + var testProtos = []testProto{ { desc: "basic scalar types", @@ -1158,6 +1176,69 @@ var testProtos = []testProto{ }, } +var invalidUTF8TestProtos = []testProto{ + { + desc: "invalid UTF-8 in optional string field", + decodeTo: []proto.Message{&test3pb.TestAllTypes{ + OptionalString: "abc\xff", + }}, + wire: pack.Message{ + pack.Tag{14, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in repeated string field", + decodeTo: []proto.Message{&test3pb.TestAllTypes{ + RepeatedString: []string{"foo", "abc\xff"}, + }}, + wire: pack.Message{ + pack.Tag{44, pack.BytesType}, pack.String("foo"), + pack.Tag{44, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in nested message", + decodeTo: []proto.Message{&test3pb.TestAllTypes{ + OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{ + Corecursive: &test3pb.TestAllTypes{ + OptionalString: "abc\xff", + }, + }, + }}, + wire: pack.Message{ + pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{14, pack.BytesType}, pack.String("abc\xff"), + }), + }), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in map key", + decodeTo: []proto.Message{&test3pb.TestAllTypes{ + MapStringString: map[string]string{"key\xff": "val"}, + }}, + wire: pack.Message{ + pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.BytesType}, pack.String("key\xff"), + pack.Tag{2, pack.BytesType}, pack.String("val"), + }), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in map value", + decodeTo: []proto.Message{&test3pb.TestAllTypes{ + MapStringString: map[string]string{"key": "val\xff"}, + }}, + wire: pack.Message{ + pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.BytesType}, pack.String("key"), + pack.Tag{2, pack.BytesType}, pack.String("val\xff"), + }), + }.Marshal(), + }, +} + func build(m proto.Message, opts ...buildOpt) proto.Message { for _, opt := range opts { opt(m) @@ -1185,3 +1266,17 @@ func marshalText(m proto.Message) string { b, _ := textpb.Marshal(m) return string(b) } + +func isErrInvalidUTF8(err error) bool { + nerr, ok := err.(errors.NonFatalErrors) + if !ok || len(nerr) == 0 { + return false + } + for _, err := range nerr { + if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() { + continue + } + return false + } + return true +} diff --git a/proto/encode.go b/proto/encode.go index b294392f..86357906 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -182,13 +182,13 @@ func (o MarshalOptions) marshalField(b []byte, field protoreflect.FieldDescripto switch { case field.Cardinality() != protoreflect.Repeated: b = wire.AppendTag(b, num, wireTypes[kind]) - return o.marshalSingular(b, num, kind, value) + return o.marshalSingular(b, num, field, value) case field.IsMap(): return o.marshalMap(b, num, kind, field.MessageType(), value.Map()) case field.IsPacked(): - return o.marshalPacked(b, num, kind, value.List()) + return o.marshalPacked(b, num, field, value.List()) default: - return o.marshalList(b, num, kind, value.List()) + return o.marshalList(b, num, field, value.List()) } } @@ -229,13 +229,13 @@ func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, mapsort.Range(mapv, kind, f) } -func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) { +func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) { b = wire.AppendTag(b, num, wire.BytesType) b, pos := appendSpeculativeLength(b) var nerr errors.NonFatal for i, llen := 0, list.Len(); i < llen; i++ { var err error - b, err = o.marshalSingular(b, num, kind, list.Get(i)) + b, err = o.marshalSingular(b, num, field, list.Get(i)) if !nerr.Merge(err) { return b, err } @@ -244,12 +244,13 @@ func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protorefle return b, nerr.E } -func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) { +func (o MarshalOptions) marshalList(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) { + kind := field.Kind() var nerr errors.NonFatal for i, llen := 0, list.Len(); i < llen; i++ { var err error b = wire.AppendTag(b, num, wireTypes[kind]) - b, err = o.marshalSingular(b, num, kind, list.Get(i)) + b, err = o.marshalSingular(b, num, field, list.Get(i)) if !nerr.Merge(err) { return b, err } diff --git a/proto/encode_gen.go b/proto/encode_gen.go index 46621c84..4919b96d 100644 --- a/proto/encode_gen.go +++ b/proto/encode_gen.go @@ -8,6 +8,7 @@ package proto import ( "math" + "unicode/utf8" "github.com/golang/protobuf/v2/internal/encoding/wire" "github.com/golang/protobuf/v2/internal/errors" @@ -35,9 +36,9 @@ var wireTypes = map[protoreflect.Kind]wire.Type{ protoreflect.GroupKind: wire.StartGroupType, } -func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) { +func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) { var nerr errors.NonFatal - switch kind { + switch field.Kind() { case protoreflect.BoolKind: b = wire.AppendVarint(b, wire.EncodeBool(v.Bool())) case protoreflect.EnumKind: @@ -67,6 +68,9 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref case protoreflect.DoubleKind: b = wire.AppendFixed64(b, math.Float64bits(v.Float())) case protoreflect.StringKind: + if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) { + nerr.AppendInvalidUTF8(string(field.FullName())) + } b = wire.AppendBytes(b, []byte(v.String())) case protoreflect.BytesKind: b = wire.AppendBytes(b, v.Bytes()) @@ -87,7 +91,7 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref } b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType)) default: - return b, errors.New("invalid kind %v", kind) + return b, errors.New("invalid kind %v", field.Kind()) } return b, nerr.E } diff --git a/proto/encode_test.go b/proto/encode_test.go index 30722e05..d670edfd 100644 --- a/proto/encode_test.go +++ b/proto/encode_test.go @@ -92,6 +92,27 @@ func TestEncodeDeterministic(t *testing.T) { } } +func TestEncodeInvalidUTF8(t *testing.T) { + for _, test := range invalidUTF8TestProtos { + for _, want := range test.decodeTo { + t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { + wire, err := proto.Marshal(want) + if !isErrInvalidUTF8(err) { + t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want)) + } + got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) + if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) { + t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want)) + return + } + if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) { + t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want)) + } + }) + } + } +} + func TestEncodeRequiredFieldChecks(t *testing.T) { for _, test := range testProtos { if !test.partial {