mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-23 19:21:40 +00:00
proto, internal/impl: zero-length proto2 bytes fields should be non-nil
Fix decoding of zero-length bytes fields to produce a non-nil []byte. Change-Id: Ifb7791a47df81091700f7226523371d1386fb1ad Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188765 Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
afd3633ce3
commit
8003f08e51
@ -152,10 +152,26 @@ func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions
|
||||
return b, nil
|
||||
}
|
||||
|
||||
{{if .ToGoTypeNoZero}}
|
||||
// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}.
|
||||
// The zero value is not decoded.
|
||||
func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
|
||||
if wtyp != {{.WireType.Expr}} {
|
||||
return 0, errUnknown
|
||||
}
|
||||
v, n := {{template "Consume" .}}
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
|
||||
return n, nil
|
||||
}
|
||||
{{end}}
|
||||
|
||||
var coder{{.Name}}NoZero = pointerCoderFuncs{
|
||||
size: size{{.Name}}NoZero,
|
||||
marshal: append{{.Name}}NoZero,
|
||||
unmarshal: consume{{.Name}},
|
||||
unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}},
|
||||
}
|
||||
|
||||
{{if or (eq .Name "Bytes") (eq .Name "String")}}
|
||||
@ -174,10 +190,28 @@ func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ ma
|
||||
return b, nil
|
||||
}
|
||||
|
||||
{{if .ToGoTypeNoZero}}
|
||||
// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
|
||||
func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
|
||||
if wtyp != {{.WireType.Expr}} {
|
||||
return 0, errUnknown
|
||||
}
|
||||
v, n := {{template "Consume" .}}
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
|
||||
return 0, errInvalidUTF8{}
|
||||
}
|
||||
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
|
||||
return n, nil
|
||||
}
|
||||
{{end}}
|
||||
|
||||
var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{
|
||||
size: size{{.Name}}NoZero,
|
||||
marshal: append{{.Name}}NoZeroValidateUTF8,
|
||||
unmarshal: consume{{.Name}}ValidateUTF8,
|
||||
unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}ValidateUTF8,
|
||||
}
|
||||
{{end}}
|
||||
|
||||
@ -551,6 +585,9 @@ var coder{{.Name}}PackedSliceIface = ifaceCoderFuncs{
|
||||
{{end -}}
|
||||
{{end -}}
|
||||
|
||||
// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
|
||||
var emptyBuf [0]byte
|
||||
|
||||
var wireTypes = map[protoreflect.Kind]wire.Type{
|
||||
{{range . -}}
|
||||
protoreflect.{{.Name}}Kind: {{.WireType.Expr}},
|
||||
|
@ -85,10 +85,11 @@ type ProtoKind struct {
|
||||
FromValue Expr
|
||||
|
||||
// Conversions to/from generated structures.
|
||||
GoType GoType
|
||||
ToGoType Expr
|
||||
FromGoType Expr
|
||||
NoPointer bool
|
||||
GoType GoType
|
||||
ToGoType Expr
|
||||
ToGoTypeNoZero Expr
|
||||
FromGoType Expr
|
||||
NoPointer bool
|
||||
}
|
||||
|
||||
func (k ProtoKind) Expr() Expr {
|
||||
@ -229,14 +230,15 @@ var ProtoKinds = []ProtoKind{
|
||||
FromGoType: "v",
|
||||
},
|
||||
{
|
||||
Name: "Bytes",
|
||||
WireType: WireBytes,
|
||||
ToValue: "append(([]byte)(nil), v...)",
|
||||
FromValue: "v.Bytes()",
|
||||
GoType: GoBytes,
|
||||
ToGoType: "append(([]byte)(nil), v...)",
|
||||
FromGoType: "v",
|
||||
NoPointer: true,
|
||||
Name: "Bytes",
|
||||
WireType: WireBytes,
|
||||
ToValue: "append(([]byte)(nil), v...)",
|
||||
FromValue: "v.Bytes()",
|
||||
GoType: GoBytes,
|
||||
ToGoType: "append(emptyBuf[:], v...)",
|
||||
ToGoTypeNoZero: "append(([]byte)(nil), v...)",
|
||||
FromGoType: "v",
|
||||
NoPointer: true,
|
||||
},
|
||||
{
|
||||
Name: "Message",
|
||||
|
@ -4395,7 +4395,7 @@ func consumeBytes(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n in
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
*p.Bytes() = append(([]byte)(nil), v...)
|
||||
*p.Bytes() = append(emptyBuf[:], v...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -4428,7 +4428,7 @@ func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOp
|
||||
if !utf8.Valid(v) {
|
||||
return 0, errInvalidUTF8{}
|
||||
}
|
||||
*p.Bytes() = append(([]byte)(nil), v...)
|
||||
*p.Bytes() = append(emptyBuf[:], v...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -4460,10 +4460,24 @@ func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// consumeBytesNoZero wire decodes a []byte pointer as a Bytes.
|
||||
// The zero value is not decoded.
|
||||
func consumeBytesNoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
|
||||
if wtyp != wire.BytesType {
|
||||
return 0, errUnknown
|
||||
}
|
||||
v, n := wire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
*p.Bytes() = append(([]byte)(nil), v...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var coderBytesNoZero = pointerCoderFuncs{
|
||||
size: sizeBytesNoZero,
|
||||
marshal: appendBytesNoZero,
|
||||
unmarshal: consumeBytes,
|
||||
unmarshal: consumeBytesNoZero,
|
||||
}
|
||||
|
||||
// appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes.
|
||||
@ -4481,10 +4495,26 @@ func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marsha
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// consumeBytesNoZeroValidateUTF8 wire decodes a []byte pointer as a Bytes.
|
||||
func consumeBytesNoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
|
||||
if wtyp != wire.BytesType {
|
||||
return 0, errUnknown
|
||||
}
|
||||
v, n := wire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
if !utf8.Valid(v) {
|
||||
return 0, errInvalidUTF8{}
|
||||
}
|
||||
*p.Bytes() = append(([]byte)(nil), v...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{
|
||||
size: sizeBytesNoZero,
|
||||
marshal: appendBytesNoZeroValidateUTF8,
|
||||
unmarshal: consumeBytesValidateUTF8,
|
||||
unmarshal: consumeBytesNoZeroValidateUTF8,
|
||||
}
|
||||
|
||||
// sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes.
|
||||
@ -4516,7 +4546,7 @@ func consumeBytesSlice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions)
|
||||
if n < 0 {
|
||||
return 0, wire.ParseError(n)
|
||||
}
|
||||
*sp = append(*sp, append(([]byte)(nil), v...))
|
||||
*sp = append(*sp, append(emptyBuf[:], v...))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -4552,7 +4582,7 @@ func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmars
|
||||
if !utf8.Valid(v) {
|
||||
return 0, errInvalidUTF8{}
|
||||
}
|
||||
*sp = append(*sp, append(([]byte)(nil), v...))
|
||||
*sp = append(*sp, append(emptyBuf[:], v...))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -4585,7 +4615,7 @@ func consumeBytesIface(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _
|
||||
if n < 0 {
|
||||
return nil, 0, wire.ParseError(n)
|
||||
}
|
||||
return append(([]byte)(nil), v...), n, nil
|
||||
return append(emptyBuf[:], v...), n, nil
|
||||
}
|
||||
|
||||
var coderBytesIface = ifaceCoderFuncs{
|
||||
@ -4617,7 +4647,7 @@ func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp
|
||||
if !utf8.Valid(v) {
|
||||
return nil, 0, errInvalidUTF8{}
|
||||
}
|
||||
return append(([]byte)(nil), v...), n, nil
|
||||
return append(emptyBuf[:], v...), n, nil
|
||||
}
|
||||
|
||||
var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{
|
||||
@ -4655,7 +4685,7 @@ func consumeBytesSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire
|
||||
if n < 0 {
|
||||
return nil, 0, wire.ParseError(n)
|
||||
}
|
||||
*sp = append(*sp, append(([]byte)(nil), v...))
|
||||
*sp = append(*sp, append(emptyBuf[:], v...))
|
||||
return ival, n, nil
|
||||
}
|
||||
|
||||
@ -4665,6 +4695,9 @@ var coderBytesSliceIface = ifaceCoderFuncs{
|
||||
unmarshal: consumeBytesSliceIface,
|
||||
}
|
||||
|
||||
// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
|
||||
var emptyBuf [0]byte
|
||||
|
||||
var wireTypes = map[protoreflect.Kind]wire.Type{
|
||||
protoreflect.BoolKind: wire.VarintType,
|
||||
protoreflect.EnumKind: wire.VarintType,
|
||||
|
@ -108,6 +108,21 @@ func TestDecodeNoEnforceUTF8(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeZeroLengthBytes(t *testing.T) {
|
||||
// Verify that proto3 bytes fields don't give the mistaken
|
||||
// impression that they preserve presence.
|
||||
wire := pack.Message{
|
||||
pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
|
||||
}.Marshal()
|
||||
m := &test3pb.TestAllTypes{}
|
||||
if err := proto.Unmarshal(wire, m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m.OptionalBytes != nil {
|
||||
t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
|
||||
}
|
||||
}
|
||||
|
||||
var testProtos = []testProto{
|
||||
{
|
||||
desc: "basic scalar types",
|
||||
@ -183,6 +198,60 @@ var testProtos = []testProto{
|
||||
pack.Tag{21, pack.VarintType}, pack.Varint(int(testpb.TestAllTypes_BAR)),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "zero values",
|
||||
decodeTo: []proto.Message{&testpb.TestAllTypes{
|
||||
OptionalInt32: proto.Int32(0),
|
||||
OptionalInt64: proto.Int64(0),
|
||||
OptionalUint32: proto.Uint32(0),
|
||||
OptionalUint64: proto.Uint64(0),
|
||||
OptionalSint32: proto.Int32(0),
|
||||
OptionalSint64: proto.Int64(0),
|
||||
OptionalFixed32: proto.Uint32(0),
|
||||
OptionalFixed64: proto.Uint64(0),
|
||||
OptionalSfixed32: proto.Int32(0),
|
||||
OptionalSfixed64: proto.Int64(0),
|
||||
OptionalFloat: proto.Float32(0),
|
||||
OptionalDouble: proto.Float64(0),
|
||||
OptionalBool: proto.Bool(false),
|
||||
OptionalString: proto.String(""),
|
||||
OptionalBytes: []byte{},
|
||||
}, &test3pb.TestAllTypes{}, build(
|
||||
&testpb.TestAllExtensions{},
|
||||
extend(testpb.E_OptionalInt32Extension, int32(0)),
|
||||
extend(testpb.E_OptionalInt64Extension, int64(0)),
|
||||
extend(testpb.E_OptionalUint32Extension, uint32(0)),
|
||||
extend(testpb.E_OptionalUint64Extension, uint64(0)),
|
||||
extend(testpb.E_OptionalSint32Extension, int32(0)),
|
||||
extend(testpb.E_OptionalSint64Extension, int64(0)),
|
||||
extend(testpb.E_OptionalFixed32Extension, uint32(0)),
|
||||
extend(testpb.E_OptionalFixed64Extension, uint64(0)),
|
||||
extend(testpb.E_OptionalSfixed32Extension, int32(0)),
|
||||
extend(testpb.E_OptionalSfixed64Extension, int64(0)),
|
||||
extend(testpb.E_OptionalFloatExtension, float32(0)),
|
||||
extend(testpb.E_OptionalDoubleExtension, float64(0)),
|
||||
extend(testpb.E_OptionalBoolExtension, bool(false)),
|
||||
extend(testpb.E_OptionalStringExtension, string("")),
|
||||
extend(testpb.E_OptionalBytesExtension, []byte{}),
|
||||
)},
|
||||
wire: pack.Message{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(0),
|
||||
pack.Tag{2, pack.VarintType}, pack.Varint(0),
|
||||
pack.Tag{3, pack.VarintType}, pack.Uvarint(0),
|
||||
pack.Tag{4, pack.VarintType}, pack.Uvarint(0),
|
||||
pack.Tag{5, pack.VarintType}, pack.Svarint(0),
|
||||
pack.Tag{6, pack.VarintType}, pack.Svarint(0),
|
||||
pack.Tag{7, pack.Fixed32Type}, pack.Uint32(0),
|
||||
pack.Tag{8, pack.Fixed64Type}, pack.Uint64(0),
|
||||
pack.Tag{9, pack.Fixed32Type}, pack.Int32(0),
|
||||
pack.Tag{10, pack.Fixed64Type}, pack.Int64(0),
|
||||
pack.Tag{11, pack.Fixed32Type}, pack.Float32(0),
|
||||
pack.Tag{12, pack.Fixed64Type}, pack.Float64(0),
|
||||
pack.Tag{13, pack.VarintType}, pack.Bool(false),
|
||||
pack.Tag{14, pack.BytesType}, pack.String(""),
|
||||
pack.Tag{15, pack.BytesType}, pack.Bytes(nil),
|
||||
}.Marshal(),
|
||||
},
|
||||
{
|
||||
desc: "groups",
|
||||
decodeTo: []proto.Message{&testpb.TestAllTypes{
|
||||
|
Loading…
x
Reference in New Issue
Block a user