all: add NewField, NewElement, NewValue

Add methods to protoreflect.{Message,List,Map} to constrict values
assignable to a message field, list element, or map value. These
methods return the default value for scalar fields, the zero value for
scalar list elements and map values, and an empty, mutable value for
messages, lists, and maps.

Deprecate the NewMessage methods on these types, which are superseded.

Updates 

Change-Id: I0f064f60c89a239330ccea81523f559f14fd2c4f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188997
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-08-05 10:48:38 -07:00
parent c36f3ae703
commit f5274511fe
10 changed files with 235 additions and 44 deletions
internal
reflect/protoreflect
testing/prototest
types/dynamicpb

@ -689,11 +689,14 @@ func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
}
}
func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *{{.}}) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
return fi.newField()
} else {
return xt.New().Message()
return xt.New()
}
}
func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {

@ -57,48 +57,67 @@ var (
byteType = reflect.TypeOf(byte(0))
)
var (
boolZero = pref.ValueOf(bool(false))
int32Zero = pref.ValueOf(int32(0))
int64Zero = pref.ValueOf(int64(0))
uint32Zero = pref.ValueOf(uint32(0))
uint64Zero = pref.ValueOf(uint64(0))
float32Zero = pref.ValueOf(float32(0))
float64Zero = pref.ValueOf(float64(0))
stringZero = pref.ValueOf(string(""))
bytesZero = pref.ValueOf([]byte(nil))
)
type scalarConverter struct {
goType, pbType reflect.Type
def pref.Value
}
func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
defVal := func(fd pref.FieldDescriptor, zero pref.Value) pref.Value {
if fd.Cardinality() == pref.Repeated {
// Default isn't defined for repeated fields.
return zero
}
return fd.Default()
}
switch fd.Kind() {
case pref.BoolKind:
if t.Kind() == reflect.Bool {
return &scalarConverter{t, boolType, fd.Default()}
return &scalarConverter{t, boolType, defVal(fd, boolZero)}
}
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
if t.Kind() == reflect.Int32 {
return &scalarConverter{t, int32Type, fd.Default()}
return &scalarConverter{t, int32Type, defVal(fd, int32Zero)}
}
case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
if t.Kind() == reflect.Int64 {
return &scalarConverter{t, int64Type, fd.Default()}
return &scalarConverter{t, int64Type, defVal(fd, int64Zero)}
}
case pref.Uint32Kind, pref.Fixed32Kind:
if t.Kind() == reflect.Uint32 {
return &scalarConverter{t, uint32Type, fd.Default()}
return &scalarConverter{t, uint32Type, defVal(fd, uint32Zero)}
}
case pref.Uint64Kind, pref.Fixed64Kind:
if t.Kind() == reflect.Uint64 {
return &scalarConverter{t, uint64Type, fd.Default()}
return &scalarConverter{t, uint64Type, defVal(fd, uint64Zero)}
}
case pref.FloatKind:
if t.Kind() == reflect.Float32 {
return &scalarConverter{t, float32Type, fd.Default()}
return &scalarConverter{t, float32Type, defVal(fd, float32Zero)}
}
case pref.DoubleKind:
if t.Kind() == reflect.Float64 {
return &scalarConverter{t, float64Type, fd.Default()}
return &scalarConverter{t, float64Type, defVal(fd, float64Zero)}
}
case pref.StringKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &scalarConverter{t, stringType, fd.Default()}
return &scalarConverter{t, stringType, defVal(fd, stringZero)}
}
case pref.BytesKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &scalarConverter{t, bytesType, fd.Default()}
return &scalarConverter{t, bytesType, defVal(fd, bytesZero)}
}
case pref.EnumKind:
// Handle enums, which must be a named int32 type.
@ -133,6 +152,9 @@ func (c *scalarConverter) GoValueOf(v pref.Value) reflect.Value {
}
func (c *scalarConverter) New() pref.Value {
if c.pbType == bytesType {
return pref.ValueOf(append(([]byte)(nil), c.def.Bytes()...))
}
return c.def
}
@ -142,7 +164,13 @@ type enumConverter struct {
}
func newEnumConverter(goType reflect.Type, fd pref.FieldDescriptor) Converter {
return &enumConverter{goType, fd.Default()}
var def pref.Value
if fd.Cardinality() == pref.Repeated {
def = pref.ValueOf(fd.Enum().Values().Get(0).Number())
} else {
def = fd.Default()
}
return &enumConverter{goType, def}
}
func (c *enumConverter) PBValueOf(v reflect.Value) pref.Value {

@ -62,7 +62,10 @@ func (ls *listReflect) Truncate(i int) {
ls.v.Elem().Set(ls.v.Elem().Slice(0, i))
}
func (ls *listReflect) NewMessage() pref.Message {
return ls.conv.New().Message()
return ls.NewElement().Message()
}
func (ls *listReflect) NewElement() pref.Value {
return ls.conv.New()
}
func (ls *listReflect) ProtoUnwrap() interface{} {
return ls.v.Interface()

@ -85,7 +85,10 @@ func (ms *mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
}
}
func (ms *mapReflect) NewMessage() pref.Message {
return ms.valConv.New().Message()
return ms.NewValue().Message()
}
func (ms *mapReflect) NewValue() pref.Value {
return ms.valConv.New()
}
func (ms *mapReflect) ProtoUnwrap() interface{} {
return ms.v.Interface()

@ -26,6 +26,7 @@ type fieldInfo struct {
set func(pointer, pref.Value)
mutable func(pointer) pref.Value
newMessage func() pref.Message
newField func() pref.Value
}
func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo {
@ -113,6 +114,9 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x export
newMessage: func() pref.Message {
return conv.New().Message()
},
newField: func() pref.Value {
return conv.New()
},
}
}
@ -160,6 +164,9 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField, x exporter
}
return conv.PBValueOf(v)
},
newField: func() pref.Value {
return conv.New()
},
}
}
@ -204,6 +211,9 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField, x exporte
v := p.Apply(fieldOffset).AsValueOf(fs.Type)
return conv.PBValueOf(v)
},
newField: func() pref.Value {
return conv.New()
},
}
}
@ -289,6 +299,9 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField, x expor
}
}
},
newField: func() pref.Value {
return conv.New()
},
}
}
@ -367,6 +380,10 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
lazyInit()
return messageType.New()
},
newField: func() pref.Value {
lazyInit()
return pref.ValueOf(messageType.New())
},
}
}
@ -417,6 +434,9 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x expo
newMessage: func() pref.Message {
return conv.New().Message()
},
newField: func() pref.Value {
return conv.New()
},
}
}

