From c5060d2fe624120c91275235630f6afe6aed129f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 22 Aug 2019 17:01:56 -0700 Subject: [PATCH] reflect/protoreflect: add non-allocating Value constructors Passing a non-pointer type to protoreflect.NewValue causes an unnecessary allocation in order to store the value in an interface{}. While this allocation could be avoided by a smarter compiler, no such compiler exists today. Add functions for creating new values of a specific type, avoiding the allocation. (And also adding a small amount of type safety, although this is unlikely to be important.) Update the proto and internal/impl packages to use these functions. Change-Id: Ic733de22ddf19c530189166c853348e1b54b7391 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/191457 Reviewed-by: Joe Tsai --- internal/cmd/generate-types/proto.go | 57 ++++---- internal/impl/convert.go | 189 ++++++++++++++++++++++----- proto/decode_gen.go | 96 +++++++------- reflect/protoreflect/value_union.go | 95 +++++++++++--- 4 files changed, 314 insertions(+), 123 deletions(-) diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index 5cbf4ceb..1cd0ed7b 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -90,6 +90,7 @@ type ProtoKind struct { ToGoTypeNoZero Expr FromGoType Expr NoPointer bool + NoValueCodec bool } func (k ProtoKind) Expr() Expr { @@ -100,7 +101,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Bool", WireType: WireVarint, - ToValue: "wire.DecodeBool(v)", + ToValue: "protoreflect.ValueOfBool(wire.DecodeBool(v))", FromValue: "wire.EncodeBool(v.Bool())", GoType: GoBool, ToGoType: "wire.DecodeBool(v)", @@ -109,13 +110,13 @@ var ProtoKinds = []ProtoKind{ { Name: "Enum", WireType: WireVarint, - ToValue: "protoreflect.EnumNumber(v)", + ToValue: "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))", FromValue: "uint64(v.Enum())", }, { Name: "Int32", WireType: WireVarint, - ToValue: "int32(v)", + ToValue: "protoreflect.ValueOfInt32(int32(v))", FromValue: "uint64(int32(v.Int()))", GoType: GoInt32, ToGoType: "int32(v)", @@ -124,7 +125,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Sint32", WireType: WireVarint, - ToValue: "int32(wire.DecodeZigZag(v & math.MaxUint32))", + ToValue: "protoreflect.ValueOfInt32(int32(wire.DecodeZigZag(v & math.MaxUint32)))", FromValue: "wire.EncodeZigZag(int64(int32(v.Int())))", GoType: GoInt32, ToGoType: "int32(wire.DecodeZigZag(v & math.MaxUint32))", @@ -133,7 +134,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Uint32", WireType: WireVarint, - ToValue: "uint32(v)", + ToValue: "protoreflect.ValueOfUint32(uint32(v))", FromValue: "uint64(uint32(v.Uint()))", GoType: GoUint32, ToGoType: "uint32(v)", @@ -142,7 +143,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Int64", WireType: WireVarint, - ToValue: "int64(v)", + ToValue: "protoreflect.ValueOfInt64(int64(v))", FromValue: "uint64(v.Int())", GoType: GoInt64, ToGoType: "int64(v)", @@ -151,7 +152,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Sint64", WireType: WireVarint, - ToValue: "wire.DecodeZigZag(v)", + ToValue: "protoreflect.ValueOfInt64(wire.DecodeZigZag(v))", FromValue: "wire.EncodeZigZag(v.Int())", GoType: GoInt64, ToGoType: "wire.DecodeZigZag(v)", @@ -160,7 +161,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Uint64", WireType: WireVarint, - ToValue: "v", + ToValue: "protoreflect.ValueOfUint64(v)", FromValue: "v.Uint()", GoType: GoUint64, ToGoType: "v", @@ -169,7 +170,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Sfixed32", WireType: WireFixed32, - ToValue: "int32(v)", + ToValue: "protoreflect.ValueOfInt32(int32(v))", FromValue: "uint32(v.Int())", GoType: GoInt32, ToGoType: "int32(v)", @@ -178,7 +179,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Fixed32", WireType: WireFixed32, - ToValue: "uint32(v)", + ToValue: "protoreflect.ValueOfUint32(uint32(v))", FromValue: "uint32(v.Uint())", GoType: GoUint32, ToGoType: "v", @@ -187,7 +188,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Float", WireType: WireFixed32, - ToValue: "math.Float32frombits(uint32(v))", + ToValue: "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))", FromValue: "math.Float32bits(float32(v.Float()))", GoType: GoFloat32, ToGoType: "math.Float32frombits(v)", @@ -196,7 +197,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Sfixed64", WireType: WireFixed64, - ToValue: "int64(v)", + ToValue: "protoreflect.ValueOfInt64(int64(v))", FromValue: "uint64(v.Int())", GoType: GoInt64, ToGoType: "int64(v)", @@ -205,7 +206,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Fixed64", WireType: WireFixed64, - ToValue: "v", + ToValue: "protoreflect.ValueOfUint64(v)", FromValue: "v.Uint()", GoType: GoUint64, ToGoType: "v", @@ -214,7 +215,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Double", WireType: WireFixed64, - ToValue: "math.Float64frombits(v)", + ToValue: "protoreflect.ValueOfFloat64(math.Float64frombits(v))", FromValue: "math.Float64bits(v.Float())", GoType: GoFloat64, ToGoType: "math.Float64frombits(v)", @@ -223,7 +224,7 @@ var ProtoKinds = []ProtoKind{ { Name: "String", WireType: WireBytes, - ToValue: "string(v)", + ToValue: "protoreflect.ValueOfString(string(v))", FromValue: "v.String()", GoType: GoString, ToGoType: "v", @@ -232,7 +233,7 @@ var ProtoKinds = []ProtoKind{ { Name: "Bytes", WireType: WireBytes, - ToValue: "append(([]byte)(nil), v...)", + ToValue: "protoreflect.ValueOfBytes(append(([]byte)(nil), v...))", FromValue: "v.Bytes()", GoType: GoBytes, ToGoType: "append(emptyBuf[:], v...)", @@ -241,16 +242,18 @@ var ProtoKinds = []ProtoKind{ NoPointer: true, }, { - Name: "Message", - WireType: WireBytes, - ToValue: "v", - FromValue: "v", + Name: "Message", + WireType: WireBytes, + ToValue: "protoreflect.ValueOfBytes(v)", + FromValue: "v", + NoValueCodec: true, }, { - Name: "Group", - WireType: WireGroup, - ToValue: "v", - FromValue: "v", + Name: "Group", + WireType: WireGroup, + ToValue: "protoreflect.ValueOfBytes(v)", + FromValue: "v", + NoValueCodec: true, }, } @@ -282,7 +285,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) } {{end -}} - return protoreflect.ValueOf({{.ToValue}}), n, nil + return {{.ToValue}}, n, nil {{- end}} default: return val, 0, errUnknown @@ -305,7 +308,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf({{.ToValue}})) + list.Append({{.ToValue}}) } return n, nil } @@ -333,7 +336,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl } list.Append(protoreflect.ValueOf(m)) {{- else -}} - list.Append(protoreflect.ValueOf({{.ToValue}})) + list.Append({{.ToValue}}) {{- end}} return n, nil {{- end}} diff --git a/internal/impl/convert.go b/internal/impl/convert.go index 4b280ef1..9b74c037 100644 --- a/internal/impl/convert.go +++ b/internal/impl/convert.go @@ -81,11 +81,6 @@ var ( bytesZero = pref.ValueOf([]byte(nil)) ) -type scalarConverter struct { - goType, pbType reflect.Type - def pref.Value -} - func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter { defVal := func(fd pref.FieldDescriptor, zero pref.Value) pref.Value { if fd.Cardinality() == pref.Repeated { @@ -97,39 +92,39 @@ func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter { switch fd.Kind() { case pref.BoolKind: if t.Kind() == reflect.Bool { - return &scalarConverter{t, boolType, defVal(fd, boolZero)} + return &boolConverter{t, defVal(fd, boolZero)} } case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind: if t.Kind() == reflect.Int32 { - return &scalarConverter{t, int32Type, defVal(fd, int32Zero)} + return &int32Converter{t, defVal(fd, int32Zero)} } case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind: if t.Kind() == reflect.Int64 { - return &scalarConverter{t, int64Type, defVal(fd, int64Zero)} + return &int64Converter{t, defVal(fd, int64Zero)} } case pref.Uint32Kind, pref.Fixed32Kind: if t.Kind() == reflect.Uint32 { - return &scalarConverter{t, uint32Type, defVal(fd, uint32Zero)} + return &uint32Converter{t, defVal(fd, uint32Zero)} } case pref.Uint64Kind, pref.Fixed64Kind: if t.Kind() == reflect.Uint64 { - return &scalarConverter{t, uint64Type, defVal(fd, uint64Zero)} + return &uint64Converter{t, defVal(fd, uint64Zero)} } case pref.FloatKind: if t.Kind() == reflect.Float32 { - return &scalarConverter{t, float32Type, defVal(fd, float32Zero)} + return &float32Converter{t, defVal(fd, float32Zero)} } case pref.DoubleKind: if t.Kind() == reflect.Float64 { - return &scalarConverter{t, float64Type, defVal(fd, float64Zero)} + return &float64Converter{t, defVal(fd, float64Zero)} } case pref.StringKind: if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { - return &scalarConverter{t, stringType, defVal(fd, stringZero)} + return &stringConverter{t, defVal(fd, stringZero)} } case pref.BytesKind: if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { - return &scalarConverter{t, bytesType, defVal(fd, bytesZero)} + return &bytesConverter{t, defVal(fd, bytesZero)} } case pref.EnumKind: // Handle enums, which must be a named int32 type. @@ -142,37 +137,167 @@ func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter { panic(fmt.Sprintf("invalid Go type %v for field %v", t, fd.FullName())) } -func (c *scalarConverter) PBValueOf(v reflect.Value) pref.Value { +type boolConverter struct { + goType reflect.Type + def pref.Value +} + +func (c *boolConverter) PBValueOf(v reflect.Value) pref.Value { if v.Type() != c.goType { panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) } - if c.goType.Kind() == reflect.String && c.pbType.Kind() == reflect.Slice && v.Len() == 0 { - return pref.ValueOf([]byte(nil)) // ensure empty string is []byte(nil) - } - return pref.ValueOf(v.Convert(c.pbType).Interface()) + return pref.ValueOfBool(v.Bool()) +} +func (c *boolConverter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(v.Bool()).Convert(c.goType) +} +func (c *boolConverter) New() pref.Value { return c.def } +func (c *boolConverter) Zero() pref.Value { return c.def } + +type int32Converter struct { + goType reflect.Type + def pref.Value } -func (c *scalarConverter) GoValueOf(v pref.Value) reflect.Value { - rv := reflect.ValueOf(v.Interface()) - if rv.Type() != c.pbType { - panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), c.pbType)) +func (c *int32Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) } - if c.pbType.Kind() == reflect.String && c.goType.Kind() == reflect.Slice && rv.Len() == 0 { + return pref.ValueOfInt32(int32(v.Int())) +} +func (c *int32Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(int32(v.Int())).Convert(c.goType) +} +func (c *int32Converter) New() pref.Value { return c.def } +func (c *int32Converter) Zero() pref.Value { return c.def } + +type int64Converter struct { + goType reflect.Type + def pref.Value +} + +func (c *int64Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfInt64(int64(v.Int())) +} +func (c *int64Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(int64(v.Int())).Convert(c.goType) +} +func (c *int64Converter) New() pref.Value { return c.def } +func (c *int64Converter) Zero() pref.Value { return c.def } + +type uint32Converter struct { + goType reflect.Type + def pref.Value +} + +func (c *uint32Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfUint32(uint32(v.Uint())) +} +func (c *uint32Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType) +} +func (c *uint32Converter) New() pref.Value { return c.def } +func (c *uint32Converter) Zero() pref.Value { return c.def } + +type uint64Converter struct { + goType reflect.Type + def pref.Value +} + +func (c *uint64Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfUint64(uint64(v.Uint())) +} +func (c *uint64Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType) +} +func (c *uint64Converter) New() pref.Value { return c.def } +func (c *uint64Converter) Zero() pref.Value { return c.def } + +type float32Converter struct { + goType reflect.Type + def pref.Value +} + +func (c *float32Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfFloat32(float32(v.Float())) +} +func (c *float32Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(float32(v.Float())).Convert(c.goType) +} +func (c *float32Converter) New() pref.Value { return c.def } +func (c *float32Converter) Zero() pref.Value { return c.def } + +type float64Converter struct { + goType reflect.Type + def pref.Value +} + +func (c *float64Converter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfFloat64(float64(v.Float())) +} +func (c *float64Converter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(float64(v.Float())).Convert(c.goType) +} +func (c *float64Converter) New() pref.Value { return c.def } +func (c *float64Converter) Zero() pref.Value { return c.def } + +type stringConverter struct { + goType reflect.Type + def pref.Value +} + +func (c *stringConverter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + return pref.ValueOfString(v.Convert(stringType).String()) +} +func (c *stringConverter) GoValueOf(v pref.Value) reflect.Value { + // pref.Value.String never panics, so we go through an interface + // conversion here to check the type. + s := v.Interface().(string) + if c.goType.Kind() == reflect.Slice && s == "" { return reflect.Zero(c.goType) // ensure empty string is []byte(nil) } - return rv.Convert(c.goType) + return reflect.ValueOf(s).Convert(c.goType) +} +func (c *stringConverter) New() pref.Value { return c.def } +func (c *stringConverter) Zero() pref.Value { return c.def } + +type bytesConverter struct { + goType reflect.Type + def pref.Value } -func (c *scalarConverter) New() pref.Value { - if c.pbType == bytesType { - return pref.ValueOf(append(([]byte)(nil), c.def.Bytes()...)) +func (c *bytesConverter) PBValueOf(v reflect.Value) pref.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) } - return c.def + if c.goType.Kind() == reflect.String && v.Len() == 0 { + return pref.ValueOfBytes(nil) // ensure empty string is []byte(nil) + } + return pref.ValueOfBytes(v.Convert(bytesType).Bytes()) } - -func (c *scalarConverter) Zero() pref.Value { - return c.New() +func (c *bytesConverter) GoValueOf(v pref.Value) reflect.Value { + return reflect.ValueOf(v.Bytes()).Convert(c.goType) } +func (c *bytesConverter) New() pref.Value { return c.def } +func (c *bytesConverter) Zero() pref.Value { return c.def } type enumConverter struct { goType reflect.Type diff --git a/proto/decode_gen.go b/proto/decode_gen.go index dbb4c877..ce53dcd2 100644 --- a/proto/decode_gen.go +++ b/proto/decode_gen.go @@ -29,7 +29,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(wire.DecodeBool(v)), n, nil + return protoreflect.ValueOfBool(wire.DecodeBool(v)), n, nil case protoreflect.EnumKind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -38,7 +38,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(protoreflect.EnumNumber(v)), n, nil + return protoreflect.ValueOfEnum(protoreflect.EnumNumber(v)), n, nil case protoreflect.Int32Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -47,7 +47,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(int32(v)), n, nil + return protoreflect.ValueOfInt32(int32(v)), n, nil case protoreflect.Sint32Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -56,7 +56,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(int32(wire.DecodeZigZag(v & math.MaxUint32))), n, nil + return protoreflect.ValueOfInt32(int32(wire.DecodeZigZag(v & math.MaxUint32))), n, nil case protoreflect.Uint32Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -65,7 +65,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(uint32(v)), n, nil + return protoreflect.ValueOfUint32(uint32(v)), n, nil case protoreflect.Int64Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -74,7 +74,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(int64(v)), n, nil + return protoreflect.ValueOfInt64(int64(v)), n, nil case protoreflect.Sint64Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -83,7 +83,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(wire.DecodeZigZag(v)), n, nil + return protoreflect.ValueOfInt64(wire.DecodeZigZag(v)), n, nil case protoreflect.Uint64Kind: if wtyp != wire.VarintType { return val, 0, errUnknown @@ -92,7 +92,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(v), n, nil + return protoreflect.ValueOfUint64(v), n, nil case protoreflect.Sfixed32Kind: if wtyp != wire.Fixed32Type { return val, 0, errUnknown @@ -101,7 +101,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(int32(v)), n, nil + return protoreflect.ValueOfInt32(int32(v)), n, nil case protoreflect.Fixed32Kind: if wtyp != wire.Fixed32Type { return val, 0, errUnknown @@ -110,7 +110,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(uint32(v)), n, nil + return protoreflect.ValueOfUint32(uint32(v)), n, nil case protoreflect.FloatKind: if wtyp != wire.Fixed32Type { return val, 0, errUnknown @@ -119,7 +119,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(math.Float32frombits(uint32(v))), n, nil + return protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v))), n, nil case protoreflect.Sfixed64Kind: if wtyp != wire.Fixed64Type { return val, 0, errUnknown @@ -128,7 +128,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(int64(v)), n, nil + return protoreflect.ValueOfInt64(int64(v)), n, nil case protoreflect.Fixed64Kind: if wtyp != wire.Fixed64Type { return val, 0, errUnknown @@ -137,7 +137,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(v), n, nil + return protoreflect.ValueOfUint64(v), n, nil case protoreflect.DoubleKind: if wtyp != wire.Fixed64Type { return val, 0, errUnknown @@ -146,7 +146,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(math.Float64frombits(v)), n, nil + return protoreflect.ValueOfFloat64(math.Float64frombits(v)), n, nil case protoreflect.StringKind: if wtyp != wire.BytesType { return val, 0, errUnknown @@ -158,7 +158,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName())) } - return protoreflect.ValueOf(string(v)), n, nil + return protoreflect.ValueOfString(string(v)), n, nil case protoreflect.BytesKind: if wtyp != wire.BytesType { return val, 0, errUnknown @@ -167,7 +167,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(append(([]byte)(nil), v...)), n, nil + return protoreflect.ValueOfBytes(append(([]byte)(nil), v...)), n, nil case protoreflect.MessageKind: if wtyp != wire.BytesType { return val, 0, errUnknown @@ -176,7 +176,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(v), n, nil + return protoreflect.ValueOfBytes(v), n, nil case protoreflect.GroupKind: if wtyp != wire.StartGroupType { return val, 0, errUnknown @@ -185,7 +185,7 @@ func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protorefl if n < 0 { return val, 0, wire.ParseError(n) } - return protoreflect.ValueOf(v), n, nil + return protoreflect.ValueOfBytes(v), n, nil default: return val, 0, errUnknown } @@ -205,7 +205,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(wire.DecodeBool(v))) + list.Append(protoreflect.ValueOfBool(wire.DecodeBool(v))) } return n, nil } @@ -216,7 +216,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(wire.DecodeBool(v))) + list.Append(protoreflect.ValueOfBool(wire.DecodeBool(v))) return n, nil case protoreflect.EnumKind: if wtyp == wire.BytesType { @@ -230,7 +230,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(protoreflect.EnumNumber(v))) + list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))) } return n, nil } @@ -241,7 +241,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(protoreflect.EnumNumber(v))) + list.Append(protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))) return n, nil case protoreflect.Int32Kind: if wtyp == wire.BytesType { @@ -255,7 +255,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(int32(v))) + list.Append(protoreflect.ValueOfInt32(int32(v))) } return n, nil } @@ -266,7 +266,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(int32(v))) + list.Append(protoreflect.ValueOfInt32(int32(v))) return n, nil case protoreflect.Sint32Kind: if wtyp == wire.BytesType { @@ -280,7 +280,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(int32(wire.DecodeZigZag(v & math.MaxUint32)))) + list.Append(protoreflect.ValueOfInt32(int32(wire.DecodeZigZag(v & math.MaxUint32)))) } return n, nil } @@ -291,7 +291,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(int32(wire.DecodeZigZag(v & math.MaxUint32)))) + list.Append(protoreflect.ValueOfInt32(int32(wire.DecodeZigZag(v & math.MaxUint32)))) return n, nil case protoreflect.Uint32Kind: if wtyp == wire.BytesType { @@ -305,7 +305,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(uint32(v))) + list.Append(protoreflect.ValueOfUint32(uint32(v))) } return n, nil } @@ -316,7 +316,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(uint32(v))) + list.Append(protoreflect.ValueOfUint32(uint32(v))) return n, nil case protoreflect.Int64Kind: if wtyp == wire.BytesType { @@ -330,7 +330,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(int64(v))) + list.Append(protoreflect.ValueOfInt64(int64(v))) } return n, nil } @@ -341,7 +341,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(int64(v))) + list.Append(protoreflect.ValueOfInt64(int64(v))) return n, nil case protoreflect.Sint64Kind: if wtyp == wire.BytesType { @@ -355,7 +355,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(wire.DecodeZigZag(v))) + list.Append(protoreflect.ValueOfInt64(wire.DecodeZigZag(v))) } return n, nil } @@ -366,7 +366,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(wire.DecodeZigZag(v))) + list.Append(protoreflect.ValueOfInt64(wire.DecodeZigZag(v))) return n, nil case protoreflect.Uint64Kind: if wtyp == wire.BytesType { @@ -380,7 +380,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(v)) + list.Append(protoreflect.ValueOfUint64(v)) } return n, nil } @@ -391,7 +391,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(v)) + list.Append(protoreflect.ValueOfUint64(v)) return n, nil case protoreflect.Sfixed32Kind: if wtyp == wire.BytesType { @@ -405,7 +405,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(int32(v))) + list.Append(protoreflect.ValueOfInt32(int32(v))) } return n, nil } @@ -416,7 +416,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(int32(v))) + list.Append(protoreflect.ValueOfInt32(int32(v))) return n, nil case protoreflect.Fixed32Kind: if wtyp == wire.BytesType { @@ -430,7 +430,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(uint32(v))) + list.Append(protoreflect.ValueOfUint32(uint32(v))) } return n, nil } @@ -441,7 +441,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(uint32(v))) + list.Append(protoreflect.ValueOfUint32(uint32(v))) return n, nil case protoreflect.FloatKind: if wtyp == wire.BytesType { @@ -455,7 +455,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(math.Float32frombits(uint32(v)))) + list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))) } return n, nil } @@ -466,7 +466,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(math.Float32frombits(uint32(v)))) + list.Append(protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))) return n, nil case protoreflect.Sfixed64Kind: if wtyp == wire.BytesType { @@ -480,7 +480,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(int64(v))) + list.Append(protoreflect.ValueOfInt64(int64(v))) } return n, nil } @@ -491,7 +491,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(int64(v))) + list.Append(protoreflect.ValueOfInt64(int64(v))) return n, nil case protoreflect.Fixed64Kind: if wtyp == wire.BytesType { @@ -505,7 +505,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(v)) + list.Append(protoreflect.ValueOfUint64(v)) } return n, nil } @@ -516,7 +516,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(v)) + list.Append(protoreflect.ValueOfUint64(v)) return n, nil case protoreflect.DoubleKind: if wtyp == wire.BytesType { @@ -530,7 +530,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl return 0, wire.ParseError(n) } buf = buf[n:] - list.Append(protoreflect.ValueOf(math.Float64frombits(v))) + list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v))) } return n, nil } @@ -541,7 +541,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(math.Float64frombits(v))) + list.Append(protoreflect.ValueOfFloat64(math.Float64frombits(v))) return n, nil case protoreflect.StringKind: if wtyp != wire.BytesType { @@ -554,7 +554,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if strs.EnforceUTF8(fd) && !utf8.Valid(v) { return 0, errors.InvalidUTF8(string(fd.FullName())) } - list.Append(protoreflect.ValueOf(string(v))) + list.Append(protoreflect.ValueOfString(string(v))) return n, nil case protoreflect.BytesKind: if wtyp != wire.BytesType { @@ -564,7 +564,7 @@ func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protorefl if n < 0 { return 0, wire.ParseError(n) } - list.Append(protoreflect.ValueOf(append(([]byte)(nil), v...))) + list.Append(protoreflect.ValueOfBytes(append(([]byte)(nil), v...))) return n, nil case protoreflect.MessageKind: if wtyp != wire.BytesType { diff --git a/reflect/protoreflect/value_union.go b/reflect/protoreflect/value_union.go index 0223d026..d695a4d6 100644 --- a/reflect/protoreflect/value_union.go +++ b/reflect/protoreflect/value_union.go @@ -65,38 +65,101 @@ func ValueOf(v interface{}) Value { case nil: return Value{} case bool: - if v { - return Value{typ: boolType, num: 1} - } else { - return Value{typ: boolType, num: 0} - } + return ValueOfBool(v) case int32: - return Value{typ: int32Type, num: uint64(v)} + return ValueOfInt32(v) case int64: - return Value{typ: int64Type, num: uint64(v)} + return ValueOfInt64(v) case uint32: - return Value{typ: uint32Type, num: uint64(v)} + return ValueOfUint32(v) case uint64: - return Value{typ: uint64Type, num: uint64(v)} + return ValueOfUint64(v) case float32: - return Value{typ: float32Type, num: uint64(math.Float64bits(float64(v)))} + return ValueOfFloat32(v) case float64: - return Value{typ: float64Type, num: uint64(math.Float64bits(float64(v)))} + return ValueOfFloat64(v) case string: - return valueOfString(v) + return ValueOfString(v) case []byte: - return valueOfBytes(v[:len(v):len(v)]) + return ValueOfBytes(v) case EnumNumber: - return Value{typ: enumType, num: uint64(v)} + return ValueOfEnum(v) case Message, List, Map: return valueOfIface(v) default: - // TODO: Special case Enum, ProtoMessage, *[]T, and *map[K]V? - // Note: this would violate the documented invariant in Interface. panic(fmt.Sprintf("invalid type: %v", reflect.TypeOf(v))) } } +// ValueOfBool returns a new boolean value. +func ValueOfBool(v bool) Value { + if v { + return Value{typ: boolType, num: 1} + } else { + return Value{typ: boolType, num: 0} + } +} + +// ValueOfInt32 returns a new int32 value. +func ValueOfInt32(v int32) Value { + return Value{typ: int32Type, num: uint64(v)} +} + +// ValueOfInt64 returns a new int64 value. +func ValueOfInt64(v int64) Value { + return Value{typ: int64Type, num: uint64(v)} +} + +// ValueOfUint32 returns a new uint32 value. +func ValueOfUint32(v uint32) Value { + return Value{typ: uint32Type, num: uint64(v)} +} + +// ValueOfUint64 returns a new uint64 value. +func ValueOfUint64(v uint64) Value { + return Value{typ: uint64Type, num: v} +} + +// ValueOfFloat32 returns a new float32 value. +func ValueOfFloat32(v float32) Value { + return Value{typ: float32Type, num: uint64(math.Float64bits(float64(v)))} +} + +// ValueOfFloat64 returns a new float64 value. +func ValueOfFloat64(v float64) Value { + return Value{typ: float64Type, num: uint64(math.Float64bits(float64(v)))} +} + +// ValueOfString returns a new string value. +func ValueOfString(v string) Value { + return valueOfString(v) +} + +// ValueOfBytes returns a new bytes value. +func ValueOfBytes(v []byte) Value { + return valueOfBytes(v[:len(v):len(v)]) +} + +// ValueOfEnum returns a new enum value. +func ValueOfEnum(v EnumNumber) Value { + return Value{typ: enumType, num: uint64(v)} +} + +// ValueOfMessage returns a new Message value. +func ValueOfMessage(v Message) Value { + return valueOfIface(v) +} + +// ValueOfList returns a new List value. +func ValueOfList(v List) Value { + return valueOfIface(v) +} + +// ValueOfMap returns a new Map value. +func ValueOfMap(v Map) Value { + return valueOfIface(v) +} + // IsValid reports whether v is populated with a value. func (v Value) IsValid() bool { return v.typ != nilType