From d3470999428befce9bbefe77980ff65ac5a494c4 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Mon, 19 Oct 2020 20:57:54 -0700 Subject: [PATCH] internal/cmd/generate-types: use ConsumeBytes instead of ConsumeString The protowire.{ConsumeBytes,ConsumeString} funcs are identical except that the latter allocates a string by implicitly converting the []byte. Avoid using ConsumeString since we can do the conversion ourselves at a latter point and sometimes avoid the allocation. Change-Id: Idf31edc013b72ee5ee8461a68d10e57ad461d95c Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/263628 Trust: Joe Tsai Reviewed-by: Damien Neil --- internal/cmd/generate-types/impl.go | 14 +++++----- internal/cmd/generate-types/proto.go | 2 +- internal/impl/codec_gen.go | 38 ++++++++++++++-------------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go index bef43ee7..07657c44 100644 --- a/internal/cmd/generate-types/impl.go +++ b/internal/cmd/generate-types/impl.go @@ -69,9 +69,7 @@ b = protowire.Append{{.WireType}}(b, {{.FromValue}}) {{- end -}} {{- define "Consume" -}} -{{- if eq .Name "String" -}} -v, n := protowire.ConsumeString(b) -{{- else if eq .WireType "Varint" -}} +{{- if eq .WireType "Varint" -}} var v uint64 var n int if len(b) >= 1 && b[0] < 0x80 { @@ -149,7 +147,7 @@ func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f *c if n < 0 { return out, errDecode } - if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } *p.{{.GoType.PointerMethod}}() = {{.ToGoType}} @@ -237,7 +235,7 @@ func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp protowire.Type if n < 0 { return out, errDecode } - if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } *p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}} @@ -321,7 +319,7 @@ func consume{{.Name}}PtrValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f if n < 0 { return out, errDecode } - if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } vp := p.{{.GoType.PointerMethod}}Ptr() @@ -429,7 +427,7 @@ func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp protowire.Type, if n < 0 { return out, errDecode } - if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } sp := p.{{.GoType.PointerMethod}}Slice() @@ -553,7 +551,7 @@ func consume{{.Name}}ValueValidateUTF8(b []byte, _ protoreflect.Value, _ protowi if n < 0 { return protoreflect.Value{}, out, errDecode } - if !utf8.ValidString(v) { + if !utf8.Valid(v) { return protoreflect.Value{}, out, errInvalidUTF8{} } out.n = n diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index e91cf8d7..c70cc21f 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -239,7 +239,7 @@ var ProtoKinds = []ProtoKind{ ToValue: "protoreflect.ValueOfString(string(v))", FromValue: "v.String()", GoType: GoString, - ToGoType: "v", + ToGoType: "string(v)", FromGoType: "v", }, { diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go index 2da10292..1a509b63 100644 --- a/internal/impl/codec_gen.go +++ b/internal/impl/codec_gen.go @@ -4935,11 +4935,11 @@ func consumeString(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } - *p.String() = v + *p.String() = string(v) out.n = n return out, nil } @@ -4967,14 +4967,14 @@ func consumeStringValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f *code if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } - if !utf8.ValidString(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } - *p.String() = v + *p.String() = string(v) out.n = n return out, nil } @@ -5058,7 +5058,7 @@ func consumeStringPtr(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } @@ -5066,7 +5066,7 @@ func consumeStringPtr(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInf if *vp == nil { *vp = new(string) } - **vp = v + **vp = string(v) out.n = n return out, nil } @@ -5095,18 +5095,18 @@ func consumeStringPtrValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f *c if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } - if !utf8.ValidString(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } vp := p.StringPtr() if *vp == nil { *vp = new(string) } - **vp = v + **vp = string(v) out.n = n return out, nil } @@ -5143,11 +5143,11 @@ func consumeStringSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldI if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } - *sp = append(*sp, v) + *sp = append(*sp, string(v)) out.n = n return out, nil } @@ -5177,15 +5177,15 @@ func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp protowire.Type, f if wtyp != protowire.BytesType { return out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return out, errDecode } - if !utf8.ValidString(v) { + if !utf8.Valid(v) { return out, errInvalidUTF8{} } sp := p.StringSlice() - *sp = append(*sp, v) + *sp = append(*sp, string(v)) out.n = n return out, nil } @@ -5214,7 +5214,7 @@ func consumeStringValue(b []byte, _ protoreflect.Value, _ protowire.Number, wtyp if wtyp != protowire.BytesType { return protoreflect.Value{}, out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return protoreflect.Value{}, out, errDecode } @@ -5244,11 +5244,11 @@ func consumeStringValueValidateUTF8(b []byte, _ protoreflect.Value, _ protowire. if wtyp != protowire.BytesType { return protoreflect.Value{}, out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return protoreflect.Value{}, out, errDecode } - if !utf8.ValidString(v) { + if !utf8.Valid(v) { return protoreflect.Value{}, out, errInvalidUTF8{} } out.n = n @@ -5289,7 +5289,7 @@ func consumeStringSliceValue(b []byte, listv protoreflect.Value, _ protowire.Num if wtyp != protowire.BytesType { return protoreflect.Value{}, out, errUnknown } - v, n := protowire.ConsumeString(b) + v, n := protowire.ConsumeBytes(b) if n < 0 { return protoreflect.Value{}, out, errDecode }