@ -92,11 +92,14 @@ func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Val
}
}
func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *messageState) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
return fi.newField()
} else {
return xt.New().Message()
return xt.New()
}
}
func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
@ -199,11 +202,14 @@ func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protore
}
}
func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *messageReflectWrapper) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
return fi.newField()
} else {
return xt.New().Message()
return xt.New()
}
}
func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {

@ -82,6 +82,10 @@ func (m *message) NewMessage(pref.FieldDescriptor) pref.Message {
panic("invalid field descriptor")
}
func (m *message) NewField(pref.FieldDescriptor) pref.Value {
panic("invalid field descriptor")
}
func (m *message) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor {
panic("invalid oneof descriptor")
}

@ -119,8 +119,15 @@ type Message interface {
// NewMessage returns a newly allocated empty message assignable to
// the field of the given descriptor.
// It panics if the field is not a singular message.
//
// Deprecated: Use NewField instead.
NewMessage(FieldDescriptor) Message
// NewField returns a new value for assignable to the field of a given descriptor.
// For scalars, this returns the default value.
// For lists, maps, and messages, this returns a new, empty, mutable value.
NewField(FieldDescriptor) Value
// WhichOneof reports which field within the oneof is populated,
// returning nil if none are populated.
// It panics if the oneof descriptor does not belong to this message.
@ -197,7 +204,15 @@ type List interface {
// NewMessage returns a newly allocated empty message assignable as a list entry.
// It panics if the list entry type is not a message.
//
// Deprecated: Use NewElement instead.
NewMessage() Message
// NewElement returns a new value for a list element.
// For enums, this returns the first enum value.
// For other scalars, this returns the zero value.
// For messages, this returns a new, empty, mutable value.
NewElement() Value
}
// Map is an unordered, associative map.
@ -240,5 +255,13 @@ type Map interface {
// NewMessage returns a newly allocated empty message assignable as a map value.
// It panics if the map value type is not a message.
//
// Deprecated: Use NewValue instead.
NewMessage() Message
// NewValue returns a new value assignable as a map value.
// For enums, this returns the first enum value.
// For other scalars, this returns the zero value.
// For messages, this returns a new, empty, mutable value.
NewValue() Value
}

@ -85,8 +85,14 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
testFieldList(t, m, fd)
case fd.IsMap():
testFieldMap(t, m, fd)
case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
testFieldFloat(t, m, fd)
case fd.Message() != nil:
default:
if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) {
t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want))
}
if fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind {
testFieldFloat(t, m, fd)
}
}
// Set to a non-zero value, the zero value, different non-zero values.
@ -108,6 +114,9 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
wantHas = true
}
}
if fd.Syntax() == pref.Proto3 && fd.Cardinality() != pref.Repeated && fd.ContainingOneof() == nil && fd.Kind() == pref.EnumKind && v.Enum() == 0 {
wantHas = false
}
if got, want := m.Has(fd), wantHas; got != want {
t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
}
@ -166,8 +175,16 @@ func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
name := fd.FullName()
num := fd.Number()
// New values.
m.Clear(fd) // start with an empty map
mapv := m.Mutable(fd).Map()
mapv := m.Get(fd).Map()
if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
}
mapv = m.Mutable(fd).Map() // mutable map
if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
}
// Add values.
want := make(testMap)
@ -228,6 +245,7 @@ func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
func (m testMap) Len() int { return len(m) }
func (m testMap) NewMessage() pref.Message { panic("unimplemented") }
func (m testMap) NewValue() pref.Value { panic("unimplemented") }
func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
for k, v := range m {
if !f(pref.ValueOf(k).MapKey(), v) {
@ -242,7 +260,14 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
num := fd.Number()
m.Clear(fd) // start with an empty list
list := m.Mutable(fd).List()
list := m.Get(fd).List()
if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
}
list = m.Mutable(fd).List() // mutable list
if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
}
// Append values.
var want pref.List = &testList{}
@ -293,6 +318,7 @@ func (l *testList) Len() int { return len(l.a) }
func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
func (l *testList) Truncate(n int) { l.a = l.a[:n] }
func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
func (l *testList) NewElement() pref.Value { panic("unimplemented") }
// testFieldFloat exercises some interesting floating-point scalar field values.
func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
@ -462,6 +488,19 @@ const (
maxVal seed = -2
)
// newSeed creates new seed values from a base, for example to create seeds for the
// elements in a list. If the input seed is minVal or maxVal, so is the output.
func newSeed(n seed, adjust ...int) seed {
switch n {
case minVal, maxVal:
return n
}
for _, a := range adjust {
n = 10*n + seed(a)
}
return n
}
// newValue returns a new value assignable to a field.
//
// The stack parameter is used to avoid infinite recursion when populating circular
@ -469,7 +508,7 @@ const (
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
switch {
case fd.IsList():
list := m.New().Mutable(fd).List()
list := m.NewField(fd).List()
if n == 0 {
return pref.ValueOf(list)
}
@ -479,17 +518,17 @@ func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.Mess
list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list)
case fd.IsMap():
mapv := m.New().Mutable(fd).Map()
mapv := m.NewField(fd).Map()
if n == 0 {
return pref.ValueOf(mapv)
}
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
return pref.ValueOf(mapv)
case fd.Message() != nil:
return populateMessage(m.Mutable(fd).Message(), n, stack)
return populateMessage(m.NewField(fd).Message(), n, stack)
default:
return newScalarValue(fd, n)
}
@ -499,7 +538,7 @@ func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pre
if fd.Message() == nil {
return newScalarValue(fd, n)
}
return populateMessage(list.NewMessage(), n, stack)
return populateMessage(list.NewElement().Message(), n, stack)
}
func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
@ -512,7 +551,7 @@ func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.Me
if vd.Message() == nil {
return newScalarValue(vd, n)
}
return populateMessage(mapv.NewMessage(), n, stack)
return populateMessage(mapv.NewValue().Message(), n, stack)
}
func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
@ -520,8 +559,17 @@ func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
case pref.BoolKind:
return pref.ValueOf(n != 0)
case pref.EnumKind:
// TODO: use actual value
return pref.ValueOf(pref.EnumNumber(n))
vals := fd.Enum().Values()
var i int
switch n {
case minVal:
i = 0
case maxVal:
i = vals.Len() - 1
default:
i = int(n) % vals.Len()
}
return pref.ValueOf(vals.Get(i).Number())
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
switch n {
case minVal:
@ -608,7 +656,7 @@ func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pre
if fd.IsWeak() {
continue
}
m.Set(fd, newValue(m, fd, 10*n+seed(i), stack))
m.Set(fd, newValue(m, fd, newSeed(n, i), stack))
}
return pref.ValueOf(m)
}

