diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go index 085710d4..0bbbbe97 100644 --- a/internal/cmd/generate-types/impl.go +++ b/internal/cmd/generate-types/impl.go @@ -152,10 +152,26 @@ func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions return b, nil } +{{if .ToGoTypeNoZero}} +// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}. +// The zero value is not decoded. +func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != {{.WireType.Expr}} { + return 0, errUnknown + } + v, n := {{template "Consume" .}} + if n < 0 { + return 0, wire.ParseError(n) + } + *p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}} + return n, nil +} +{{end}} + var coder{{.Name}}NoZero = pointerCoderFuncs{ size: size{{.Name}}NoZero, marshal: append{{.Name}}NoZero, - unmarshal: consume{{.Name}}, + unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}, } {{if or (eq .Name "Bytes") (eq .Name "String")}} @@ -174,10 +190,28 @@ func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ ma return b, nil } +{{if .ToGoTypeNoZero}} +// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}. +func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != {{.WireType.Expr}} { + return 0, errUnknown + } + v, n := {{template "Consume" .}} + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) { + return 0, errInvalidUTF8{} + } + *p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}} + return n, nil +} +{{end}} + var coder{{.Name}}NoZeroValidateUTF8 = pointerCoderFuncs{ size: size{{.Name}}NoZero, marshal: append{{.Name}}NoZeroValidateUTF8, - unmarshal: consume{{.Name}}ValidateUTF8, + unmarshal: consume{{.Name}}{{if .ToGoTypeNoZero}}NoZero{{end}}ValidateUTF8, } {{end}} @@ -551,6 +585,9 @@ var coder{{.Name}}PackedSliceIface = ifaceCoderFuncs{ {{end -}} {{end -}} +// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices. +var emptyBuf [0]byte + var wireTypes = map[protoreflect.Kind]wire.Type{ {{range . -}} protoreflect.{{.Name}}Kind: {{.WireType.Expr}}, diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index 023cde8a..5cbf4ceb 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -85,10 +85,11 @@ type ProtoKind struct { FromValue Expr // Conversions to/from generated structures. - GoType GoType - ToGoType Expr - FromGoType Expr - NoPointer bool + GoType GoType + ToGoType Expr + ToGoTypeNoZero Expr + FromGoType Expr + NoPointer bool } func (k ProtoKind) Expr() Expr { @@ -229,14 +230,15 @@ var ProtoKinds = []ProtoKind{ FromGoType: "v", }, { - Name: "Bytes", - WireType: WireBytes, - ToValue: "append(([]byte)(nil), v...)", - FromValue: "v.Bytes()", - GoType: GoBytes, - ToGoType: "append(([]byte)(nil), v...)", - FromGoType: "v", - NoPointer: true, + Name: "Bytes", + WireType: WireBytes, + ToValue: "append(([]byte)(nil), v...)", + FromValue: "v.Bytes()", + GoType: GoBytes, + ToGoType: "append(emptyBuf[:], v...)", + ToGoTypeNoZero: "append(([]byte)(nil), v...)", + FromGoType: "v", + NoPointer: true, }, { Name: "Message", diff --git a/internal/impl/codec_gen.go b/internal/impl/codec_gen.go index 46380f57..f40af3dc 100644 --- a/internal/impl/codec_gen.go +++ b/internal/impl/codec_gen.go @@ -4395,7 +4395,7 @@ func consumeBytes(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n in if n < 0 { return 0, wire.ParseError(n) } - *p.Bytes() = append(([]byte)(nil), v...) + *p.Bytes() = append(emptyBuf[:], v...) return n, nil } @@ -4428,7 +4428,7 @@ func consumeBytesValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOp if !utf8.Valid(v) { return 0, errInvalidUTF8{} } - *p.Bytes() = append(([]byte)(nil), v...) + *p.Bytes() = append(emptyBuf[:], v...) return n, nil } @@ -4460,10 +4460,24 @@ func appendBytesNoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([ return b, nil } +// consumeBytesNoZero wire decodes a []byte pointer as a Bytes. +// The zero value is not decoded. +func consumeBytesNoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeBytes(b) + if n < 0 { + return 0, wire.ParseError(n) + } + *p.Bytes() = append(([]byte)(nil), v...) + return n, nil +} + var coderBytesNoZero = pointerCoderFuncs{ size: sizeBytesNoZero, marshal: appendBytesNoZero, - unmarshal: consumeBytes, + unmarshal: consumeBytesNoZero, } // appendBytesNoZeroValidateUTF8 wire encodes a []byte pointer as a Bytes. @@ -4481,10 +4495,26 @@ func appendBytesNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marsha return b, nil } +// consumeBytesNoZeroValidateUTF8 wire decodes a []byte pointer as a Bytes. +func consumeBytesNoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) { + if wtyp != wire.BytesType { + return 0, errUnknown + } + v, n := wire.ConsumeBytes(b) + if n < 0 { + return 0, wire.ParseError(n) + } + if !utf8.Valid(v) { + return 0, errInvalidUTF8{} + } + *p.Bytes() = append(([]byte)(nil), v...) + return n, nil +} + var coderBytesNoZeroValidateUTF8 = pointerCoderFuncs{ size: sizeBytesNoZero, marshal: appendBytesNoZeroValidateUTF8, - unmarshal: consumeBytesValidateUTF8, + unmarshal: consumeBytesNoZeroValidateUTF8, } // sizeBytesSlice returns the size of wire encoding a [][]byte pointer as a repeated Bytes. @@ -4516,7 +4546,7 @@ func consumeBytesSlice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) if n < 0 { return 0, wire.ParseError(n) } - *sp = append(*sp, append(([]byte)(nil), v...)) + *sp = append(*sp, append(emptyBuf[:], v...)) return n, nil } @@ -4552,7 +4582,7 @@ func consumeBytesSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmars if !utf8.Valid(v) { return 0, errInvalidUTF8{} } - *sp = append(*sp, append(([]byte)(nil), v...)) + *sp = append(*sp, append(emptyBuf[:], v...)) return n, nil } @@ -4585,7 +4615,7 @@ func consumeBytesIface(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ if n < 0 { return nil, 0, wire.ParseError(n) } - return append(([]byte)(nil), v...), n, nil + return append(emptyBuf[:], v...), n, nil } var coderBytesIface = ifaceCoderFuncs{ @@ -4617,7 +4647,7 @@ func consumeBytesIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp if !utf8.Valid(v) { return nil, 0, errInvalidUTF8{} } - return append(([]byte)(nil), v...), n, nil + return append(emptyBuf[:], v...), n, nil } var coderBytesIfaceValidateUTF8 = ifaceCoderFuncs{ @@ -4655,7 +4685,7 @@ func consumeBytesSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire if n < 0 { return nil, 0, wire.ParseError(n) } - *sp = append(*sp, append(([]byte)(nil), v...)) + *sp = append(*sp, append(emptyBuf[:], v...)) return ival, n, nil } @@ -4665,6 +4695,9 @@ var coderBytesSliceIface = ifaceCoderFuncs{ unmarshal: consumeBytesSliceIface, } +// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices. +var emptyBuf [0]byte + var wireTypes = map[protoreflect.Kind]wire.Type{ protoreflect.BoolKind: wire.VarintType, protoreflect.EnumKind: wire.VarintType, diff --git a/proto/decode_test.go b/proto/decode_test.go index c32c94ce..6088eb55 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -108,6 +108,21 @@ func TestDecodeNoEnforceUTF8(t *testing.T) { } } +func TestDecodeZeroLengthBytes(t *testing.T) { + // Verify that proto3 bytes fields don't give the mistaken + // impression that they preserve presence. + wire := pack.Message{ + pack.Tag{15, pack.BytesType}, pack.Bytes(nil), + }.Marshal() + m := &test3pb.TestAllTypes{} + if err := proto.Unmarshal(wire, m); err != nil { + t.Fatal(err) + } + if m.OptionalBytes != nil { + t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes) + } +} + var testProtos = []testProto{ { desc: "basic scalar types", @@ -183,6 +198,60 @@ var testProtos = []testProto{ pack.Tag{21, pack.VarintType}, pack.Varint(int(testpb.TestAllTypes_BAR)), }.Marshal(), }, + { + desc: "zero values", + decodeTo: []proto.Message{&testpb.TestAllTypes{ + OptionalInt32: proto.Int32(0), + OptionalInt64: proto.Int64(0), + OptionalUint32: proto.Uint32(0), + OptionalUint64: proto.Uint64(0), + OptionalSint32: proto.Int32(0), + OptionalSint64: proto.Int64(0), + OptionalFixed32: proto.Uint32(0), + OptionalFixed64: proto.Uint64(0), + OptionalSfixed32: proto.Int32(0), + OptionalSfixed64: proto.Int64(0), + OptionalFloat: proto.Float32(0), + OptionalDouble: proto.Float64(0), + OptionalBool: proto.Bool(false), + OptionalString: proto.String(""), + OptionalBytes: []byte{}, + }, &test3pb.TestAllTypes{}, build( + &testpb.TestAllExtensions{}, + extend(testpb.E_OptionalInt32Extension, int32(0)), + extend(testpb.E_OptionalInt64Extension, int64(0)), + extend(testpb.E_OptionalUint32Extension, uint32(0)), + extend(testpb.E_OptionalUint64Extension, uint64(0)), + extend(testpb.E_OptionalSint32Extension, int32(0)), + extend(testpb.E_OptionalSint64Extension, int64(0)), + extend(testpb.E_OptionalFixed32Extension, uint32(0)), + extend(testpb.E_OptionalFixed64Extension, uint64(0)), + extend(testpb.E_OptionalSfixed32Extension, int32(0)), + extend(testpb.E_OptionalSfixed64Extension, int64(0)), + extend(testpb.E_OptionalFloatExtension, float32(0)), + extend(testpb.E_OptionalDoubleExtension, float64(0)), + extend(testpb.E_OptionalBoolExtension, bool(false)), + extend(testpb.E_OptionalStringExtension, string("")), + extend(testpb.E_OptionalBytesExtension, []byte{}), + )}, + wire: pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(0), + pack.Tag{2, pack.VarintType}, pack.Varint(0), + pack.Tag{3, pack.VarintType}, pack.Uvarint(0), + pack.Tag{4, pack.VarintType}, pack.Uvarint(0), + pack.Tag{5, pack.VarintType}, pack.Svarint(0), + pack.Tag{6, pack.VarintType}, pack.Svarint(0), + pack.Tag{7, pack.Fixed32Type}, pack.Uint32(0), + pack.Tag{8, pack.Fixed64Type}, pack.Uint64(0), + pack.Tag{9, pack.Fixed32Type}, pack.Int32(0), + pack.Tag{10, pack.Fixed64Type}, pack.Int64(0), + pack.Tag{11, pack.Fixed32Type}, pack.Float32(0), + pack.Tag{12, pack.Fixed64Type}, pack.Float64(0), + pack.Tag{13, pack.VarintType}, pack.Bool(false), + pack.Tag{14, pack.BytesType}, pack.String(""), + pack.Tag{15, pack.BytesType}, pack.Bytes(nil), + }.Marshal(), + }, { desc: "groups", decodeTo: []proto.Message{&testpb.TestAllTypes{