all: support enforce_utf8 override

In 2014, when proto3 was being developed, there were a number of early
adopters of the new syntax. Before the finalization of proto3 when
it was released in open-source in July 2016, a decision was made to
strictly validate strings in proto3. However, some of the early adopters
were already using invalid UTF-8 with string fields.
The google.protobuf.FieldOptions.enforce_utf8 option only exists to support
those grandfathered users where they can opt-out of the validation logic.
Practical use of that option in open source is impossible even if a user
specifies the proto1_legacy build tag since it requires a hacked
variant of descriptor.proto that is not externally available.

This CL supports enforce_utf8 by modifiyng internal/filedesc to
expose the flag if it detects it in the raw descriptor.
We add an strs.EnforceUTF8 function as a centralized place to determine
whether to perform validation. Validation opt-out is supported
only in builds with legacy support.

We implement support for validating UTF-8 in all proto3 string fields,
even if they are backed by a Go []byte.

Change-Id: I9c0628b84909bc7181125f09db730c80d490e485
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186002
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-07-13 00:44:41 -07:00
parent 302cb325fb
commit c51e2e0293
13 changed files with 634 additions and 173 deletions

View File

@ -95,7 +95,42 @@ var coder{{.Name}} = pointerCoderFuncs{
unmarshal: consume{{.Name}},
}
// size{{.Name}} returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
{{if or (eq .Name "Bytes") (eq .Name "String")}}
// append{{.Name}}ValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}.
func append{{.Name}}ValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.{{.GoType.PointerMethod}}()
b = wire.AppendVarint(b, wiretag)
{{template "Append" .}}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consume{{.Name}}ValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoType}}
return n, nil
}
var coder{{.Name}}ValidateUTF8 = pointerCoderFuncs{
size: size{{.Name}},
marshal: append{{.Name}}ValidateUTF8,
unmarshal: consume{{.Name}}ValidateUTF8,
}
{{end}}
// size{{.Name}}NoZero returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not encoded.
func size{{.Name}}NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.{{.GoType.PointerMethod}}()
@ -105,7 +140,7 @@ func size{{.Name}}NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + {{template "Size" .}}
}
// append{{.Name}} wire encodes a {{.GoType}} pointer as a {{.Name}}.
// append{{.Name}}NoZero wire encodes a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not encoded.
func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.{{.GoType.PointerMethod}}()
@ -123,6 +158,29 @@ var coder{{.Name}}NoZero = pointerCoderFuncs{
unmarshal: consume{{.Name}},
}
{{if or (eq .Name "Bytes") (eq .Name "String")}}
// append{{.Name}}NoZeroValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not encoded.
func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.{{.GoType.PointerMethod}}()
if {{template "IsZero" .}} {
return b, nil
}
b = wire.AppendVarint(b, wiretag)
{{template "Append" .}}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZeroValidateUTF8,
unmarshal: consume{{.Name}}ValidateUTF8,
}
{{end}}
{{- if not .NoPointer}}
// size{{.Name}}Ptr returns the size of wire encoding a *{{.GoType}} pointer as a {{.Name}}.
// It panics if the pointer is nil.
@ -228,6 +286,44 @@ var coder{{.Name}}Slice = pointerCoderFuncs{
unmarshal: consume{{.Name}}Slice,
}
{{if or (eq .Name "Bytes") (eq .Name "String")}}
// append{{.Name}}SliceValidateUTF8 encodes a []{{.GoType}} pointer as a repeated {{.Name}}.
func append{{.Name}}SliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
s := *p.{{.GoType.PointerMethod}}Slice()
for _, v := range s {
b = wire.AppendVarint(b, wiretag)
{{template "Append" .}}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return b, errInvalidUTF8{}
}
}
return b, nil
}
// consume{{.Name}}SliceValidateUTF8 wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}.
func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
sp := p.{{.GoType.PointerMethod}}Slice()
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, {{.ToGoType}})
return n, nil
}
var coder{{.Name}}SliceValidateUTF8 = pointerCoderFuncs{
size: size{{.Name}}Slice,
marshal: append{{.Name}}SliceValidateUTF8,
unmarshal: consume{{.Name}}SliceValidateUTF8,
}
{{end}}
{{if or (eq .WireType "Varint") (eq .WireType "Fixed32") (eq .WireType "Fixed64")}}
// size{{.Name}}PackedSlice returns the size of wire encoding a []{{.GoType}} pointer as a packed repeated {{.Name}}.
func size{{.Name}}PackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
@ -309,6 +405,40 @@ var coder{{.Name}}Iface = ifaceCoderFuncs{
unmarshal: consume{{.Name}}Iface,
}
{{if or (eq .Name "Bytes") (eq .Name "String")}}
// append{{.Name}}IfaceValidateUTF8 encodes a {{.GoType}} value as a {{.Name}}.
func append{{.Name}}IfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := ival.({{.GoType}})
b = wire.AppendVarint(b, wiretag)
{{template "Append" .}}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consume{{.Name}}IfaceValidateUTF8 decodes a {{.GoType}} value as a {{.Name}}.
func consume{{.Name}}IfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != {{.WireType.Expr}} {
return nil, 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return nil, 0, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return nil, 0, errInvalidUTF8{}
}
return {{.ToGoType}}, n, nil
}
var coder{{.Name}}IfaceValidateUTF8 = ifaceCoderFuncs{
size: size{{.Name}}Iface,
marshal: append{{.Name}}IfaceValidateUTF8,
unmarshal: consume{{.Name}}IfaceValidateUTF8,
}
{{end}}
// size{{.Name}}SliceIface returns the size of wire encoding a []{{.GoType}} value as a repeated {{.Name}}.
func size{{.Name}}SliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[]{{.GoType}})