@ -23,7 +23,7 @@ import (
// package documentation for that interface for how to get and set fields and
// otherwise interact with the contents of a Message.
//
// Reflection API functions which construct messages, such as NewMessage,
// Reflection API functions which construct messages, such as NewField,
// return new dynamic messages of the appropriate type. Functions which take
// messages, such as Set for a message-value field, will accept any message
// with a compatible type.
@ -156,19 +156,9 @@ func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
}
m.clearOtherOneofFields(fd)
switch {
case fd.IsExtension():
m.known[num] = fd.(pref.ExtensionType).New()
m.known[num] = m.NewField(fd)
if fd.IsExtension() {
m.ext[num] = fd
case fd.IsMap():
m.known[num] = pref.ValueOf(&dynamicMap{
desc: fd,
mapv: make(map[interface{}]pref.Value),
})
case fd.IsList():
m.known[num] = pref.ValueOf(&dynamicList{desc: fd})
case fd.Message() != nil:
m.known[num] = pref.ValueOf(m.NewMessage(fd))
}
return m.known[num]
}
@ -221,6 +211,27 @@ func (m *Message) NewMessage(fd pref.FieldDescriptor) pref.Message {
return New(md).ProtoReflect()
}
// NewField returns a new value for assignable to the field of a given descriptor.
// See protoreflect.Message for details.
func (m *Message) NewField(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd)
switch {
case fd.IsExtension():
return fd.(pref.ExtensionType).New()
case fd.IsMap():
return pref.ValueOf(&dynamicMap{
desc: fd,
mapv: make(map[interface{}]pref.Value),
})
case fd.IsList():
return pref.ValueOf(&dynamicList{desc: fd})
case fd.Message() != nil:
return pref.ValueOf(New(fd.Message()).ProtoReflect())
default:
return fd.Default()
}
}
// WhichOneof reports which field in a oneof is populated, returning nil if none are populated.
// See protoreflect.Message for details.
func (m *Message) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor {
@ -278,6 +289,9 @@ func (x emptyList) NewMessage() pref.Message {
}
return New(md).ProtoReflect()
}
func (x emptyList) NewElement() pref.Value {
return newListEntry(x.desc)
}
type dynamicList struct {
desc pref.FieldDescriptor
@ -318,6 +332,10 @@ func (x *dynamicList) NewMessage() pref.Message {
return New(md).ProtoReflect()
}
func (x *dynamicList) NewElement() pref.Value {
return newListEntry(x.desc)
}
type dynamicMap struct {
desc pref.FieldDescriptor
mapv map[interface{}]pref.Value
@ -339,6 +357,13 @@ func (x *dynamicMap) NewMessage() pref.Message {
}
return New(md).ProtoReflect()
}
func (x *dynamicMap) NewValue() pref.Value {
if md := x.desc.MapValue().Message(); md != nil {
return pref.ValueOf(New(md).ProtoReflect())
}
return x.desc.MapValue().Default()
}
func (x *dynamicMap) Range(f func(pref.MapKey, pref.Value) bool) {
for k, v := range x.mapv {
if !f(pref.ValueOf(k).MapKey(), v) {
@ -412,3 +437,31 @@ func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
}
}
func newListEntry(fd pref.FieldDescriptor) pref.Value {
switch fd.Kind() {
case pref.BoolKind:
return pref.ValueOf(false)
case pref.EnumKind:
return pref.ValueOf(fd.Enum().Values().Get(0).Number())
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
return pref.ValueOf(int32(0))
case pref.Uint32Kind, pref.Fixed32Kind:
return pref.ValueOf(uint32(0))
case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return pref.ValueOf(int64(0))
case pref.Uint64Kind, pref.Fixed64Kind:
return pref.ValueOf(uint64(0))
case pref.FloatKind:
return pref.ValueOf(float32(0))
case pref.DoubleKind:
return pref.ValueOf(float64(0))
case pref.StringKind:
return pref.ValueOf("")
case pref.BytesKind:
return pref.ValueOf(([]byte)(nil))
case pref.MessageKind, pref.GroupKind:
return pref.ValueOf(New(fd.Message()).ProtoReflect())
}
panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
}