proto, internal/impl: zero-length proto2 bytes fields should be non-nil

Fix decoding of zero-length bytes fields to produce a non-nil []byte.

Change-Id: Ifb7791a47df81091700f7226523371d1386fb1ad
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188765
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-08-02 15:13:00 -07:00
parent afd3633ce3
commit 8003f08e51
4 changed files with 164 additions and 23 deletions

View File

@ -152,10 +152,26 @@ func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions
return b, nil
}
{{if .ToGoTypeNoZero}}
// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not decoded.
func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
return n, nil
}
{{end}}
var coder{{.Name}}NoZero = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZero,
unmarshal: consume{{.Name}},
unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}},
}
{{if or (eq .Name "Bytes") (eq .Name "String")}}
@ -174,10 +190,28 @@ func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ ma
return b, nil
}
{{if .ToGoTypeNoZero}}
// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
return n, nil
}
{{end}}
var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZeroValidateUTF8,
unmarshal: consume{{.Name}}ValidateUTF8,
unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}ValidateUTF8,
}
{{end}}
@ -551,6 +585,9 @@ var coder{{.Name}}PackedSliceIface = ifaceCoderFuncs{
{{end -}}
{{end -}}
// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
var emptyBuf [0]byte
var wireTypes = map[protoreflect.Kind]wire.Type{
{{range . -}}
protoreflect.{{.Name}}Kind: {{.WireType.Expr}},

View File

@ -85,10 +85,11 @@ type ProtoKind struct {
FromValue Expr
// Conversions to/from generated structures.
GoType GoType
ToGoType Expr
FromGoType Expr
NoPointer bool
GoType GoType
ToGoType Expr
ToGoTypeNoZero Expr
FromGoType Expr
NoPointer bool
}
func (k ProtoKind) Expr() Expr {
@ -229,14 +230,15 @@ var ProtoKinds = []ProtoKind{
FromGoType: "v",
},
{
Name: "Bytes",
WireType: WireBytes,
ToValue: "append(([]byte)(nil), v...)",
FromValue: "v.Bytes()",
GoType: GoBytes,
ToGoType: "append(([]byte)(nil), v...)",
FromGoType: "v",
NoPointer: true,
Name: "Bytes",
WireType: WireBytes,
ToValue: "append(([]byte)(nil), v...)",
FromValue: "v.Bytes()",
GoType: GoBytes,
ToGoType: "append(emptyBuf[:], v...)",
ToGoTypeNoZero: "append(([]byte)(nil), v...)",
FromGoType: "v",
NoPointer: true,
},
{
Name: "Message",

View File

@ -4395,7 +4395,7 @@ func consumeBytes(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n in
if n < 0 {
return 0, wire.ParseError(n)
}
*p.Bytes() = append(([]byte)(nil), v...)
*p.Bytes() = append(emptyBuf[:], v...)
return n, nil
}
@ -4428,7 +4428,7 @@ func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOp
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
*p.Bytes() = append(([]byte)(nil), v...)
*p.Bytes() = append(emptyBuf[:], v...)
return n, nil
}
@ -4460,10 +4460,24 @@ func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([
return b, nil
}
// consumeBytesNoZero wire decodes a []byte pointer as a Bytes.
// The zero value is not decoded.
func consumeBytesNoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
*p.Bytes() = append(([]byte)(nil), v...)
return n, nil
}
var coderBytesNoZero = pointerCoderFuncs{
size: sizeBytesNoZero,
marshal: appendBytesNoZero,
unmarshal: consumeBytes,
unmarshal: consumeBytesNoZero,
}
// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
@ -4481,10 +4495,26 @@ func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marsha
return b, nil
}
// consumeBytesNoZeroValidateUTF8 wire decodes a []byte pointer as a Bytes.
func consumeBytesNoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
*p.Bytes() = append(([]byte)(nil), v...)
return n, nil
}
var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeBytesNoZero,
marshal: appendBytesNoZeroValidateUTF8,
unmarshal: consumeBytesValidateUTF8,
unmarshal: consumeBytesNoZeroValidateUTF8,
}
// sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
@ -4516,7 +4546,7 @@ func consumeBytesSlice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions)
if n < 0 {
return 0, wire.ParseError(n)
}
*sp = append(*sp, append(([]byte)(nil), v...))
*sp = append(*sp, append(emptyBuf[:], v...))
return n, nil
}
@ -4552,7 +4582,7 @@ func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmars
if !utf8.Valid(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, append(([]byte)(nil), v...))
*sp = append(*sp, append(emptyBuf[:], v...))
return n, nil
}
@ -4585,7 +4615,7 @@ func consumeBytesIface(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _
if n < 0 {
return nil, 0, wire.ParseError(n)
}
return append(([]byte)(nil), v...), n, nil
return append(emptyBuf[:], v...), n, nil
}
var coderBytesIface = ifaceCoderFuncs{
@ -4617,7 +4647,7 @@ func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp
if !utf8.Valid(v) {
return nil, 0, errInvalidUTF8{}
}
return append(([]byte)(nil), v...), n, nil
return append(emptyBuf[:], v...), n, nil
}
var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
@ -4655,7 +4685,7 @@ func consumeBytesSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire
if n < 0 {
return nil, 0, wire.ParseError(n)
}
*sp = append(*sp, append(([]byte)(nil), v...))
*sp = append(*sp, append(emptyBuf[:], v...))
return ival, n, nil
}
@ -4665,6 +4695,9 @@ var coderBytesSliceIface = ifaceCoderFuncs{
unmarshal: consumeBytesSliceIface,
}
// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
var emptyBuf [0]byte
var wireTypes = map[protoreflect.Kind]wire.Type{
protoreflect.BoolKind: wire.VarintType,
protoreflect.EnumKind: wire.VarintType,

View File

@ -108,6 +108,21 @@ func TestDecodeNoEnforceUTF8(t *testing.T) {
}
}
func TestDecodeZeroLengthBytes(t *testing.T) {
// Verify that proto3 bytes fields don't give the mistaken
// impression that they preserve presence.
wire := pack.Message{
pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
}.Marshal()
m := &test3pb.TestAllTypes{}
if err := proto.Unmarshal(wire, m); err != nil {
t.Fatal(err)
}
if m.OptionalBytes != nil {
t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
}
}
var testProtos = []testProto{
{
desc: "basic scalar types",
@ -183,6 +198,60 @@ var testProtos = []testProto{
pack.Tag{21, pack.VarintType}, pack.Varint(int(testpb.TestAllTypes_BAR)),
}.Marshal(),
},
{
desc: "zero values",
decodeTo: []proto.Message{&testpb.TestAllTypes{
OptionalInt32: proto.Int32(0),
OptionalInt64: proto.Int64(0),
OptionalUint32: proto.Uint32(0),
OptionalUint64: proto.Uint64(0),
OptionalSint32: proto.Int32(0),
OptionalSint64: proto.Int64(0),
OptionalFixed32: proto.Uint32(0),
OptionalFixed64: proto.Uint64(0),
OptionalSfixed32: proto.Int32(0),
OptionalSfixed64: proto.Int64(0),
OptionalFloat: proto.Float32(0),
OptionalDouble: proto.Float64(0),
OptionalBool: proto.Bool(false),
OptionalString: proto.String(""),
OptionalBytes: []byte{},
}, &test3pb.TestAllTypes{}, build(
&testpb.TestAllExtensions{},
extend(testpb.E_OptionalInt32Extension, int32(0)),
extend(testpb.E_OptionalInt64Extension, int64(0)),
extend(testpb.E_OptionalUint32Extension, uint32(0)),
extend(testpb.E_OptionalUint64Extension, uint64(0)),
extend(testpb.E_OptionalSint32Extension, int32(0)),
extend(testpb.E_OptionalSint64Extension, int64(0)),
extend(testpb.E_OptionalFixed32Extension, uint32(0)),
extend(testpb.E_OptionalFixed64Extension, uint64(0)),
extend(testpb.E_OptionalSfixed32Extension, int32(0)),
extend(testpb.E_OptionalSfixed64Extension, int64(0)),
extend(testpb.E_OptionalFloatExtension, float32(0)),
extend(testpb.E_OptionalDoubleExtension, float64(0)),
extend(testpb.E_OptionalBoolExtension, bool(false)),
extend(testpb.E_OptionalStringExtension, string("")),
extend(testpb.E_OptionalBytesExtension, []byte{}),
)},
wire: pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(0),
pack.Tag{2, pack.VarintType}, pack.Varint(0),
pack.Tag{3, pack.VarintType}, pack.Uvarint(0),
pack.Tag{4, pack.VarintType}, pack.Uvarint(0),
pack.Tag{5, pack.VarintType}, pack.Svarint(0),
pack.Tag{6, pack.VarintType}, pack.Svarint(0),
pack.Tag{7, pack.Fixed32Type}, pack.Uint32(0),
pack.Tag{8, pack.Fixed64Type}, pack.Uint64(0),
pack.Tag{9, pack.Fixed32Type}, pack.Int32(0),
pack.Tag{10, pack.Fixed64Type}, pack.Int64(0),
pack.Tag{11, pack.Fixed32Type}, pack.Float32(0),
pack.Tag{12, pack.Fixed64Type}, pack.Float64(0),
pack.Tag{13, pack.VarintType}, pack.Bool(false),
pack.Tag{14, pack.BytesType}, pack.String(""),
pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
}.Marshal(),
},
{
desc: "groups",
decodeTo: []proto.Message{&testpb.TestAllTypes{