View File

@ -191,6 +191,7 @@ func writeSource(file, src string) {
"google.golang.org/protobuf/internal/descfmt",
"google.golang.org/protobuf/internal/encoding/wire",
"google.golang.org/protobuf/internal/errors",
"google.golang.org/protobuf/internal/strs",
"google.golang.org/protobuf/internal/pragma",
"google.golang.org/protobuf/reflect/protoreflect",
"google.golang.org/protobuf/runtime/protoiface",

View File

@ -276,7 +276,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl
return val, 0, wire.ParseError(n)
}
{{if (eq .Name "String") -}}
if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
@ -320,7 +320,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl
return 0, wire.ParseError(n)
}
{{if (eq .Name "String") -}}
if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
@ -357,7 +357,7 @@ func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescripto
{{- range .}}
case {{.Expr}}:
{{- if (eq .Name "String") }}
if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
return b, errors.InvalidUTF8(string(fd.FullName()))
}
b = wire.AppendString(b, {{.FromValue}})

View File

@ -200,6 +200,8 @@ type (
IsWeak bool // promoted from google.protobuf.FieldOptions
HasPacked bool // promoted from google.protobuf.FieldOptions
IsPacked bool // promoted from google.protobuf.FieldOptions
HasEnforceUTF8 bool // promoted from google.protobuf.FieldOptions
EnforceUTF8 bool // promoted from google.protobuf.FieldOptions
Default defaultValue
ContainingOneof pref.OneofDescriptor // must be consistent with Message.Oneofs.Fields
Enum pref.EnumDescriptor
@ -303,6 +305,20 @@ func (fd *Field) Message() pref.MessageDescriptor { return fd.L1.Message }
func (fd *Field) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, fd) }
func (fd *Field) ProtoType(pref.FieldDescriptor) {}
// EnforceUTF8 is a pseudo-internal API to determine whether to enforce UTF-8
// validation for the string field. This exists for Google-internal use only
// since proto3 did not enforce UTF-8 validity prior to the open-source release.
// If this method does not exist, the default is to enforce valid UTF-8.
//
// WARNING: This method is exempt from the compatibility promise and may be
// removed in the future without warning.
func (fd *Field) EnforceUTF8() bool {
if fd.L1.HasEnforceUTF8 {
return fd.L1.EnforceUTF8
}
return fd.L0.ParentFile.L1.Syntax == pref.Proto3
}
func (od *Oneof) Options() pref.ProtoMessage {
if f := od.L1.Options; f != nil {
return f()

View File

@ -480,6 +480,8 @@ func (fd *Field) unmarshalFull(b []byte, sb *strs.Builder, pf *File, pd pref.Des
}
func (fd *Field) unmarshalOptions(b []byte) {
const FieldOptions_EnforceUTF8 = 13
for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b)
b = b[n:]
@ -493,6 +495,9 @@ func (fd *Field) unmarshalOptions(b []byte) {
fd.L1.IsPacked = wire.DecodeBool(v)
case fieldnum.FieldOptions_Weak:
fd.L1.IsWeak = wire.DecodeBool(v)
case FieldOptions_EnforceUTF8:
fd.L1.HasEnforceUTF8 = true
fd.L1.EnforceUTF8 = !wire.DecodeBool(v)
}
default:
m := wire.ConsumeFieldValue(num, typ, b)

View File

@ -6,7 +6,6 @@ package impl
import (
"reflect"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
@ -747,136 +746,6 @@ var coderEnumPackedSliceIface = ifaceCoderFuncs{
unmarshal: consumeEnumSliceIface,
}
// Strings with UTF8 validation.
func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*p.String() = v
return n, nil
}
var coderStringValidateUTF8 = pointerCoderFuncs{
size: sizeString,
marshal: appendStringValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
if len(v) == 0 {
return b, nil
}
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeStringNoZero,
marshal: appendStringNoZeroValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
func sizeStringSliceValidateUTF8(p pointer, tagsize int, _ marshalOptions) (size int) {
s := *p.StringSlice()
for _, v := range s {
size += tagsize + wire.SizeBytes(len(v))
}
return size
}
func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
s := *p.StringSlice()
var err error
for _, v := range s {
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
}
return b, err
}
func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
sp := p.StringSlice()
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, v)
return n, nil
}
var coderStringSliceValidateUTF8 = pointerCoderFuncs{
size: sizeStringSliceValidateUTF8,
marshal: appendStringSliceValidateUTF8,
unmarshal: consumeStringSliceValidateUTF8,
}
func sizeStringIfaceValidateUTF8(ival interface{}, tagsize int, _ marshalOptions) int {
v := ival.(string)
return tagsize + wire.SizeBytes(len(v))
}
func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := ival.(string)
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != wire.BytesType {
return nil, 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return nil, 0, errInvalidUTF8{}
}
return v, n, nil
}
var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
size: sizeStringIfaceValidateUTF8,
marshal: appendStringIfaceValidateUTF8,
unmarshal: consumeStringIfaceValidateUTF8,
}
func asMessage(v reflect.Value) pref.ProtoMessage {
if m, ok := v.Interface().(pref.ProtoMessage); ok {
return m

View File

@ -8,6 +8,7 @@ package impl
import (
"math"
"unicode/utf8"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/reflect/protoreflect"
@ -46,7 +47,7 @@ var coderBool = pointerCoderFuncs{
unmarshal: consumeBool,
}
// sizeBool returns the size of wire encoding a bool pointer as a Bool.
// sizeBoolNoZero returns the size of wire encoding a bool pointer as a Bool.
// The zero value is not encoded.
func sizeBoolNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Bool()
@ -56,7 +57,7 @@ func sizeBoolNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(wire.EncodeBool(v))
}
// appendBool wire encodes a bool pointer as a Bool.
// appendBoolNoZero wire encodes a bool pointer as a Bool.
// The zero value is not encoded.
func appendBoolNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bool()
@ -364,7 +365,7 @@ var coderInt32 = pointerCoderFuncs{
unmarshal: consumeInt32,
}
// sizeInt32 returns the size of wire encoding a int32 pointer as a Int32.
// sizeInt32NoZero returns the size of wire encoding a int32 pointer as a Int32.
// The zero value is not encoded.
func sizeInt32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@ -374,7 +375,7 @@ func sizeInt32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(uint64(v))
}
// appendInt32 wire encodes a int32 pointer as a Int32.
// appendInt32NoZero wire encodes a int32 pointer as a Int32.
// The zero value is not encoded.
func appendInt32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@ -682,7 +683,7 @@ var coderSint32 = pointerCoderFuncs{
unmarshal: consumeSint32,
}
// sizeSint32 returns the size of wire encoding a int32 pointer as a Sint32.
// sizeSint32NoZero returns the size of wire encoding a int32 pointer as a Sint32.
// The zero value is not encoded.
func sizeSint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@ -692,7 +693,7 @@ func sizeSint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(wire.EncodeZigZag(int64(v)))
}
// appendSint32 wire encodes a int32 pointer as a Sint32.
// appendSint32NoZero wire encodes a int32 pointer as a Sint32.
// The zero value is not encoded.
func appendSint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@ -1000,7 +1001,7 @@ var coderUint32 = pointerCoderFuncs{
unmarshal: consumeUint32,
}
// sizeUint32 returns the size of wire encoding a uint32 pointer as a Uint32.
// sizeUint32NoZero returns the size of wire encoding a uint32 pointer as a Uint32.
// The zero value is not encoded.
func sizeUint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint32()
@ -1010,7 +1011,7 @@ func sizeUint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(uint64(v))
}
// appendUint32 wire encodes a uint32 pointer as a Uint32.
// appendUint32NoZero wire encodes a uint32 pointer as a Uint32.
// The zero value is not encoded.
func appendUint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint32()
@ -1318,7 +1319,7 @@ var coderInt64 = pointerCoderFuncs{
unmarshal: consumeInt64,
}
// sizeInt64 returns the size of wire encoding a int64 pointer as a Int64.
// sizeInt64NoZero returns the size of wire encoding a int64 pointer as a Int64.
// The zero value is not encoded.
func sizeInt64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@ -1328,7 +1329,7 @@ func sizeInt64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(uint64(v))
}
// appendInt64 wire encodes a int64 pointer as a Int64.
// appendInt64NoZero wire encodes a int64 pointer as a Int64.
// The zero value is not encoded.
func appendInt64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@ -1636,7 +1637,7 @@ var coderSint64 = pointerCoderFuncs{
unmarshal: consumeSint64,
}
// sizeSint64 returns the size of wire encoding a int64 pointer as a Sint64.
// sizeSint64NoZero returns the size of wire encoding a int64 pointer as a Sint64.
// The zero value is not encoded.
func sizeSint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@ -1646,7 +1647,7 @@ func sizeSint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(wire.EncodeZigZag(v))
}
// appendSint64 wire encodes a int64 pointer as a Sint64.
// appendSint64NoZero wire encodes a int64 pointer as a Sint64.
// The zero value is not encoded.
func appendSint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@ -1954,7 +1955,7 @@ var coderUint64 = pointerCoderFuncs{
unmarshal: consumeUint64,
}
// sizeUint64 returns the size of wire encoding a uint64 pointer as a Uint64.
// sizeUint64NoZero returns the size of wire encoding a uint64 pointer as a Uint64.
// The zero value is not encoded.
func sizeUint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint64()
@ -1964,7 +1965,7 @@ func sizeUint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeVarint(v)
}
// appendUint64 wire encodes a uint64 pointer as a Uint64.
// appendUint64NoZero wire encodes a uint64 pointer as a Uint64.
// The zero value is not encoded.
func appendUint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint64()
@ -2272,7 +2273,7 @@ var coderSfixed32 = pointerCoderFuncs{
unmarshal: consumeSfixed32,
}
// sizeSfixed32 returns the size of wire encoding a int32 pointer as a Sfixed32.
// sizeSfixed32NoZero returns the size of wire encoding a int32 pointer as a Sfixed32.
// The zero value is not encoded.
func sizeSfixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int32()
@ -2282,7 +2283,7 @@ func sizeSfixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed32()
}
// appendSfixed32 wire encodes a int32 pointer as a Sfixed32.
// appendSfixed32NoZero wire encodes a int32 pointer as a Sfixed32.
// The zero value is not encoded.
func appendSfixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int32()
@ -2572,7 +2573,7 @@ var coderFixed32 = pointerCoderFuncs{
unmarshal: consumeFixed32,
}
// sizeFixed32 returns the size of wire encoding a uint32 pointer as a Fixed32.
// sizeFixed32NoZero returns the size of wire encoding a uint32 pointer as a Fixed32.
// The zero value is not encoded.
func sizeFixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint32()
@ -2582,7 +2583,7 @@ func sizeFixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed32()
}
// appendFixed32 wire encodes a uint32 pointer as a Fixed32.
// appendFixed32NoZero wire encodes a uint32 pointer as a Fixed32.
// The zero value is not encoded.
func appendFixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint32()
@ -2872,7 +2873,7 @@ var coderFloat = pointerCoderFuncs{
unmarshal: consumeFloat,
}
// sizeFloat returns the size of wire encoding a float32 pointer as a Float.
// sizeFloatNoZero returns the size of wire encoding a float32 pointer as a Float.
// The zero value is not encoded.
func sizeFloatNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Float32()
@ -2882,7 +2883,7 @@ func sizeFloatNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed32()
}
// appendFloat wire encodes a float32 pointer as a Float.
// appendFloatNoZero wire encodes a float32 pointer as a Float.
// The zero value is not encoded.
func appendFloatNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Float32()
@ -3172,7 +3173,7 @@ var coderSfixed64 = pointerCoderFuncs{
unmarshal: consumeSfixed64,
}
// sizeSfixed64 returns the size of wire encoding a int64 pointer as a Sfixed64.
// sizeSfixed64NoZero returns the size of wire encoding a int64 pointer as a Sfixed64.
// The zero value is not encoded.
func sizeSfixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Int64()
@ -3182,7 +3183,7 @@ func sizeSfixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed64()
}
// appendSfixed64 wire encodes a int64 pointer as a Sfixed64.
// appendSfixed64NoZero wire encodes a int64 pointer as a Sfixed64.
// The zero value is not encoded.
func appendSfixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Int64()
@ -3472,7 +3473,7 @@ var coderFixed64 = pointerCoderFuncs{
unmarshal: consumeFixed64,
}
// sizeFixed64 returns the size of wire encoding a uint64 pointer as a Fixed64.
// sizeFixed64NoZero returns the size of wire encoding a uint64 pointer as a Fixed64.
// The zero value is not encoded.
func sizeFixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Uint64()
@ -3482,7 +3483,7 @@ func sizeFixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed64()
}
// appendFixed64 wire encodes a uint64 pointer as a Fixed64.
// appendFixed64NoZero wire encodes a uint64 pointer as a Fixed64.
// The zero value is not encoded.
func appendFixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Uint64()
@ -3772,7 +3773,7 @@ var coderDouble = pointerCoderFuncs{
unmarshal: consumeDouble,
}
// sizeDouble returns the size of wire encoding a float64 pointer as a Double.
// sizeDoubleNoZero returns the size of wire encoding a float64 pointer as a Double.
// The zero value is not encoded.
func sizeDoubleNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Float64()
@ -3782,7 +3783,7 @@ func sizeDoubleNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeFixed64()
}
// appendDouble wire encodes a float64 pointer as a Double.
// appendDoubleNoZero wire encodes a float64 pointer as a Double.
// The zero value is not encoded.
func appendDoubleNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Float64()
@ -4072,7 +4073,40 @@ var coderString = pointerCoderFuncs{
unmarshal: consumeString,
}
// sizeString returns the size of wire encoding a string pointer as a String.
// appendStringValidateUTF8 wire encodes a string pointer as a String.
func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consumeStringValidateUTF8 wire decodes a string pointer as a String.
func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*p.String() = v
return n, nil
}
var coderStringValidateUTF8 = pointerCoderFuncs{
size: sizeString,
marshal: appendStringValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
// sizeStringNoZero returns the size of wire encoding a string pointer as a String.
// The zero value is not encoded.
func sizeStringNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.String()
@ -4082,7 +4116,7 @@ func sizeStringNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeBytes(len(v))
}
// appendString wire encodes a string pointer as a String.
// appendStringNoZero wire encodes a string pointer as a String.
// The zero value is not encoded.
func appendStringNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
@ -4100,6 +4134,27 @@ var coderStringNoZero = pointerCoderFuncs{
unmarshal: consumeString,
}
// appendStringNoZeroValidateUTF8 wire encodes a string pointer as a String.
// The zero value is not encoded.
func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.String()
if len(v) == 0 {
return b, nil
}
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeStringNoZero,
marshal: appendStringNoZeroValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
// sizeStringPtr returns the size of wire encoding a *string pointer as a String.
// It panics if the pointer is nil.
func sizeStringPtr(p pointer, tagsize int, _ marshalOptions) (size int) {
@ -4178,6 +4233,42 @@ var coderStringSlice = pointerCoderFuncs{
unmarshal: consumeStringSlice,
}
// appendStringSliceValidateUTF8 encodes a []string pointer as a repeated String.
func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
s := *p.StringSlice()
for _, v := range s {
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
}
return b, nil
}
// consumeStringSliceValidateUTF8 wire decodes a []string pointer as a repeated String.
func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
sp := p.StringSlice()
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, v)
return n, nil
}
var coderStringSliceValidateUTF8 = pointerCoderFuncs{
size: sizeStringSlice,
marshal: appendStringSliceValidateUTF8,
unmarshal: consumeStringSliceValidateUTF8,
}
// sizeStringIface returns the size of wire encoding a string value as a String.
func sizeStringIface(ival interface{}, tagsize int, _ marshalOptions) int {
v := ival.(string)
@ -4210,6 +4301,38 @@ var coderStringIface = ifaceCoderFuncs{
unmarshal: consumeStringIface,
}
// appendStringIfaceValidateUTF8 encodes a string value as a String.
func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := ival.(string)
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consumeStringIfaceValidateUTF8 decodes a string value as a String.
func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != wire.BytesType {
return nil, 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return nil, 0, errInvalidUTF8{}
}
return v, n, nil
}
var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
size: sizeStringIface,
marshal: appendStringIfaceValidateUTF8,
unmarshal: consumeStringIfaceValidateUTF8,
}
// sizeStringSliceIface returns the size of wire encoding a []string value as a repeated String.
func sizeStringSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[]string)
@ -4282,7 +4405,40 @@ var coderBytes = pointerCoderFuncs{
unmarshal: consumeBytes,
}
// sizeBytes returns the size of wire encoding a []byte pointer as a Bytes.
// appendBytesValidateUTF8 wire encodes a []byte pointer as a Bytes.
func appendBytesValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bytes()
b = wire.AppendVarint(b, wiretag)
b = wire.AppendBytes(b, v)
if !utf8.Valid(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consumeBytesValidateUTF8 wire decodes a []byte pointer as a Bytes.
func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
*p.Bytes() = append(([]byte)(nil), v...)
return n, nil
}
var coderBytesValidateUTF8 = pointerCoderFuncs{
size: sizeBytes,
marshal: appendBytesValidateUTF8,
unmarshal: consumeBytesValidateUTF8,
}
// sizeBytesNoZero returns the size of wire encoding a []byte pointer as a Bytes.
// The zero value is not encoded.
func sizeBytesNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
v := *p.Bytes()
@ -4292,7 +4448,7 @@ func sizeBytesNoZero(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + wire.SizeBytes(len(v))
}
// appendBytes wire encodes a []byte pointer as a Bytes.
// appendBytesNoZero wire encodes a []byte pointer as a Bytes.
// The zero value is not encoded.
func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bytes()
@ -4310,6 +4466,27 @@ var coderBytesNoZero = pointerCoderFuncs{
unmarshal: consumeBytes,
}
// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
// The zero value is not encoded.
func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := *p.Bytes()
if len(v) == 0 {
return b, nil
}
b = wire.AppendVarint(b, wiretag)
b = wire.AppendBytes(b, v)
if !utf8.Valid(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeBytesNoZero,
marshal: appendBytesNoZeroValidateUTF8,
unmarshal: consumeBytesValidateUTF8,
}
// sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
func sizeBytesSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
s := *p.BytesSlice()
@ -4349,6 +4526,42 @@ var coderBytesSlice = pointerCoderFuncs{
unmarshal: consumeBytesSlice,
}
// appendBytesSliceValidateUTF8 encodes a [][]byte pointer as a repeated Bytes.
func appendBytesSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
s := *p.BytesSlice()
for _, v := range s {
b = wire.AppendVarint(b, wiretag)
b = wire.AppendBytes(b, v)
if !utf8.Valid(v) {
return b, errInvalidUTF8{}
}
}
return b, nil
}
// consumeBytesSliceValidateUTF8 wire decodes a [][]byte pointer as a repeated Bytes.
func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
sp := p.BytesSlice()
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, append(([]byte)(nil), v...))
return n, nil
}
var coderBytesSliceValidateUTF8 = pointerCoderFuncs{
size: sizeBytesSlice,
marshal: appendBytesSliceValidateUTF8,
unmarshal: consumeBytesSliceValidateUTF8,
}
// sizeBytesIface returns the size of wire encoding a []byte value as a Bytes.
func sizeBytesIface(ival interface{}, tagsize int, _ marshalOptions) int {
v := ival.([]byte)
@ -4381,6 +4594,38 @@ var coderBytesIface = ifaceCoderFuncs{
unmarshal: consumeBytesIface,
}
// appendBytesIfaceValidateUTF8 encodes a []byte value as a Bytes.
func appendBytesIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := ival.([]byte)
b = wire.AppendVarint(b, wiretag)
b = wire.AppendBytes(b, v)
if !utf8.Valid(v) {
return b, errInvalidUTF8{}
}
return b, nil
}
// consumeBytesIfaceValidateUTF8 decodes a []byte value as a Bytes.
func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != wire.BytesType {
return nil, 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
if !utf8.Valid(v) {
return nil, 0, errInvalidUTF8{}
}
return append(([]byte)(nil), v...), n, nil
}
var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
size: sizeBytesIface,
marshal: appendBytesIfaceValidateUTF8,
unmarshal: consumeBytesIfaceValidateUTF8,
}
// sizeBytesSliceIface returns the size of wire encoding a [][]byte value as a repeated Bytes.
func sizeBytesSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) {
s := *ival.(*[][]byte)

View File

@ -9,6 +9,7 @@ import (
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/strs"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
@ -98,12 +99,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
return coderDoubleSlice
}
case pref.StringKind:
if ft.Kind() == reflect.String && fd.Syntax() == pref.Proto3 {
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringSliceValidateUTF8
}
if ft.Kind() == reflect.String {
return coderStringSlice
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return coderBytesSliceValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesSlice
}
@ -251,9 +255,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
return coderDoubleNoZero
}
case pref.StringKind:
if ft.Kind() == reflect.String {
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringNoZeroValidateUTF8
}
if ft.Kind() == reflect.String {
return coderStringNoZero
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return coderBytesNoZeroValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesNoZero
}
@ -392,12 +402,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
return coderDouble
}
case pref.StringKind:
if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String {
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringValidateUTF8
}
if ft.Kind() == reflect.String {
return coderString
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return coderBytesValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytes
}
@ -620,12 +633,15 @@ func encoderFuncsForValue(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFu
return coderDoubleIface
}
case pref.StringKind:
if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String {
if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
return coderStringIfaceValidateUTF8
}
if ft.Kind() == reflect.String {
return coderStringIface
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return coderBytesIfaceValidateUTF8
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return coderBytesIface
}

