mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-02-05 15:40:09 +00:00
reflect/protoreflect: add ExtensionType IsValid{Interface,Value} methods
Add a way to typecheck a Value or interface{} without converting it to the other form. This permits implementations which store field values as a Value (such as dynamicpb, or (soon) extensions in generated messages) to validate inputs without an unnecessary conversion. Fixes golang/protobuf#905 Change-Id: I1b78612b22ae832efbb55f81ae420871729e3a02 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/192457 Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
50f860a45a
commit
835b271169
@ -25,6 +25,12 @@ type Converter interface {
|
||||
// GoValueOf converts a protoreflect.Value to a reflect.Value.
|
||||
GoValueOf(pref.Value) reflect.Value
|
||||
|
||||
// IsValidPB returns whether a protoreflect.Value is compatible with this type.
|
||||
IsValidPB(pref.Value) bool
|
||||
|
||||
// IsValidGo returns whether a reflect.Value is compatible with this type.
|
||||
IsValidGo(reflect.Value) bool
|
||||
|
||||
// New returns a new field value.
|
||||
// For scalars, it returns the default value of the field.
|
||||
// For composite types, it returns a new mutable value.
|
||||
@ -151,6 +157,13 @@ func (c *boolConverter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *boolConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(v.Bool()).Convert(c.goType)
|
||||
}
|
||||
func (c *boolConverter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(bool)
|
||||
return ok
|
||||
}
|
||||
func (c *boolConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *boolConverter) New() pref.Value { return c.def }
|
||||
func (c *boolConverter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -168,6 +181,13 @@ func (c *int32Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *int32Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(int32(v.Int())).Convert(c.goType)
|
||||
}
|
||||
func (c *int32Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(int32)
|
||||
return ok
|
||||
}
|
||||
func (c *int32Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *int32Converter) New() pref.Value { return c.def }
|
||||
func (c *int32Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -185,6 +205,13 @@ func (c *int64Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *int64Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(int64(v.Int())).Convert(c.goType)
|
||||
}
|
||||
func (c *int64Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(int64)
|
||||
return ok
|
||||
}
|
||||
func (c *int64Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *int64Converter) New() pref.Value { return c.def }
|
||||
func (c *int64Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -202,6 +229,13 @@ func (c *uint32Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *uint32Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType)
|
||||
}
|
||||
func (c *uint32Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(uint32)
|
||||
return ok
|
||||
}
|
||||
func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *uint32Converter) New() pref.Value { return c.def }
|
||||
func (c *uint32Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -219,6 +253,13 @@ func (c *uint64Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *uint64Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType)
|
||||
}
|
||||
func (c *uint64Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(uint64)
|
||||
return ok
|
||||
}
|
||||
func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *uint64Converter) New() pref.Value { return c.def }
|
||||
func (c *uint64Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -236,6 +277,13 @@ func (c *float32Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *float32Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(float32(v.Float())).Convert(c.goType)
|
||||
}
|
||||
func (c *float32Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(float32)
|
||||
return ok
|
||||
}
|
||||
func (c *float32Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *float32Converter) New() pref.Value { return c.def }
|
||||
func (c *float32Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -253,6 +301,13 @@ func (c *float64Converter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *float64Converter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(float64(v.Float())).Convert(c.goType)
|
||||
}
|
||||
func (c *float64Converter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(float64)
|
||||
return ok
|
||||
}
|
||||
func (c *float64Converter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *float64Converter) New() pref.Value { return c.def }
|
||||
func (c *float64Converter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -276,6 +331,13 @@ func (c *stringConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
}
|
||||
return reflect.ValueOf(s).Convert(c.goType)
|
||||
}
|
||||
func (c *stringConverter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(string)
|
||||
return ok
|
||||
}
|
||||
func (c *stringConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *stringConverter) New() pref.Value { return c.def }
|
||||
func (c *stringConverter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -296,6 +358,13 @@ func (c *bytesConverter) PBValueOf(v reflect.Value) pref.Value {
|
||||
func (c *bytesConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(v.Bytes()).Convert(c.goType)
|
||||
}
|
||||
func (c *bytesConverter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().([]byte)
|
||||
return ok
|
||||
}
|
||||
func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
func (c *bytesConverter) New() pref.Value { return c.def }
|
||||
func (c *bytesConverter) Zero() pref.Value { return c.def }
|
||||
|
||||
@ -325,6 +394,15 @@ func (c *enumConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return reflect.ValueOf(v.Enum()).Convert(c.goType)
|
||||
}
|
||||
|
||||
func (c *enumConverter) IsValidPB(v pref.Value) bool {
|
||||
_, ok := v.Interface().(pref.EnumNumber)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *enumConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *enumConverter) New() pref.Value {
|
||||
return c.def
|
||||
}
|
||||
@ -365,6 +443,21 @@ func (c *messageConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return rv
|
||||
}
|
||||
|
||||
func (c *messageConverter) IsValidPB(v pref.Value) bool {
|
||||
m := v.Message()
|
||||
var rv reflect.Value
|
||||
if u, ok := m.(Unwrapper); ok {
|
||||
rv = reflect.ValueOf(u.ProtoUnwrap())
|
||||
} else {
|
||||
rv = reflect.ValueOf(m.Interface())
|
||||
}
|
||||
return rv.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *messageConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *messageConverter) New() pref.Value {
|
||||
return c.PBValueOf(reflect.New(c.goType.Elem()))
|
||||
}
|
||||
|
@ -34,6 +34,18 @@ func (c *listConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return v.List().(*listReflect).v
|
||||
}
|
||||
|
||||
func (c *listConverter) IsValidPB(v pref.Value) bool {
|
||||
list, ok := v.Interface().(*listReflect)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return list.v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *listConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *listConverter) New() pref.Value {
|
||||
return c.PBValueOf(reflect.New(c.goType.Elem()))
|
||||
}
|
||||
|
@ -38,6 +38,18 @@ func (c *mapConverter) GoValueOf(v pref.Value) reflect.Value {
|
||||
return v.Map().(*mapReflect).v
|
||||
}
|
||||
|
||||
func (c *mapConverter) IsValidPB(v pref.Value) bool {
|
||||
mapv, ok := v.Interface().(*mapReflect)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return mapv.v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *mapConverter) IsValidGo(v reflect.Value) bool {
|
||||
return v.Type() == c.goType
|
||||
}
|
||||
|
||||
func (c *mapConverter) New() pref.Value {
|
||||
return c.PBValueOf(reflect.MakeMap(c.goType))
|
||||
}
|
||||
|
@ -114,6 +114,12 @@ func (xi *ExtensionInfo) ValueOf(v interface{}) pref.Value {
|
||||
func (xi *ExtensionInfo) InterfaceOf(v pref.Value) interface{} {
|
||||
return xi.lazyInit().GoValueOf(v).Interface()
|
||||
}
|
||||
func (xi *ExtensionInfo) IsValidValue(v pref.Value) bool {
|
||||
return xi.lazyInit().IsValidPB(v)
|
||||
}
|
||||
func (xi *ExtensionInfo) IsValidInterface(v interface{}) bool {
|
||||
return xi.lazyInit().IsValidGo(reflect.ValueOf(v))
|
||||
}
|
||||
func (xi *ExtensionInfo) GoType() reflect.Type {
|
||||
xi.lazyInit()
|
||||
return xi.goType
|
||||
|
130
internal/impl/extension_test.go
Normal file
130
internal/impl/extension_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package impl_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
cmp "github.com/google/go-cmp/cmp"
|
||||
testpb "google.golang.org/protobuf/internal/testprotos/test"
|
||||
pref "google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
func TestExtensionType(t *testing.T) {
|
||||
cmpOpts := cmp.Options{
|
||||
cmp.Comparer(func(x, y proto.Message) bool {
|
||||
return proto.Equal(x, y)
|
||||
}),
|
||||
}
|
||||
for _, test := range []struct {
|
||||
xt pref.ExtensionType
|
||||
value interface{}
|
||||
}{
|
||||
{
|
||||
xt: testpb.E_OptionalInt32Extension,
|
||||
value: int32(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalInt64Extension,
|
||||
value: int64(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalUint32Extension,
|
||||
value: uint32(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalUint64Extension,
|
||||
value: uint64(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalFloatExtension,
|
||||
value: float32(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalDoubleExtension,
|
||||
value: float64(0),
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalBoolExtension,
|
||||
value: true,
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalStringExtension,
|
||||
value: "",
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalBytesExtension,
|
||||
value: []byte{},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalNestedMessageExtension,
|
||||
value: &testpb.TestAllTypes_NestedMessage{},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_OptionalNestedEnumExtension,
|
||||
value: testpb.TestAllTypes_FOO,
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedInt32Extension,
|
||||
value: []int32{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedInt64Extension,
|
||||
value: []int64{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedUint32Extension,
|
||||
value: []uint32{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedUint64Extension,
|
||||
value: []uint64{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedFloatExtension,
|
||||
value: []float32{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedDoubleExtension,
|
||||
value: []float64{0},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedBoolExtension,
|
||||
value: []bool{true},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedStringExtension,
|
||||
value: []string{""},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedBytesExtension,
|
||||
value: [][]byte{nil},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedNestedMessageExtension,
|
||||
value: []*testpb.TestAllTypes_NestedMessage{{}},
|
||||
},
|
||||
{
|
||||
xt: testpb.E_RepeatedNestedEnumExtension,
|
||||
value: []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_FOO},
|
||||
},
|
||||
} {
|
||||
name := test.xt.TypeDescriptor().FullName()
|
||||
t.Run(fmt.Sprint(name), func(t *testing.T) {
|
||||
if !test.xt.IsValidInterface(test.value) {
|
||||
t.Fatalf("IsValidInterface(%[1]T(%[1]v)) = false, want true", test.value)
|
||||
}
|
||||
v := test.xt.ValueOf(test.value)
|
||||
if !test.xt.IsValidValue(v) {
|
||||
t.Fatalf("IsValidValue(%[1]T(%[1]v)) = false, want true", v)
|
||||
}
|
||||
if got, want := test.xt.InterfaceOf(v), test.value; !cmp.Equal(got, want, cmpOpts) {
|
||||
t.Fatalf("round trip InterfaceOf(ValueOf(x)) = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -484,6 +484,12 @@ type ExtensionType interface {
|
||||
// InterfaceOf is able to unwrap the Value further than Value.Interface
|
||||
// as it has more type information available.
|
||||
InterfaceOf(Value) interface{}
|
||||
|
||||
// IsValidValue returns whether the Value is valid to assign to the field.
|
||||
IsValidValue(Value) bool
|
||||
|
||||
// IsValidInterface returns whether the input is valid to assign to the field.
|
||||
IsValidInterface(interface{}) bool
|
||||
}
|
||||
|
||||
// EnumDescriptor describes an enum and
|
||||
|
@ -166,8 +166,9 @@ func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
|
||||
m.checkField(fd)
|
||||
switch {
|
||||
case fd.IsExtension():
|
||||
// Call InterfaceOf just to let the extension typecheck the value.
|
||||
_ = fd.(pref.ExtensionTypeDescriptor).Type().InterfaceOf(v)
|
||||
if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
|
||||
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
|
||||
}
|
||||
m.ext[fd.Number()] = fd
|
||||
case fd.IsMap():
|
||||
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
|
||||
|
Loading…
x
Reference in New Issue
Block a user