mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2024-12-27 15:26:51 +00:00
proto: validate UTF-8 in proto3 strings
Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700 Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
parent
a5f43e834e
commit
bc310b58c6
@ -5,6 +5,7 @@
|
||||
package jsonpb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
@ -2130,14 +2131,14 @@ func TestUnmarshal(t *testing.T) {
|
||||
"value": "` + "abc\xff" + `"
|
||||
}`,
|
||||
wantMessage: func() proto.Message {
|
||||
m := &knownpb.StringValue{Value: "abc\xff"}
|
||||
m := &knownpb.StringValue{Value: "abcd"}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: "google.protobuf.StringValue",
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
wantErr: true,
|
||||
@ -2216,14 +2217,14 @@ func TestUnmarshal(t *testing.T) {
|
||||
"value": "` + "abc\xff" + `"
|
||||
}`,
|
||||
wantMessage: func() proto.Message {
|
||||
m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}}
|
||||
m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: "google.protobuf.Value",
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
wantErr: true,
|
||||
@ -2369,7 +2370,7 @@ func TestUnmarshal(t *testing.T) {
|
||||
}
|
||||
}`,
|
||||
wantMessage: func() proto.Message {
|
||||
m1 := &knownpb.StringValue{Value: "abc\xff"}
|
||||
m1 := &knownpb.StringValue{Value: "abcd"}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m1)
|
||||
if err != nil {
|
||||
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
|
||||
@ -2385,7 +2386,7 @@ func TestUnmarshal(t *testing.T) {
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: "pb2.KnownTypes",
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
wantErr: true,
|
||||
|
@ -5,6 +5,7 @@
|
||||
package jsonpb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"math"
|
||||
"strings"
|
||||
@ -1687,14 +1688,14 @@ func TestMarshal(t *testing.T) {
|
||||
Resolver: preg.NewTypes((&knownpb.StringValue{}).ProtoReflect().Type()),
|
||||
},
|
||||
input: func() proto.Message {
|
||||
m := &knownpb.StringValue{Value: "abc\xff"}
|
||||
m := &knownpb.StringValue{Value: "abcd"}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: "google.protobuf.StringValue",
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
want: `{
|
||||
@ -1765,14 +1766,14 @@ func TestMarshal(t *testing.T) {
|
||||
Resolver: preg.NewTypes((&knownpb.Value{}).ProtoReflect().Type()),
|
||||
},
|
||||
input: func() proto.Message {
|
||||
m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abc\xff"}}
|
||||
m := &knownpb.Value{Kind: &knownpb.Value_StringValue{"abcd"}}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: "type.googleapis.com/google.protobuf.Value",
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
want: `{
|
||||
|
@ -5,6 +5,7 @@
|
||||
package textpb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"math"
|
||||
"strings"
|
||||
@ -1248,7 +1249,7 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
|
||||
},
|
||||
input: func() proto.Message {
|
||||
m := &pb3.Nested{
|
||||
SString: "abc\xff",
|
||||
SString: "abcd",
|
||||
}
|
||||
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
|
||||
if err != nil {
|
||||
@ -1256,7 +1257,7 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
|
||||
}
|
||||
return &knownpb.Any{
|
||||
TypeUrl: string(m.ProtoReflect().Type().FullName()),
|
||||
Value: b,
|
||||
Value: bytes.Replace(b, []byte("abcd"), []byte("abc\xff"), -1),
|
||||
}
|
||||
}(),
|
||||
want: `[pb3.Nested]: {
|
||||
|
@ -312,6 +312,7 @@ func writeSource(file, src string) {
|
||||
"fmt",
|
||||
"math",
|
||||
"sync",
|
||||
"unicode/utf8",
|
||||
"",
|
||||
"github.com/golang/protobuf/v2/internal/encoding/wire",
|
||||
"github.com/golang/protobuf/v2/internal/errors",
|
||||
|
@ -157,8 +157,8 @@ var protoDecodeTemplate = template.Must(template.New("").Parse(`
|
||||
// unmarshalScalar decodes a value of the given kind.
|
||||
//
|
||||
// Message values are decoded into a []byte which aliases the input data.
|
||||
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) {
|
||||
switch kind {
|
||||
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
|
||||
switch field.Kind() {
|
||||
{{- range .}}
|
||||
case {{.Expr}}:
|
||||
if wtyp != {{.WireType.Expr}} {
|
||||
@ -172,6 +172,13 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
|
||||
if n < 0 {
|
||||
return val, 0, wire.ParseError(n)
|
||||
}
|
||||
{{if (eq .Name "String") -}}
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
|
||||
var nerr errors.NonFatal
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
return protoreflect.ValueOf(string(v)), n, nerr.E
|
||||
}
|
||||
{{end -}}
|
||||
return protoreflect.ValueOf({{.ToValue}}), n, nil
|
||||
{{- end}}
|
||||
default:
|
||||
@ -179,9 +186,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
|
||||
}
|
||||
}
|
||||
|
||||
func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) {
|
||||
func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
|
||||
var nerr errors.NonFatal
|
||||
switch kind {
|
||||
switch field.Kind() {
|
||||
{{- range .}}
|
||||
case {{.Expr}}:
|
||||
{{- if .WireType.Packable}}
|
||||
@ -212,6 +219,11 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
{{if (eq .Name "String") -}}
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
}
|
||||
{{end -}}
|
||||
{{if or (eq .Name "Message") (eq .Name "Group") -}}
|
||||
m := list.NewMessage()
|
||||
if err := o.unmarshalMessage(v, m); !nerr.Merge(err) {
|
||||
@ -240,12 +252,17 @@ var wireTypes = map[protoreflect.Kind]wire.Type{
|
||||
{{- end}}
|
||||
}
|
||||
|
||||
func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) {
|
||||
func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
|
||||
var nerr errors.NonFatal
|
||||
switch kind {
|
||||
switch field.Kind() {
|
||||
{{- range .}}
|
||||
case {{.Expr}}:
|
||||
{{if (eq .Name "Message") -}}
|
||||
{{- if (eq .Name "String") }}
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
}
|
||||
{{end -}}
|
||||
{{- if (eq .Name "Message") -}}
|
||||
var pos int
|
||||
var err error
|
||||
b, pos = appendSpeculativeLength(b)
|
||||
@ -266,7 +283,7 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
default:
|
||||
return b, errors.New("invalid kind %v", kind)
|
||||
return b, errors.New("invalid kind %v", field.Kind())
|
||||
}
|
||||
return b, nerr.E
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
|
||||
case fieldType.Cardinality() != protoreflect.Repeated:
|
||||
valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
|
||||
case !fieldType.IsMap():
|
||||
valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
|
||||
valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType)
|
||||
default:
|
||||
valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
|
||||
}
|
||||
@ -105,8 +105,9 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err
|
||||
}
|
||||
|
||||
func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
|
||||
v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
|
||||
if err != nil {
|
||||
var nerr errors.NonFatal
|
||||
v, n, err := o.unmarshalScalar(b, wtyp, num, field)
|
||||
if !nerr.Merge(err) {
|
||||
return 0, err
|
||||
}
|
||||
switch field.Kind() {
|
||||
@ -124,12 +125,14 @@ func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wir
|
||||
knownFields.Set(num, protoreflect.ValueOf(m))
|
||||
}
|
||||
// Pass up errors (fatal and otherwise).
|
||||
err = o.unmarshalMessage(v.Bytes(), m)
|
||||
if err := o.unmarshalMessage(v.Bytes(), m); !nerr.Merge(err) {
|
||||
return n, err
|
||||
}
|
||||
default:
|
||||
// Non-message scalars replace the previous value.
|
||||
knownFields.Set(num, v)
|
||||
}
|
||||
return n, err
|
||||
return n, nerr.E
|
||||
}
|
||||
|
||||
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
|
||||
@ -164,17 +167,19 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
|
||||
err = errUnknown
|
||||
switch num {
|
||||
case 1:
|
||||
key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
|
||||
if err != nil {
|
||||
key, n, err = o.unmarshalScalar(b, wtyp, num, keyField)
|
||||
if !nerr.Merge(err) {
|
||||
break
|
||||
}
|
||||
err = nil
|
||||
haveKey = true
|
||||
case 2:
|
||||
var v protoreflect.Value
|
||||
v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
|
||||
if err != nil {
|
||||
v, n, err = o.unmarshalScalar(b, wtyp, num, valField)
|
||||
if !nerr.Merge(err) {
|
||||
break
|
||||
}
|
||||
err = nil
|
||||
switch valField.Kind() {
|
||||
case protoreflect.GroupKind, protoreflect.MessageKind:
|
||||
if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) {
|
||||
@ -190,7 +195,7 @@ func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
} else if !nerr.Merge(err) {
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b = b[n:]
|
||||
|
@ -8,6 +8,7 @@ package proto
|
||||
|
||||
import (
|
||||
"math"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/golang/protobuf/v2/internal/encoding/wire"
|
||||
"github.com/golang/protobuf/v2/internal/errors"
|
||||
@ -17,8 +18,8 @@ import (
|
||||
// unmarshalScalar decodes a value of the given kind.
|
||||
//
|
||||
// Message values are decoded into a []byte which aliases the input data.
|
||||
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) {
|
||||
switch kind {
|
||||
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
|
||||
switch field.Kind() {
|
||||
case protoreflect.BoolKind:
|
||||
if wtyp != wire.VarintType {
|
||||
return val, 0, errUnknown
|
||||
@ -153,6 +154,11 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
|
||||
if n < 0 {
|
||||
return val, 0, wire.ParseError(n)
|
||||
}
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
|
||||
var nerr errors.NonFatal
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
return protoreflect.ValueOf(string(v)), n, nerr.E
|
||||
}
|
||||
return protoreflect.ValueOf(string(v)), n, nil
|
||||
case protoreflect.BytesKind:
|
||||
if wtyp != wire.BytesType {
|
||||
@ -186,9 +192,9 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Num
|
||||
}
|
||||
}
|
||||
|
||||
func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) {
|
||||
func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
|
||||
var nerr errors.NonFatal
|
||||
switch kind {
|
||||
switch field.Kind() {
|
||||
case protoreflect.BoolKind:
|
||||
if wtyp == wire.BytesType {
|
||||
buf, n := wire.ConsumeBytes(b)
|
||||
@ -547,6 +553,9 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Numbe
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
}
|
||||
list.Append(protoreflect.ValueOf(string(v)))
|
||||
return n, nerr.E
|
||||
case protoreflect.BytesKind:
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
protoV1 "github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/v2/encoding/textpb"
|
||||
"github.com/golang/protobuf/v2/internal/encoding/pack"
|
||||
"github.com/golang/protobuf/v2/internal/errors"
|
||||
"github.com/golang/protobuf/v2/internal/scalar"
|
||||
"github.com/golang/protobuf/v2/proto"
|
||||
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
|
||||
@ -80,6 +81,23 @@ func TestDecodeRequiredFieldChecks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInvalidUTF8(t *testing.T) {
|
||||
for _, test := range invalidUTF8TestProtos {
|
||||
for _, want := range test.decodeTo {
|
||||
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
|
||||
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
||||
err := proto.Unmarshal(test.wire, got)
|
||||
if !isErrInvalidUTF8(err) {
|
||||
t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
|
||||
}
|
||||
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var testProtos = []testProto{
|
||||
{
|
||||
desc: "basic scalar types",
|
||||
@ -1158,6 +1176,69 @@ var testProtos = []testProto{
|
||||
},
|
||||
}
|
||||
|
||||
var invalidUTF8TestProtos = []testProto{
|
||||
{
|
||||
desc: "invalid UTF-8 in optional string field",
|
||||
decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
||||
OptionalString: "abc\xff",
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "invalid UTF-8 in repeated string field",
|
||||
decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
||||
RepeatedString: []string{"foo", "abc\xff"},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{44, pack.BytesType}, pack.String("foo"),
|
||||
pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "invalid UTF-8 in nested message",
|
||||
decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
||||
OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
|
||||
Corecursive: &test3pb.TestAllTypes{
|
||||
OptionalString: "abc\xff",
|
||||
},
|
||||
},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
|
||||
}),
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "invalid UTF-8 in map key",
|
||||
decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
||||
MapStringString: map[string]string{"key\xff": "val"},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
|
||||
pack.Tag{2, pack.BytesType}, pack.String("val"),
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "invalid UTF-8 in map value",
|
||||
decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
||||
MapStringString: map[string]string{"key": "val\xff"},
|
||||
}},
|
||||
wire: pack.Message{
|
||||
pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.BytesType}, pack.String("key"),
|
||||
pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
|
||||
}),
|
||||
}.Marshal(),
|
||||
},
|
||||
}
|
||||
|
||||
func build(m proto.Message, opts ...buildOpt) proto.Message {
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
@ -1185,3 +1266,17 @@ func marshalText(m proto.Message) string {
|
||||
b, _ := textpb.Marshal(m)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func isErrInvalidUTF8(err error) bool {
|
||||
nerr, ok := err.(errors.NonFatalErrors)
|
||||
if !ok || len(nerr) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, err := range nerr {
|
||||
if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -182,13 +182,13 @@ func (o MarshalOptions) marshalField(b []byte, field protoreflect.FieldDescripto
|
||||
switch {
|
||||
case field.Cardinality() != protoreflect.Repeated:
|
||||
b = wire.AppendTag(b, num, wireTypes[kind])
|
||||
return o.marshalSingular(b, num, kind, value)
|
||||
return o.marshalSingular(b, num, field, value)
|
||||
case field.IsMap():
|
||||
return o.marshalMap(b, num, kind, field.MessageType(), value.Map())
|
||||
case field.IsPacked():
|
||||
return o.marshalPacked(b, num, kind, value.List())
|
||||
return o.marshalPacked(b, num, field, value.List())
|
||||
default:
|
||||
return o.marshalList(b, num, kind, value.List())
|
||||
return o.marshalList(b, num, field, value.List())
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,13 +229,13 @@ func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind,
|
||||
mapsort.Range(mapv, kind, f)
|
||||
}
|
||||
|
||||
func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
|
||||
func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
|
||||
b = wire.AppendTag(b, num, wire.BytesType)
|
||||
b, pos := appendSpeculativeLength(b)
|
||||
var nerr errors.NonFatal
|
||||
for i, llen := 0, list.Len(); i < llen; i++ {
|
||||
var err error
|
||||
b, err = o.marshalSingular(b, num, kind, list.Get(i))
|
||||
b, err = o.marshalSingular(b, num, field, list.Get(i))
|
||||
if !nerr.Merge(err) {
|
||||
return b, err
|
||||
}
|
||||
@ -244,12 +244,13 @@ func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protorefle
|
||||
return b, nerr.E
|
||||
}
|
||||
|
||||
func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
|
||||
func (o MarshalOptions) marshalList(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
|
||||
kind := field.Kind()
|
||||
var nerr errors.NonFatal
|
||||
for i, llen := 0, list.Len(); i < llen; i++ {
|
||||
var err error
|
||||
b = wire.AppendTag(b, num, wireTypes[kind])
|
||||
b, err = o.marshalSingular(b, num, kind, list.Get(i))
|
||||
b, err = o.marshalSingular(b, num, field, list.Get(i))
|
||||
if !nerr.Merge(err) {
|
||||
return b, err
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ package proto
|
||||
|
||||
import (
|
||||
"math"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/golang/protobuf/v2/internal/encoding/wire"
|
||||
"github.com/golang/protobuf/v2/internal/errors"
|
||||
@ -35,9 +36,9 @@ var wireTypes = map[protoreflect.Kind]wire.Type{
|
||||
protoreflect.GroupKind: wire.StartGroupType,
|
||||
}
|
||||
|
||||
func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) {
|
||||
func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
|
||||
var nerr errors.NonFatal
|
||||
switch kind {
|
||||
switch field.Kind() {
|
||||
case protoreflect.BoolKind:
|
||||
b = wire.AppendVarint(b, wire.EncodeBool(v.Bool()))
|
||||
case protoreflect.EnumKind:
|
||||
@ -67,6 +68,9 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref
|
||||
case protoreflect.DoubleKind:
|
||||
b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
|
||||
case protoreflect.StringKind:
|
||||
if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
|
||||
nerr.AppendInvalidUTF8(string(field.FullName()))
|
||||
}
|
||||
b = wire.AppendBytes(b, []byte(v.String()))
|
||||
case protoreflect.BytesKind:
|
||||
b = wire.AppendBytes(b, v.Bytes())
|
||||
@ -87,7 +91,7 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref
|
||||
}
|
||||
b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType))
|
||||
default:
|
||||
return b, errors.New("invalid kind %v", kind)
|
||||
return b, errors.New("invalid kind %v", field.Kind())
|
||||
}
|
||||
return b, nerr.E
|
||||
}
|
||||
|
@ -92,6 +92,27 @@ func TestEncodeDeterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeInvalidUTF8(t *testing.T) {
|
||||
for _, test := range invalidUTF8TestProtos {
|
||||
for _, want := range test.decodeTo {
|
||||
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
|
||||
wire, err := proto.Marshal(want)
|
||||
if !isErrInvalidUTF8(err) {
|
||||
t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
|
||||
}
|
||||
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
||||
if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
|
||||
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
return
|
||||
}
|
||||
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeRequiredFieldChecks(t *testing.T) {
|
||||
for _, test := range testProtos {
|
||||
if !test.partial {
|
||||
|
Loading…
Reference in New Issue
Block a user