View File

@ -8,8 +8,21 @@ package strs
import (
"strings"
"unicode"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/reflect/protoreflect"
)
// EnforceUTF8 reports whether to enforce strict UTF-8 validation.
func EnforceUTF8(fd protoreflect.FieldDescriptor) bool {
if flags.Proto1Legacy {
if fd, ok := fd.(interface{ EnforceUTF8() bool }); ok {
return fd.EnforceUTF8()
}
}
return fd.Syntax() == protoreflect.Proto3
}
// JSONCamelCase converts a snake_case identifier to a camelCase identifier,
// according to the protobuf JSON specification.
func JSONCamelCase(s string) string {

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
@ -154,7 +155,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl
if n < 0 {
return val, 0, wire.ParseError(n)
}
if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
}
return protoreflect.ValueOf(string(v)), n, nil
@ -550,7 +551,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl
if n < 0 {
return 0, wire.ParseError(n)
}
if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName()))
}
list.Append(protoreflect.ValueOf(string(v)))

View File

@ -12,13 +12,20 @@ import (
protoV1 "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/filedesc"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/prototype"
"google.golang.org/protobuf/runtime/protoimpl"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
"google.golang.org/protobuf/types/descriptorpb"
)
type testProto struct {
@ -85,6 +92,23 @@ func TestDecodeInvalidUTF8(t *testing.T) {
}
}
func TestDecodeNoEnforceUTF8(t *testing.T) {
for _, test := range noEnforceUTF8TestProtos {
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
err := proto.Unmarshal(test.wire, got)
switch {
case flags.Proto1Legacy && err != nil:
t.Errorf("Unmarshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want))
case !flags.Proto1Legacy && err == nil:
t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
}
})
}
}
}
var testProtos = []testProto{
{
desc: "basic scalar types",
@ -1442,6 +1466,129 @@ var invalidUTF8TestProtos = []testProto{
},
}
var noEnforceUTF8TestProtos = []testProto{
{
desc: "invalid UTF-8 in optional string field",
decodeTo: []proto.Message{&TestNoEnforceUTF8{
OptionalString: string("abc\xff"),
}},
wire: pack.Message{
pack.Tag{1, pack.BytesType}, pack.String("abc\xff"),
}.Marshal(),
},
{
desc: "invalid UTF-8 in optional string field of Go bytes",
decodeTo: []proto.Message{&TestNoEnforceUTF8{
OptionalBytes: []byte("abc\xff"),
}},
wire: pack.Message{
pack.Tag{2, pack.BytesType}, pack.String("abc\xff"),
}.Marshal(),
},
{
desc: "invalid UTF-8 in repeated string field",
decodeTo: []proto.Message{&TestNoEnforceUTF8{
RepeatedString: []string{string("foo"), string("abc\xff")},
}},
wire: pack.Message{
pack.Tag{3, pack.BytesType}, pack.String("foo"),
pack.Tag{3, pack.BytesType}, pack.String("abc\xff"),
}.Marshal(),
},
{
desc: "invalid UTF-8 in repeated string field of Go bytes",
decodeTo: []proto.Message{&TestNoEnforceUTF8{
RepeatedBytes: [][]byte{[]byte("foo"), []byte("abc\xff")},
}},
wire: pack.Message{
pack.Tag{4, pack.BytesType}, pack.String("foo"),
pack.Tag{4, pack.BytesType}, pack.String("abc\xff"),
}.Marshal(),
},
{
desc: "invalid UTF-8 in oneof string field",
decodeTo: []proto.Message{
&TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofString{string("abc\xff")}},
},
wire: pack.Message{pack.Tag{5, pack.BytesType}, pack.String("abc\xff")}.Marshal(),
},
{
desc: "invalid UTF-8 in oneof string field of Go bytes",
decodeTo: []proto.Message{
&TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofBytes{[]byte("abc\xff")}},
},
wire: pack.Message{pack.Tag{6, pack.BytesType}, pack.String("abc\xff")}.Marshal(),
},
}
type TestNoEnforceUTF8 struct {
OptionalString string `protobuf:"bytes,1,opt,name=optional_string"`
OptionalBytes []byte `protobuf:"bytes,2,opt,name=optional_bytes"`
RepeatedString []string `protobuf:"bytes,3,rep,name=repeated_string"`
RepeatedBytes [][]byte `protobuf:"bytes,4,rep,name=repeated_bytes"`
OneofField isOneofField `protobuf_oneof:"oneof_field"`
}
type isOneofField interface{ isOneofField() }
type TestNoEnforceUTF8_OneofString struct {
OneofString string `protobuf:"bytes,5,opt,name=oneof_string,oneof"`
}
type TestNoEnforceUTF8_OneofBytes struct {
OneofBytes []byte `protobuf:"bytes,6,opt,name=oneof_bytes,oneof"`
}
func (*TestNoEnforceUTF8_OneofString) isOneofField() {}
func (*TestNoEnforceUTF8_OneofBytes) isOneofField() {}
func (m *TestNoEnforceUTF8) ProtoReflect() pref.Message {
return messageInfo_TestNoEnforceUTF8.MessageOf(m)
}
var messageInfo_TestNoEnforceUTF8 = protoimpl.MessageInfo{
GoType: reflect.TypeOf((*TestNoEnforceUTF8)(nil)),
PBType: &prototype.Message{
MessageDescriptor: func() protoreflect.MessageDescriptor {
pb := new(descriptorpb.FileDescriptorProto)
if err := prototext.Unmarshal([]byte(`
syntax: "proto3"
name: "test.proto"
message_type: [{
name: "TestNoEnforceUTF8"
field: [
{name:"optional_string" number:1 label:LABEL_OPTIONAL type:TYPE_STRING},
{name:"optional_bytes" number:2 label:LABEL_OPTIONAL type:TYPE_STRING},
{name:"repeated_string" number:3 label:LABEL_REPEATED type:TYPE_STRING},
{name:"repeated_bytes" number:4 label:LABEL_REPEATED type:TYPE_STRING},
{name:"oneof_string" number:5 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0},
{name:"oneof_bytes" number:6 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0}
]
oneof_decl: [{name:"oneof_field"}]
}]
`), pb); err != nil {
panic(err)
}
fd, err := protodesc.NewFile(pb, nil)
if err != nil {
panic(err)
}
md := fd.Messages().Get(0)
for i := 0; i < md.Fields().Len(); i++ {
md.Fields().Get(i).(*filedesc.Field).L1.HasEnforceUTF8 = true
md.Fields().Get(i).(*filedesc.Field).L1.EnforceUTF8 = false
}
return md
}(),
NewMessage: func() pref.Message {
return pref.ProtoMessage(new(TestNoEnforceUTF8)).ProtoReflect()
},
},
OneofWrappers: []interface{}{
(*TestNoEnforceUTF8_OneofString)(nil),
(*TestNoEnforceUTF8_OneofBytes)(nil),
},
}
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/reflect/protoreflect"
)
@ -67,7 +68,7 @@ func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescripto
case protoreflect.DoubleKind:
b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
case protoreflect.StringKind:
if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
return b, errors.InvalidUTF8(string(fd.FullName()))
}
b = wire.AppendString(b, v.String())

View File

@ -10,6 +10,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
@ -97,6 +98,22 @@ func TestEncodeInvalidUTF8(t *testing.T) {
}
}
func TestEncodeNoEnforceUTF8(t *testing.T) {
for _, test := range noEnforceUTF8TestProtos {
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
_, err := proto.Marshal(want)
switch {
case flags.Proto1Legacy && err != nil:
t.Errorf("Marshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want))
case !flags.Proto1Legacy && err == nil:
t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
}
})
}
}
}
func TestEncodeRequiredFieldChecks(t *testing.T) {
for _, test := range testProtos {
if !test.partial {