From c51e2e0293388cb210278960154ed5adf3569b3e Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Sat, 13 Jul 2019 00:44:41 -0700 Subject: [PATCH] all: support enforce_utf8 override In 2014, when proto3 was being developed, there were a number of early adopters of the new syntax. Before the finalization of proto3 when it was released in open-source in July 2016, a decision was made to strictly validate strings in proto3. However, some of the early adopters were already using invalid UTF-8 with string fields. The google.protobuf.FieldOptions.enforce_utf8 option only exists to support those grandfathered users where they can opt-out of the validation logic. Practical use of that option in open source is impossible even if a user specifies the proto1_legacy build tag since it requires a hacked variant of descriptor.proto that is not externally available. This CL supports enforce_utf8 by modifiyng internal/filedesc to expose the flag if it detects it in the raw descriptor. We add an strs.EnforceUTF8 function as a centralized place to determine whether to perform validation. Validation opt-out is supported only in builds with legacy support. We implement support for validating UTF-8 in all proto3 string fields, even if they are backed by a Go []byte. Change-Id: I9c0628b84909bc7181125f09db730c80d490e485 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186002 Reviewed-by: Damien Neil --- internal/cmd/generate-types/impl.go | 134 +++++++++++- internal/cmd/generate-types/main.go | 1 + internal/cmd/generate-types/proto.go | 6 +- internal/filedesc/desc.go | 16 ++ internal/filedesc/desc_lazy.go | 5 + internal/impl/codec_field.go | 131 ------------ internal/impl/codec_gen.go | 305 ++++++++++++++++++++++++--- internal/impl/codec_tables.go | 24 ++- internal/strs/strings.go | 13 ++ proto/decode_gen.go | 5 +- proto/decode_test.go | 147 +++++++++++++ proto/encode_gen.go | 3 +- proto/encode_test.go | 17 ++ 13 files changed, 634 insertions(+), 173 deletions(-) diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go index a92c7ea0..99192433 100644 --- a/internal/cmd/generate-types/impl.go +++ b/internal/cmd/generate-types/impl.go @@ -95,7 +95,42 @@ var coder{{.Name}} = pointerCoderFuncs{ unmarshal: consume{{.Name}}, } -// size{{.Name}} returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}. +{{if or (eq .Name "Bytes") (eq .Name "String")}} +// append{{.Name}}ValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}. +func append{{.Name}}ValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.{{.GoType.PointerMethod}}() + b = wire.AppendVarint(b, wiretag) + {{template "Append" .}} + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consume{{.Name}}ValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}. +func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != {{.WireType.Expr}} { + return 0, errUnknown + } + v, n := {{template "Consume" .}} + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return 0, errInvalidUTF8{} + } + *p.{{.GoType.PointerMethod}}() = {{.ToGoType}} + return n, nil +} + +var coder{{.Name}}ValidateUTF8 = pointerCoderFuncs{ + size: size{{.Name}}, + marshal: append{{.Name}}ValidateUTF8, + unmarshal: consume{{.Name}}ValidateUTF8, +} +{{end}} + +// size{{.Name}}NoZero returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}. // The zero value is not encoded. func size{{.Name}}NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.{{.GoType.PointerMethod}}() @@ -105,7 +140,7 @@ func size{{.Name}}NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + {{template "Size" .}} } -// append{{.Name}} wire encodes a {{.GoType}} pointer as a {{.Name}}. +// append{{.Name}}NoZero wire encodes a {{.GoType}} pointer as a {{.Name}}. // The zero value is not encoded. func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.{{.GoType.PointerMethod}}() @@ -123,6 +158,29 @@ var coder{{.Name}}NoZero = pointerCoderFuncs{ unmarshal: consume{{.Name}}, } +{{if or (eq .Name "Bytes") (eq .Name "String")}} +// append{{.Name}}NoZeroValidateUTF8 wire encodes a {{.GoType}} pointer as a {{.Name}}. +// The zero value is not encoded. +func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.{{.GoType.PointerMethod}}() + if {{template "IsZero" .}} { + return b, nil + } + b = wire.AppendVarint(b, wiretag) + {{template "Append" .}} + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{ + size: size{{.Name}}NoZero, + marshal: append{{.Name}}NoZeroValidateUTF8, + unmarshal: consume{{.Name}}ValidateUTF8, +} +{{end}} + {{- if not .NoPointer}} // size{{.Name}}Ptr returns the size of wire encoding a *{{.GoType}} pointer as a {{.Name}}. // It panics if the pointer is nil. @@ -228,6 +286,44 @@ var coder{{.Name}}Slice = pointerCoderFuncs{ unmarshal: consume{{.Name}}Slice, } +{{if or (eq .Name "Bytes") (eq .Name "String")}} +// append{{.Name}}SliceValidateUTF8 encodes a []{{.GoType}} pointer as a repeated {{.Name}}. +func append{{.Name}}SliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + s := *p.{{.GoType.PointerMethod}}Slice() + for _, v := range s { + b = wire.AppendVarint(b, wiretag) + {{template "Append" .}} + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return b, errInvalidUTF8{} + } + } + return b, nil +} + +// consume{{.Name}}SliceValidateUTF8 wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}. +func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + sp := p.{{.GoType.PointerMethod}}Slice() + if wtyp != {{.WireType.Expr}} { + return 0, errUnknown + } + v, n := {{template "Consume" .}} + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return 0, errInvalidUTF8{} + } + *sp = append(*sp, {{.ToGoType}}) + return n, nil +} + +var coder{{.Name}}SliceValidateUTF8 = pointerCoderFuncs{ + size: size{{.Name}}Slice, + marshal: append{{.Name}}SliceValidateUTF8, + unmarshal: consume{{.Name}}SliceValidateUTF8, +} +{{end}} + {{if or (eq .WireType "Varint") (eq .WireType "Fixed32") (eq .WireType "Fixed64")}} // size{{.Name}}PackedSlice returns the size of wire encoding a []{{.GoType}} pointer as a packed repeated {{.Name}}. func size{{.Name}}PackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) { @@ -309,6 +405,40 @@ var coder{{.Name}}Iface = ifaceCoderFuncs{ unmarshal: consume{{.Name}}Iface, } +{{if or (eq .Name "Bytes") (eq .Name "String")}} +// append{{.Name}}IfaceValidateUTF8 encodes a {{.GoType}} value as a {{.Name}}. +func append{{.Name}}IfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := ival.({{.GoType}}) + b = wire.AppendVarint(b, wiretag) + {{template "Append" .}} + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consume{{.Name}}IfaceValidateUTF8 decodes a {{.GoType}} value as a {{.Name}}. +func consume{{.Name}}IfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) { + if wtyp != {{.WireType.Expr}} { + return nil, 0, errUnknown + } + v, n := {{template "Consume" .}} + if n < 0 { + return nil, 0, wire.ParseError(n) + } + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return nil, 0, errInvalidUTF8{} + } + return {{.ToGoType}}, n, nil +} + +var coder{{.Name}}IfaceValidateUTF8 = ifaceCoderFuncs{ + size: size{{.Name}}Iface, + marshal: append{{.Name}}IfaceValidateUTF8, + unmarshal: consume{{.Name}}IfaceValidateUTF8, +} +{{end}} + // size{{.Name}}SliceIface returns the size of wire encoding a []{{.GoType}} value as a repeated {{.Name}}. func size{{.Name}}SliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) { s := *ival.(*[]{{.GoType}}) diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go index 6490b5c6..6d19508d 100644 --- a/internal/cmd/generate-types/main.go +++ b/internal/cmd/generate-types/main.go @@ -191,6 +191,7 @@ func writeSource(file, src string) { "google.golang.org/protobuf/internal/descfmt", "google.golang.org/protobuf/internal/encoding/wire", "google.golang.org/protobuf/internal/errors", + "google.golang.org/protobuf/internal/strs", "google.golang.org/protobuf/internal/pragma", "google.golang.org/protobuf/reflect/protoreflect", "google.golang.org/protobuf/runtime/protoiface", diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index e507b031..023cde8a 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -276,7 +276,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl return val, 0, wire.ParseError(n) } {{if (eq .Name "String") -}} - if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) } {{end -}} @@ -320,7 +320,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } {{if (eq .Name "String") -}} - if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return 0, errors.InvalidUTF8(string(fd.FullName())) } {{end -}} @@ -357,7 +357,7 @@ func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescripto {{- range .}} case {{.Expr}}: {{- if (eq .Name "String") }} - if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) { + if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) { return b, errors.InvalidUTF8(string(fd.FullName())) } b = wire.AppendString(b, {{.FromValue}}) diff --git a/internal/filedesc/desc.go b/internal/filedesc/desc.go index 59984e87..d42bcd72 100644 --- a/internal/filedesc/desc.go +++ b/internal/filedesc/desc.go @@ -200,6 +200,8 @@ type ( IsWeak bool // promoted from google.protobuf.FieldOptions HasPacked bool // promoted from google.protobuf.FieldOptions IsPacked bool // promoted from google.protobuf.FieldOptions + HasEnforceUTF8 bool // promoted from google.protobuf.FieldOptions + EnforceUTF8 bool // promoted from google.protobuf.FieldOptions Default defaultValue ContainingOneof pref.OneofDescriptor // must be consistent with Message.Oneofs.Fields Enum pref.EnumDescriptor @@ -303,6 +305,20 @@ func (fd *Field) Message() pref.MessageDescriptor { return fd.L1.Message } func (fd *Field) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, fd) } func (fd *Field) ProtoType(pref.FieldDescriptor) {} +// EnforceUTF8 is a pseudo-internal API to determine whether to enforce UTF-8 +// validation for the string field. This exists for Google-internal use only +// since proto3 did not enforce UTF-8 validity prior to the open-source release. +// If this method does not exist, the default is to enforce valid UTF-8. +// +// WARNING: This method is exempt from the compatibility promise and may be +// removed in the future without warning. +func (fd *Field) EnforceUTF8() bool { + if fd.L1.HasEnforceUTF8 { + return fd.L1.EnforceUTF8 + } + return fd.L0.ParentFile.L1.Syntax == pref.Proto3 +} + func (od *Oneof) Options() pref.ProtoMessage { if f := od.L1.Options; f != nil { return f() diff --git a/internal/filedesc/desc_lazy.go b/internal/filedesc/desc_lazy.go index 55104ad5..9b54e6d1 100644 --- a/internal/filedesc/desc_lazy.go +++ b/internal/filedesc/desc_lazy.go @@ -480,6 +480,8 @@ func (fd *Field) unmarshalFull(b []byte, sb *strs.Builder, pf *File, pd pref.Des } func (fd *Field) unmarshalOptions(b []byte) { + const FieldOptions_EnforceUTF8 = 13 + for len(b) > 0 { num, typ, n := wire.ConsumeTag(b) b = b[n:] @@ -493,6 +495,9 @@ func (fd *Field) unmarshalOptions(b []byte) { fd.L1.IsPacked = wire.DecodeBool(v) case fieldnum.FieldOptions_Weak: fd.L1.IsWeak = wire.DecodeBool(v) + case FieldOptions_EnforceUTF8: + fd.L1.HasEnforceUTF8 = true + fd.L1.EnforceUTF8 = !wire.DecodeBool(v) } default: m := wire.ConsumeFieldValue(num, typ, b) diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go index 8d0e3397..94b7d6ab 100644 --- a/internal/impl/codec_field.go +++ b/internal/impl/codec_field.go @@ -6,7 +6,6 @@ package impl import ( "reflect" - "unicode/utf8" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/proto" @@ -747,136 +746,6 @@ var coderEnumPackedSliceIface = ifaceCoderFuncs{ unmarshal: consumeEnumSliceIface, } -// Strings with UTF8 validation. - -func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { - v := *p.String() - b = wire.AppendVarint(b, wiretag) - b = wire.AppendString(b, v) - if !utf8.ValidString(v) { - return b, errInvalidUTF8{} - } - return b, nil -} - -func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { - if wtyp != wire.BytesType { - return 0, errUnknown - } - v, n := wire.ConsumeString(b) - if n < 0 { - return 0, wire.ParseError(n) - } - if !utf8.ValidString(v) { - return 0, errInvalidUTF8{} - } - *p.String() = v - return n, nil -} - -var coderStringValidateUTF8 = pointerCoderFuncs{ - size: sizeString, - marshal: appendStringValidateUTF8, - unmarshal: consumeStringValidateUTF8, -} - -func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { - v := *p.String() - if len(v) == 0 { - return b, nil - } - b = wire.AppendVarint(b, wiretag) - b = wire.AppendString(b, v) - if !utf8.ValidString(v) { - return b, errInvalidUTF8{} - } - return b, nil -} - -var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{ - size: sizeStringNoZero, - marshal: appendStringNoZeroValidateUTF8, - unmarshal: consumeStringValidateUTF8, -} - -func sizeStringSliceValidateUTF8(p pointer, tagsize int, _ marshalOptions) (size int) { - s := *p.StringSlice() - for _, v := range s { - size += tagsize + wire.SizeBytes(len(v)) - } - return size -} - -func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { - s := *p.StringSlice() - var err error - for _, v := range s { - b = wire.AppendVarint(b, wiretag) - b = wire.AppendString(b, v) - if !utf8.ValidString(v) { - return b, errInvalidUTF8{} - } - } - return b, err -} - -func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { - if wtyp != wire.BytesType { - return 0, errUnknown - } - sp := p.StringSlice() - v, n := wire.ConsumeString(b) - if n < 0 { - return 0, wire.ParseError(n) - } - if !utf8.ValidString(v) { - return 0, errInvalidUTF8{} - } - *sp = append(*sp, v) - return n, nil -} - -var coderStringSliceValidateUTF8 = pointerCoderFuncs{ - size: sizeStringSliceValidateUTF8, - marshal: appendStringSliceValidateUTF8, - unmarshal: consumeStringSliceValidateUTF8, -} - -func sizeStringIfaceValidateUTF8(ival interface{}, tagsize int, _ marshalOptions) int { - v := ival.(string) - return tagsize + wire.SizeBytes(len(v)) -} - -func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) { - v := ival.(string) - b = wire.AppendVarint(b, wiretag) - b = wire.AppendString(b, v) - if !utf8.ValidString(v) { - return b, errInvalidUTF8{} - } - return b, nil -} - -func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) { - if wtyp != wire.BytesType { - return nil, 0, errUnknown - } - v, n := wire.ConsumeString(b) - if n < 0 { - return nil, 0, wire.ParseError(n) - } - if !utf8.ValidString(v) { - return nil, 0, errInvalidUTF8{} - } - return v, n, nil -} - -var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{ - size: sizeStringIfaceValidateUTF8, - marshal: appendStringIfaceValidateUTF8, - unmarshal: consumeStringIfaceValidateUTF8, -} - func asMessage(v reflect.Value) pref.ProtoMessage { if m, ok := v.Interface().(pref.ProtoMessage); ok { return m diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go index 41bd0991..46380f57 100644 --- a/internal/impl/codec_gen.go +++ b/internal/impl/codec_gen.go @@ -8,6 +8,7 @@ package impl import ( "math" + "unicode/utf8" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/reflect/protoreflect" @@ -46,7 +47,7 @@ var coderBool = pointerCoderFuncs{ unmarshal: consumeBool, } -// sizeBool returns the size of wire encoding a bool pointer as a Bool. +// sizeBoolNoZero returns the size of wire encoding a bool pointer as a Bool. // The zero value is not encoded. func sizeBoolNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Bool() @@ -56,7 +57,7 @@ func sizeBoolNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(wire.EncodeBool(v)) } -// appendBool wire encodes a bool pointer as a Bool. +// appendBoolNoZero wire encodes a bool pointer as a Bool. // The zero value is not encoded. func appendBoolNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Bool() @@ -364,7 +365,7 @@ var coderInt32 = pointerCoderFuncs{ unmarshal: consumeInt32, } -// sizeInt32 returns the size of wire encoding a int32 pointer as a Int32. +// sizeInt32NoZero returns the size of wire encoding a int32 pointer as a Int32. // The zero value is not encoded. func sizeInt32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int32() @@ -374,7 +375,7 @@ func sizeInt32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(uint64(v)) } -// appendInt32 wire encodes a int32 pointer as a Int32. +// appendInt32NoZero wire encodes a int32 pointer as a Int32. // The zero value is not encoded. func appendInt32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int32() @@ -682,7 +683,7 @@ var coderSint32 = pointerCoderFuncs{ unmarshal: consumeSint32, } -// sizeSint32 returns the size of wire encoding a int32 pointer as a Sint32. +// sizeSint32NoZero returns the size of wire encoding a int32 pointer as a Sint32. // The zero value is not encoded. func sizeSint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int32() @@ -692,7 +693,7 @@ func sizeSint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(wire.EncodeZigZag(int64(v))) } -// appendSint32 wire encodes a int32 pointer as a Sint32. +// appendSint32NoZero wire encodes a int32 pointer as a Sint32. // The zero value is not encoded. func appendSint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int32() @@ -1000,7 +1001,7 @@ var coderUint32 = pointerCoderFuncs{ unmarshal: consumeUint32, } -// sizeUint32 returns the size of wire encoding a uint32 pointer as a Uint32. +// sizeUint32NoZero returns the size of wire encoding a uint32 pointer as a Uint32. // The zero value is not encoded. func sizeUint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Uint32() @@ -1010,7 +1011,7 @@ func sizeUint32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(uint64(v)) } -// appendUint32 wire encodes a uint32 pointer as a Uint32. +// appendUint32NoZero wire encodes a uint32 pointer as a Uint32. // The zero value is not encoded. func appendUint32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Uint32() @@ -1318,7 +1319,7 @@ var coderInt64 = pointerCoderFuncs{ unmarshal: consumeInt64, } -// sizeInt64 returns the size of wire encoding a int64 pointer as a Int64. +// sizeInt64NoZero returns the size of wire encoding a int64 pointer as a Int64. // The zero value is not encoded. func sizeInt64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int64() @@ -1328,7 +1329,7 @@ func sizeInt64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(uint64(v)) } -// appendInt64 wire encodes a int64 pointer as a Int64. +// appendInt64NoZero wire encodes a int64 pointer as a Int64. // The zero value is not encoded. func appendInt64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int64() @@ -1636,7 +1637,7 @@ var coderSint64 = pointerCoderFuncs{ unmarshal: consumeSint64, } -// sizeSint64 returns the size of wire encoding a int64 pointer as a Sint64. +// sizeSint64NoZero returns the size of wire encoding a int64 pointer as a Sint64. // The zero value is not encoded. func sizeSint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int64() @@ -1646,7 +1647,7 @@ func sizeSint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(wire.EncodeZigZag(v)) } -// appendSint64 wire encodes a int64 pointer as a Sint64. +// appendSint64NoZero wire encodes a int64 pointer as a Sint64. // The zero value is not encoded. func appendSint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int64() @@ -1954,7 +1955,7 @@ var coderUint64 = pointerCoderFuncs{ unmarshal: consumeUint64, } -// sizeUint64 returns the size of wire encoding a uint64 pointer as a Uint64. +// sizeUint64NoZero returns the size of wire encoding a uint64 pointer as a Uint64. // The zero value is not encoded. func sizeUint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Uint64() @@ -1964,7 +1965,7 @@ func sizeUint64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeVarint(v) } -// appendUint64 wire encodes a uint64 pointer as a Uint64. +// appendUint64NoZero wire encodes a uint64 pointer as a Uint64. // The zero value is not encoded. func appendUint64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Uint64() @@ -2272,7 +2273,7 @@ var coderSfixed32 = pointerCoderFuncs{ unmarshal: consumeSfixed32, } -// sizeSfixed32 returns the size of wire encoding a int32 pointer as a Sfixed32. +// sizeSfixed32NoZero returns the size of wire encoding a int32 pointer as a Sfixed32. // The zero value is not encoded. func sizeSfixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int32() @@ -2282,7 +2283,7 @@ func sizeSfixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed32() } -// appendSfixed32 wire encodes a int32 pointer as a Sfixed32. +// appendSfixed32NoZero wire encodes a int32 pointer as a Sfixed32. // The zero value is not encoded. func appendSfixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int32() @@ -2572,7 +2573,7 @@ var coderFixed32 = pointerCoderFuncs{ unmarshal: consumeFixed32, } -// sizeFixed32 returns the size of wire encoding a uint32 pointer as a Fixed32. +// sizeFixed32NoZero returns the size of wire encoding a uint32 pointer as a Fixed32. // The zero value is not encoded. func sizeFixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Uint32() @@ -2582,7 +2583,7 @@ func sizeFixed32NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed32() } -// appendFixed32 wire encodes a uint32 pointer as a Fixed32. +// appendFixed32NoZero wire encodes a uint32 pointer as a Fixed32. // The zero value is not encoded. func appendFixed32NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Uint32() @@ -2872,7 +2873,7 @@ var coderFloat = pointerCoderFuncs{ unmarshal: consumeFloat, } -// sizeFloat returns the size of wire encoding a float32 pointer as a Float. +// sizeFloatNoZero returns the size of wire encoding a float32 pointer as a Float. // The zero value is not encoded. func sizeFloatNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Float32() @@ -2882,7 +2883,7 @@ func sizeFloatNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed32() } -// appendFloat wire encodes a float32 pointer as a Float. +// appendFloatNoZero wire encodes a float32 pointer as a Float. // The zero value is not encoded. func appendFloatNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Float32() @@ -3172,7 +3173,7 @@ var coderSfixed64 = pointerCoderFuncs{ unmarshal: consumeSfixed64, } -// sizeSfixed64 returns the size of wire encoding a int64 pointer as a Sfixed64. +// sizeSfixed64NoZero returns the size of wire encoding a int64 pointer as a Sfixed64. // The zero value is not encoded. func sizeSfixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Int64() @@ -3182,7 +3183,7 @@ func sizeSfixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed64() } -// appendSfixed64 wire encodes a int64 pointer as a Sfixed64. +// appendSfixed64NoZero wire encodes a int64 pointer as a Sfixed64. // The zero value is not encoded. func appendSfixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Int64() @@ -3472,7 +3473,7 @@ var coderFixed64 = pointerCoderFuncs{ unmarshal: consumeFixed64, } -// sizeFixed64 returns the size of wire encoding a uint64 pointer as a Fixed64. +// sizeFixed64NoZero returns the size of wire encoding a uint64 pointer as a Fixed64. // The zero value is not encoded. func sizeFixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Uint64() @@ -3482,7 +3483,7 @@ func sizeFixed64NoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed64() } -// appendFixed64 wire encodes a uint64 pointer as a Fixed64. +// appendFixed64NoZero wire encodes a uint64 pointer as a Fixed64. // The zero value is not encoded. func appendFixed64NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Uint64() @@ -3772,7 +3773,7 @@ var coderDouble = pointerCoderFuncs{ unmarshal: consumeDouble, } -// sizeDouble returns the size of wire encoding a float64 pointer as a Double. +// sizeDoubleNoZero returns the size of wire encoding a float64 pointer as a Double. // The zero value is not encoded. func sizeDoubleNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Float64() @@ -3782,7 +3783,7 @@ func sizeDoubleNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeFixed64() } -// appendDouble wire encodes a float64 pointer as a Double. +// appendDoubleNoZero wire encodes a float64 pointer as a Double. // The zero value is not encoded. func appendDoubleNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Float64() @@ -4072,7 +4073,40 @@ var coderString = pointerCoderFuncs{ unmarshal: consumeString, } -// sizeString returns the size of wire encoding a string pointer as a String. +// appendStringValidateUTF8 wire encodes a string pointer as a String. +func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.String() + b = wire.AppendVarint(b, wiretag) + b = wire.AppendString(b, v) + if !utf8.ValidString(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consumeStringValidateUTF8 wire decodes a string pointer as a String. +func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeString(b) + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.ValidString(v) { + return 0, errInvalidUTF8{} + } + *p.String() = v + return n, nil +} + +var coderStringValidateUTF8 = pointerCoderFuncs{ + size: sizeString, + marshal: appendStringValidateUTF8, + unmarshal: consumeStringValidateUTF8, +} + +// sizeStringNoZero returns the size of wire encoding a string pointer as a String. // The zero value is not encoded. func sizeStringNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.String() @@ -4082,7 +4116,7 @@ func sizeStringNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeBytes(len(v)) } -// appendString wire encodes a string pointer as a String. +// appendStringNoZero wire encodes a string pointer as a String. // The zero value is not encoded. func appendStringNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.String() @@ -4100,6 +4134,27 @@ var coderStringNoZero = pointerCoderFuncs{ unmarshal: consumeString, } +// appendStringNoZeroValidateUTF8 wire encodes a string pointer as a String. +// The zero value is not encoded. +func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.String() + if len(v) == 0 { + return b, nil + } + b = wire.AppendVarint(b, wiretag) + b = wire.AppendString(b, v) + if !utf8.ValidString(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{ + size: sizeStringNoZero, + marshal: appendStringNoZeroValidateUTF8, + unmarshal: consumeStringValidateUTF8, +} + // sizeStringPtr returns the size of wire encoding a *string pointer as a String. // It panics if the pointer is nil. func sizeStringPtr(p pointer, tagsize int, _ marshalOptions) (size int) { @@ -4178,6 +4233,42 @@ var coderStringSlice = pointerCoderFuncs{ unmarshal: consumeStringSlice, } +// appendStringSliceValidateUTF8 encodes a []string pointer as a repeated String. +func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + s := *p.StringSlice() + for _, v := range s { + b = wire.AppendVarint(b, wiretag) + b = wire.AppendString(b, v) + if !utf8.ValidString(v) { + return b, errInvalidUTF8{} + } + } + return b, nil +} + +// consumeStringSliceValidateUTF8 wire decodes a []string pointer as a repeated String. +func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + sp := p.StringSlice() + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeString(b) + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.ValidString(v) { + return 0, errInvalidUTF8{} + } + *sp = append(*sp, v) + return n, nil +} + +var coderStringSliceValidateUTF8 = pointerCoderFuncs{ + size: sizeStringSlice, + marshal: appendStringSliceValidateUTF8, + unmarshal: consumeStringSliceValidateUTF8, +} + // sizeStringIface returns the size of wire encoding a string value as a String. func sizeStringIface(ival interface{}, tagsize int, _ marshalOptions) int { v := ival.(string) @@ -4210,6 +4301,38 @@ var coderStringIface = ifaceCoderFuncs{ unmarshal: consumeStringIface, } +// appendStringIfaceValidateUTF8 encodes a string value as a String. +func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := ival.(string) + b = wire.AppendVarint(b, wiretag) + b = wire.AppendString(b, v) + if !utf8.ValidString(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consumeStringIfaceValidateUTF8 decodes a string value as a String. +func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) { + if wtyp != wire.BytesType { + return nil, 0, errUnknown + } + v, n := wire.ConsumeString(b) + if n < 0 { + return nil, 0, wire.ParseError(n) + } + if !utf8.ValidString(v) { + return nil, 0, errInvalidUTF8{} + } + return v, n, nil +} + +var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{ + size: sizeStringIface, + marshal: appendStringIfaceValidateUTF8, + unmarshal: consumeStringIfaceValidateUTF8, +} + // sizeStringSliceIface returns the size of wire encoding a []string value as a repeated String. func sizeStringSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) { s := *ival.(*[]string) @@ -4282,7 +4405,40 @@ var coderBytes = pointerCoderFuncs{ unmarshal: consumeBytes, } -// sizeBytes returns the size of wire encoding a []byte pointer as a Bytes. +// appendBytesValidateUTF8 wire encodes a []byte pointer as a Bytes. +func appendBytesValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.Bytes() + b = wire.AppendVarint(b, wiretag) + b = wire.AppendBytes(b, v) + if !utf8.Valid(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consumeBytesValidateUTF8 wire decodes a []byte pointer as a Bytes. +func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeBytes(b) + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid(v) { + return 0, errInvalidUTF8{} + } + *p.Bytes() = append(([]byte)(nil), v...) + return n, nil +} + +var coderBytesValidateUTF8 = pointerCoderFuncs{ + size: sizeBytes, + marshal: appendBytesValidateUTF8, + unmarshal: consumeBytesValidateUTF8, +} + +// sizeBytesNoZero returns the size of wire encoding a []byte pointer as a Bytes. // The zero value is not encoded. func sizeBytesNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { v := *p.Bytes() @@ -4292,7 +4448,7 @@ func sizeBytesNoZero(p pointer, tagsize int, _ marshalOptions) (size int) { return tagsize + wire.SizeBytes(len(v)) } -// appendBytes wire encodes a []byte pointer as a Bytes. +// appendBytesNoZero wire encodes a []byte pointer as a Bytes. // The zero value is not encoded. func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { v := *p.Bytes() @@ -4310,6 +4466,27 @@ var coderBytesNoZero = pointerCoderFuncs{ unmarshal: consumeBytes, } +// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes. +// The zero value is not encoded. +func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := *p.Bytes() + if len(v) == 0 { + return b, nil + } + b = wire.AppendVarint(b, wiretag) + b = wire.AppendBytes(b, v) + if !utf8.Valid(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{ + size: sizeBytesNoZero, + marshal: appendBytesNoZeroValidateUTF8, + unmarshal: consumeBytesValidateUTF8, +} + // sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes. func sizeBytesSlice(p pointer, tagsize int, _ marshalOptions) (size int) { s := *p.BytesSlice() @@ -4349,6 +4526,42 @@ var coderBytesSlice = pointerCoderFuncs{ unmarshal: consumeBytesSlice, } +// appendBytesSliceValidateUTF8 encodes a [][]byte pointer as a repeated Bytes. +func appendBytesSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) { + s := *p.BytesSlice() + for _, v := range s { + b = wire.AppendVarint(b, wiretag) + b = wire.AppendBytes(b, v) + if !utf8.Valid(v) { + return b, errInvalidUTF8{} + } + } + return b, nil +} + +// consumeBytesSliceValidateUTF8 wire decodes a [][]byte pointer as a repeated Bytes. +func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + sp := p.BytesSlice() + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeBytes(b) + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid(v) { + return 0, errInvalidUTF8{} + } + *sp = append(*sp, append(([]byte)(nil), v...)) + return n, nil +} + +var coderBytesSliceValidateUTF8 = pointerCoderFuncs{ + size: sizeBytesSlice, + marshal: appendBytesSliceValidateUTF8, + unmarshal: consumeBytesSliceValidateUTF8, +} + // sizeBytesIface returns the size of wire encoding a []byte value as a Bytes. func sizeBytesIface(ival interface{}, tagsize int, _ marshalOptions) int { v := ival.([]byte) @@ -4381,6 +4594,38 @@ var coderBytesIface = ifaceCoderFuncs{ unmarshal: consumeBytesIface, } +// appendBytesIfaceValidateUTF8 encodes a []byte value as a Bytes. +func appendBytesIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _ marshalOptions) ([]byte, error) { + v := ival.([]byte) + b = wire.AppendVarint(b, wiretag) + b = wire.AppendBytes(b, v) + if !utf8.Valid(v) { + return b, errInvalidUTF8{} + } + return b, nil +} + +// consumeBytesIfaceValidateUTF8 decodes a []byte value as a Bytes. +func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) { + if wtyp != wire.BytesType { + return nil, 0, errUnknown + } + v, n := wire.ConsumeBytes(b) + if n < 0 { + return nil, 0, wire.ParseError(n) + } + if !utf8.Valid(v) { + return nil, 0, errInvalidUTF8{} + } + return append(([]byte)(nil), v...), n, nil +} + +var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{ + size: sizeBytesIface, + marshal: appendBytesIfaceValidateUTF8, + unmarshal: consumeBytesIfaceValidateUTF8, +} + // sizeBytesSliceIface returns the size of wire encoding a [][]byte value as a repeated Bytes. func sizeBytesSliceIface(ival interface{}, tagsize int, _ marshalOptions) (size int) { s := *ival.(*[][]byte) diff --git a/internal/impl/codec_tables.go b/internal/impl/codec_tables.go index 564187e4..3ff42600 100644 --- a/internal/impl/codec_tables.go +++ b/internal/impl/codec_tables.go @@ -9,6 +9,7 @@ import ( "reflect" "google.golang.org/protobuf/internal/encoding/wire" + "google.golang.org/protobuf/internal/strs" pref "google.golang.org/protobuf/reflect/protoreflect" ) @@ -98,12 +99,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs { return coderDoubleSlice } case pref.StringKind: - if ft.Kind() == reflect.String && fd.Syntax() == pref.Proto3 { + if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return coderStringSliceValidateUTF8 } if ft.Kind() == reflect.String { return coderStringSlice } + if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { + return coderBytesSliceValidateUTF8 + } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return coderBytesSlice } @@ -251,9 +255,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs { return coderDoubleNoZero } case pref.StringKind: - if ft.Kind() == reflect.String { + if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return coderStringNoZeroValidateUTF8 } + if ft.Kind() == reflect.String { + return coderStringNoZero + } + if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { + return coderBytesNoZeroValidateUTF8 + } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return coderBytesNoZero } @@ -392,12 +402,15 @@ func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs { return coderDouble } case pref.StringKind: - if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String { + if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return coderStringValidateUTF8 } if ft.Kind() == reflect.String { return coderString } + if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { + return coderBytesValidateUTF8 + } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return coderBytes } @@ -620,12 +633,15 @@ func encoderFuncsForValue(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFu return coderDoubleIface } case pref.StringKind: - if fd.Syntax() == pref.Proto3 && ft.Kind() == reflect.String { + if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return coderStringIfaceValidateUTF8 } if ft.Kind() == reflect.String { return coderStringIface } + if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { + return coderBytesIfaceValidateUTF8 + } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return coderBytesIface } diff --git a/internal/strs/strings.go b/internal/strs/strings.go index 295bd296..af5f197c 100644 --- a/internal/strs/strings.go +++ b/internal/strs/strings.go @@ -8,8 +8,21 @@ package strs import ( "strings" "unicode" + + "google.golang.org/protobuf/internal/flags" + "google.golang.org/protobuf/reflect/protoreflect" ) +// EnforceUTF8 reports whether to enforce strict UTF-8 validation. +func EnforceUTF8(fd protoreflect.FieldDescriptor) bool { + if flags.Proto1Legacy { + if fd, ok := fd.(interface{ EnforceUTF8() bool }); ok { + return fd.EnforceUTF8() + } + } + return fd.Syntax() == protoreflect.Proto3 +} + // JSONCamelCase converts a snake_case identifier to a camelCase identifier, // according to the protobuf JSON specification. func JSONCamelCase(s string) string { diff --git a/proto/decode_gen.go b/proto/decode_gen.go index a2722428..dbb4c877 100644 --- a/proto/decode_gen.go +++ b/proto/decode_gen.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -154,7 +155,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) } return protoreflect.ValueOf(string(v)), n, nil @@ -550,7 +551,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - if fd.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) { + if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return 0, errors.InvalidUTF8(string(fd.FullName())) } list.Append(protoreflect.ValueOf(string(v))) diff --git a/proto/decode_test.go b/proto/decode_test.go index 5fa3a0f4..ce2e1af9 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -12,13 +12,20 @@ import ( protoV1 "github.com/golang/protobuf/proto" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/internal/encoding/pack" + "google.golang.org/protobuf/internal/filedesc" + "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/prototype" + "google.golang.org/protobuf/runtime/protoimpl" legacypb "google.golang.org/protobuf/internal/testprotos/legacy" legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5" testpb "google.golang.org/protobuf/internal/testprotos/test" test3pb "google.golang.org/protobuf/internal/testprotos/test3" + "google.golang.org/protobuf/types/descriptorpb" ) type testProto struct { @@ -85,6 +92,23 @@ func TestDecodeInvalidUTF8(t *testing.T) { } } +func TestDecodeNoEnforceUTF8(t *testing.T) { + for _, test := range noEnforceUTF8TestProtos { + for _, want := range test.decodeTo { + t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { + got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) + err := proto.Unmarshal(test.wire, got) + switch { + case flags.Proto1Legacy && err != nil: + t.Errorf("Unmarshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want)) + case !flags.Proto1Legacy && err == nil: + t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want)) + } + }) + } + } +} + var testProtos = []testProto{ { desc: "basic scalar types", @@ -1442,6 +1466,129 @@ var invalidUTF8TestProtos = []testProto{ }, } +var noEnforceUTF8TestProtos = []testProto{ + { + desc: "invalid UTF-8 in optional string field", + decodeTo: []proto.Message{&TestNoEnforceUTF8{ + OptionalString: string("abc\xff"), + }}, + wire: pack.Message{ + pack.Tag{1, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in optional string field of Go bytes", + decodeTo: []proto.Message{&TestNoEnforceUTF8{ + OptionalBytes: []byte("abc\xff"), + }}, + wire: pack.Message{ + pack.Tag{2, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in repeated string field", + decodeTo: []proto.Message{&TestNoEnforceUTF8{ + RepeatedString: []string{string("foo"), string("abc\xff")}, + }}, + wire: pack.Message{ + pack.Tag{3, pack.BytesType}, pack.String("foo"), + pack.Tag{3, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in repeated string field of Go bytes", + decodeTo: []proto.Message{&TestNoEnforceUTF8{ + RepeatedBytes: [][]byte{[]byte("foo"), []byte("abc\xff")}, + }}, + wire: pack.Message{ + pack.Tag{4, pack.BytesType}, pack.String("foo"), + pack.Tag{4, pack.BytesType}, pack.String("abc\xff"), + }.Marshal(), + }, + { + desc: "invalid UTF-8 in oneof string field", + decodeTo: []proto.Message{ + &TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofString{string("abc\xff")}}, + }, + wire: pack.Message{pack.Tag{5, pack.BytesType}, pack.String("abc\xff")}.Marshal(), + }, + { + desc: "invalid UTF-8 in oneof string field of Go bytes", + decodeTo: []proto.Message{ + &TestNoEnforceUTF8{OneofField: &TestNoEnforceUTF8_OneofBytes{[]byte("abc\xff")}}, + }, + wire: pack.Message{pack.Tag{6, pack.BytesType}, pack.String("abc\xff")}.Marshal(), + }, +} + +type TestNoEnforceUTF8 struct { + OptionalString string `protobuf:"bytes,1,opt,name=optional_string"` + OptionalBytes []byte `protobuf:"bytes,2,opt,name=optional_bytes"` + RepeatedString []string `protobuf:"bytes,3,rep,name=repeated_string"` + RepeatedBytes [][]byte `protobuf:"bytes,4,rep,name=repeated_bytes"` + OneofField isOneofField `protobuf_oneof:"oneof_field"` +} + +type isOneofField interface{ isOneofField() } + +type TestNoEnforceUTF8_OneofString struct { + OneofString string `protobuf:"bytes,5,opt,name=oneof_string,oneof"` +} +type TestNoEnforceUTF8_OneofBytes struct { + OneofBytes []byte `protobuf:"bytes,6,opt,name=oneof_bytes,oneof"` +} + +func (*TestNoEnforceUTF8_OneofString) isOneofField() {} +func (*TestNoEnforceUTF8_OneofBytes) isOneofField() {} + +func (m *TestNoEnforceUTF8) ProtoReflect() pref.Message { + return messageInfo_TestNoEnforceUTF8.MessageOf(m) +} + +var messageInfo_TestNoEnforceUTF8 = protoimpl.MessageInfo{ + GoType: reflect.TypeOf((*TestNoEnforceUTF8)(nil)), + PBType: &prototype.Message{ + MessageDescriptor: func() protoreflect.MessageDescriptor { + pb := new(descriptorpb.FileDescriptorProto) + if err := prototext.Unmarshal([]byte(` + syntax: "proto3" + name: "test.proto" + message_type: [{ + name: "TestNoEnforceUTF8" + field: [ + {name:"optional_string" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, + {name:"optional_bytes" number:2 label:LABEL_OPTIONAL type:TYPE_STRING}, + {name:"repeated_string" number:3 label:LABEL_REPEATED type:TYPE_STRING}, + {name:"repeated_bytes" number:4 label:LABEL_REPEATED type:TYPE_STRING}, + {name:"oneof_string" number:5 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0}, + {name:"oneof_bytes" number:6 label:LABEL_OPTIONAL type:TYPE_STRING, oneof_index:0} + ] + oneof_decl: [{name:"oneof_field"}] + }] + `), pb); err != nil { + panic(err) + } + fd, err := protodesc.NewFile(pb, nil) + if err != nil { + panic(err) + } + md := fd.Messages().Get(0) + for i := 0; i < md.Fields().Len(); i++ { + md.Fields().Get(i).(*filedesc.Field).L1.HasEnforceUTF8 = true + md.Fields().Get(i).(*filedesc.Field).L1.EnforceUTF8 = false + } + return md + }(), + NewMessage: func() pref.Message { + return pref.ProtoMessage(new(TestNoEnforceUTF8)).ProtoReflect() + }, + }, + OneofWrappers: []interface{}{ + (*TestNoEnforceUTF8_OneofString)(nil), + (*TestNoEnforceUTF8_OneofBytes)(nil), + }, +} + func build(m proto.Message, opts ...buildOpt) proto.Message { for _, opt := range opts { opt(m) diff --git a/proto/encode_gen.go b/proto/encode_gen.go index fe977e3c..77b65117 100644 --- a/proto/encode_gen.go +++ b/proto/encode_gen.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -67,7 +68,7 @@ func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescripto case protoreflect.DoubleKind: b = wire.AppendFixed64(b, math.Float64bits(v.Float())) case protoreflect.StringKind: - if fd.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) { + if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) { return b, errors.InvalidUTF8(string(fd.FullName())) } b = wire.AppendString(b, v.String()) diff --git a/proto/encode_test.go b/proto/encode_test.go index f90020a4..573a197c 100644 --- a/proto/encode_test.go +++ b/proto/encode_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/proto" test3pb "google.golang.org/protobuf/internal/testprotos/test3" @@ -97,6 +98,22 @@ func TestEncodeInvalidUTF8(t *testing.T) { } } +func TestEncodeNoEnforceUTF8(t *testing.T) { + for _, test := range noEnforceUTF8TestProtos { + for _, want := range test.decodeTo { + t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { + _, err := proto.Marshal(want) + switch { + case flags.Proto1Legacy && err != nil: + t.Errorf("Marshal returned unexpected error: %v\nMessage:\n%v", err, marshalText(want)) + case !flags.Proto1Legacy && err == nil: + t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want)) + } + }) + } + } +} + func TestEncodeRequiredFieldChecks(t *testing.T) { for _, test := range testProtos { if !test.partial {