all: add NewField, NewElement, NewValue

Add methods to protoreflect.{Message,List,Map} to constrict values
assignable to a message field, list element, or map value. These
methods return the default value for scalar fields, the zero value for
scalar list elements and map values, and an empty, mutable value for
messages, lists, and maps.

Deprecate the NewMessage methods on these types, which are superseded.

Updates golang/protobuf#879

Change-Id: I0f064f60c89a239330ccea81523f559f14fd2c4f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/188997
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-08-05 10:48:38 -07:00
parent c36f3ae703
commit f5274511fe
10 changed files with 235 additions and 44 deletions

View File

@ -689,11 +689,14 @@ func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
} }
} }
func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *{{.}}) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init() m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil { if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage() return fi.newField()
} else { } else {
return xt.New().Message() return xt.New()
} }
} }
func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {

View File

@ -57,48 +57,67 @@ var (
byteType = reflect.TypeOf(byte(0)) byteType = reflect.TypeOf(byte(0))
) )
var (
boolZero = pref.ValueOf(bool(false))
int32Zero = pref.ValueOf(int32(0))
int64Zero = pref.ValueOf(int64(0))
uint32Zero = pref.ValueOf(uint32(0))
uint64Zero = pref.ValueOf(uint64(0))
float32Zero = pref.ValueOf(float32(0))
float64Zero = pref.ValueOf(float64(0))
stringZero = pref.ValueOf(string(""))
bytesZero = pref.ValueOf([]byte(nil))
)
type scalarConverter struct { type scalarConverter struct {
goType, pbType reflect.Type goType, pbType reflect.Type
def pref.Value def pref.Value
} }
func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter { func newSingularConverter(t reflect.Type, fd pref.FieldDescriptor) Converter {
defVal := func(fd pref.FieldDescriptor, zero pref.Value) pref.Value {
if fd.Cardinality() == pref.Repeated {
// Default isn't defined for repeated fields.
return zero
}
return fd.Default()
}
switch fd.Kind() { switch fd.Kind() {
case pref.BoolKind: case pref.BoolKind:
if t.Kind() == reflect.Bool { if t.Kind() == reflect.Bool {
return &scalarConverter{t, boolType, fd.Default()} return &scalarConverter{t, boolType, defVal(fd, boolZero)}
} }
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind: case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
if t.Kind() == reflect.Int32 { if t.Kind() == reflect.Int32 {
return &scalarConverter{t, int32Type, fd.Default()} return &scalarConverter{t, int32Type, defVal(fd, int32Zero)}
} }
case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind: case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
if t.Kind() == reflect.Int64 { if t.Kind() == reflect.Int64 {
return &scalarConverter{t, int64Type, fd.Default()} return &scalarConverter{t, int64Type, defVal(fd, int64Zero)}
} }
case pref.Uint32Kind, pref.Fixed32Kind: case pref.Uint32Kind, pref.Fixed32Kind:
if t.Kind() == reflect.Uint32 { if t.Kind() == reflect.Uint32 {
return &scalarConverter{t, uint32Type, fd.Default()} return &scalarConverter{t, uint32Type, defVal(fd, uint32Zero)}
} }
case pref.Uint64Kind, pref.Fixed64Kind: case pref.Uint64Kind, pref.Fixed64Kind:
if t.Kind() == reflect.Uint64 { if t.Kind() == reflect.Uint64 {
return &scalarConverter{t, uint64Type, fd.Default()} return &scalarConverter{t, uint64Type, defVal(fd, uint64Zero)}
} }
case pref.FloatKind: case pref.FloatKind:
if t.Kind() == reflect.Float32 { if t.Kind() == reflect.Float32 {
return &scalarConverter{t, float32Type, fd.Default()} return &scalarConverter{t, float32Type, defVal(fd, float32Zero)}
} }
case pref.DoubleKind: case pref.DoubleKind:
if t.Kind() == reflect.Float64 { if t.Kind() == reflect.Float64 {
return &scalarConverter{t, float64Type, fd.Default()} return &scalarConverter{t, float64Type, defVal(fd, float64Zero)}
} }
case pref.StringKind: case pref.StringKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &scalarConverter{t, stringType, fd.Default()} return &scalarConverter{t, stringType, defVal(fd, stringZero)}
} }
case pref.BytesKind: case pref.BytesKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &scalarConverter{t, bytesType, fd.Default()} return &scalarConverter{t, bytesType, defVal(fd, bytesZero)}
} }
case pref.EnumKind: case pref.EnumKind:
// Handle enums, which must be a named int32 type. // Handle enums, which must be a named int32 type.
@ -133,6 +152,9 @@ func (c *scalarConverter) GoValueOf(v pref.Value) reflect.Value {
} }
func (c *scalarConverter) New() pref.Value { func (c *scalarConverter) New() pref.Value {
if c.pbType == bytesType {
return pref.ValueOf(append(([]byte)(nil), c.def.Bytes()...))
}
return c.def return c.def
} }
@ -142,7 +164,13 @@ type enumConverter struct {
} }
func newEnumConverter(goType reflect.Type, fd pref.FieldDescriptor) Converter { func newEnumConverter(goType reflect.Type, fd pref.FieldDescriptor) Converter {
return &enumConverter{goType, fd.Default()} var def pref.Value
if fd.Cardinality() == pref.Repeated {
def = pref.ValueOf(fd.Enum().Values().Get(0).Number())
} else {
def = fd.Default()
}
return &enumConverter{goType, def}
} }
func (c *enumConverter) PBValueOf(v reflect.Value) pref.Value { func (c *enumConverter) PBValueOf(v reflect.Value) pref.Value {

View File

@ -62,7 +62,10 @@ func (ls *listReflect) Truncate(i int) {
ls.v.Elem().Set(ls.v.Elem().Slice(0, i)) ls.v.Elem().Set(ls.v.Elem().Slice(0, i))
} }
func (ls *listReflect) NewMessage() pref.Message { func (ls *listReflect) NewMessage() pref.Message {
return ls.conv.New().Message() return ls.NewElement().Message()
}
func (ls *listReflect) NewElement() pref.Value {
return ls.conv.New()
} }
func (ls *listReflect) ProtoUnwrap() interface{} { func (ls *listReflect) ProtoUnwrap() interface{} {
return ls.v.Interface() return ls.v.Interface()

View File

@ -85,7 +85,10 @@ func (ms *mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
} }
} }
func (ms *mapReflect) NewMessage() pref.Message { func (ms *mapReflect) NewMessage() pref.Message {
return ms.valConv.New().Message() return ms.NewValue().Message()
}
func (ms *mapReflect) NewValue() pref.Value {
return ms.valConv.New()
} }
func (ms *mapReflect) ProtoUnwrap() interface{} { func (ms *mapReflect) ProtoUnwrap() interface{} {
return ms.v.Interface() return ms.v.Interface()

View File

@ -26,6 +26,7 @@ type fieldInfo struct {
set func(pointer, pref.Value) set func(pointer, pref.Value)
mutable func(pointer) pref.Value mutable func(pointer) pref.Value
newMessage func() pref.Message newMessage func() pref.Message
newField func() pref.Value
} }
func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo { func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x exporter, ot reflect.Type) fieldInfo {
@ -113,6 +114,9 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, x export
newMessage: func() pref.Message { newMessage: func() pref.Message {
return conv.New().Message() return conv.New().Message()
}, },
newField: func() pref.Value {
return conv.New()
},
} }
} }
@ -160,6 +164,9 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField, x exporter
} }
return conv.PBValueOf(v) return conv.PBValueOf(v)
}, },
newField: func() pref.Value {
return conv.New()
},
} }
} }
@ -204,6 +211,9 @@ func fieldInfoForList(fd pref.FieldDescriptor, fs reflect.StructField, x exporte
v := p.Apply(fieldOffset).AsValueOf(fs.Type) v := p.Apply(fieldOffset).AsValueOf(fs.Type)
return conv.PBValueOf(v) return conv.PBValueOf(v)
}, },
newField: func() pref.Value {
return conv.New()
},
} }
} }
@ -289,6 +299,9 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField, x expor
} }
} }
}, },
newField: func() pref.Value {
return conv.New()
},
} }
} }
@ -367,6 +380,10 @@ func fieldInfoForWeakMessage(fd pref.FieldDescriptor, weakOffset offset) fieldIn
lazyInit() lazyInit()
return messageType.New() return messageType.New()
}, },
newField: func() pref.Value {
lazyInit()
return pref.ValueOf(messageType.New())
},
} }
} }
@ -417,6 +434,9 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField, x expo
newMessage: func() pref.Message { newMessage: func() pref.Message {
return conv.New().Message() return conv.New().Message()
}, },
newField: func() pref.Value {
return conv.New()
},
} }
} }

View File

@ -92,11 +92,14 @@ func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Val
} }
} }
func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *messageState) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init() m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil { if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage() return fi.newField()
} else { } else {
return xt.New().Message() return xt.New()
} }
} }
func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
@ -199,11 +202,14 @@ func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protore
} }
} }
func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
return m.NewField(fd).Message()
}
func (m *messageReflectWrapper) NewField(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.messageInfo().init() m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil { if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage() return fi.newField()
} else { } else {
return xt.New().Message() return xt.New()
} }
} }
func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {

View File

@ -82,6 +82,10 @@ func (m *message) NewMessage(pref.FieldDescriptor) pref.Message {
panic("invalid field descriptor") panic("invalid field descriptor")
} }
func (m *message) NewField(pref.FieldDescriptor) pref.Value {
panic("invalid field descriptor")
}
func (m *message) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor { func (m *message) WhichOneof(pref.OneofDescriptor) pref.FieldDescriptor {
panic("invalid oneof descriptor") panic("invalid oneof descriptor")
} }

View File

@ -119,8 +119,15 @@ type Message interface {
// NewMessage returns a newly allocated empty message assignable to // NewMessage returns a newly allocated empty message assignable to
// the field of the given descriptor. // the field of the given descriptor.
// It panics if the field is not a singular message. // It panics if the field is not a singular message.
//
// Deprecated: Use NewField instead.
NewMessage(FieldDescriptor) Message NewMessage(FieldDescriptor) Message
// NewField returns a new value for assignable to the field of a given descriptor.
// For scalars, this returns the default value.
// For lists, maps, and messages, this returns a new, empty, mutable value.
NewField(FieldDescriptor) Value
// WhichOneof reports which field within the oneof is populated, // WhichOneof reports which field within the oneof is populated,
// returning nil if none are populated. // returning nil if none are populated.
// It panics if the oneof descriptor does not belong to this message. // It panics if the oneof descriptor does not belong to this message.
@ -197,7 +204,15 @@ type List interface {
// NewMessage returns a newly allocated empty message assignable as a list entry. // NewMessage returns a newly allocated empty message assignable as a list entry.
// It panics if the list entry type is not a message. // It panics if the list entry type is not a message.
//
// Deprecated: Use NewElement instead.
NewMessage() Message NewMessage() Message
// NewElement returns a new value for a list element.
// For enums, this returns the first enum value.
// For other scalars, this returns the zero value.
// For messages, this returns a new, empty, mutable value.
NewElement() Value
} }
// Map is an unordered, associative map. // Map is an unordered, associative map.
@ -240,5 +255,13 @@ type Map interface {
// NewMessage returns a newly allocated empty message assignable as a map value. // NewMessage returns a newly allocated empty message assignable as a map value.
// It panics if the map value type is not a message. // It panics if the map value type is not a message.
//
// Deprecated: Use NewValue instead.
NewMessage() Message NewMessage() Message
// NewValue returns a new value assignable as a map value.
// For enums, this returns the first enum value.
// For other scalars, this returns the zero value.
// For messages, this returns a new, empty, mutable value.
NewValue() Value
} }

View File

@ -85,9 +85,15 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
testFieldList(t, m, fd) testFieldList(t, m, fd)
case fd.IsMap(): case fd.IsMap():
testFieldMap(t, m, fd) testFieldMap(t, m, fd)
case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind: case fd.Message() != nil:
default:
if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) {
t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want))
}
if fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind {
testFieldFloat(t, m, fd) testFieldFloat(t, m, fd)
} }
}
// Set to a non-zero value, the zero value, different non-zero values. // Set to a non-zero value, the zero value, different non-zero values.
for _, n := range []seed{1, 0, minVal, maxVal} { for _, n := range []seed{1, 0, minVal, maxVal} {
@ -108,6 +114,9 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
wantHas = true wantHas = true
} }
} }
if fd.Syntax() == pref.Proto3 && fd.Cardinality() != pref.Repeated && fd.ContainingOneof() == nil && fd.Kind() == pref.EnumKind && v.Enum() == 0 {
wantHas = false
}
if got, want := m.Has(fd), wantHas; got != want { if got, want := m.Has(fd), wantHas; got != want {
t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want) t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
} }
@ -166,8 +175,16 @@ func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
name := fd.FullName() name := fd.FullName()
num := fd.Number() num := fd.Number()
// New values.
m.Clear(fd) // start with an empty map m.Clear(fd) // start with an empty map
mapv := m.Mutable(fd).Map() mapv := m.Get(fd).Map()
if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
}
mapv = m.Mutable(fd).Map() // mutable map
if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
}
// Add values. // Add values.
want := make(testMap) want := make(testMap)
@ -228,6 +245,7 @@ func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) } func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
func (m testMap) Len() int { return len(m) } func (m testMap) Len() int { return len(m) }
func (m testMap) NewMessage() pref.Message { panic("unimplemented") } func (m testMap) NewMessage() pref.Message { panic("unimplemented") }
func (m testMap) NewValue() pref.Value { panic("unimplemented") }
func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) { func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
for k, v := range m { for k, v := range m {
if !f(pref.ValueOf(k).MapKey(), v) { if !f(pref.ValueOf(k).MapKey(), v) {
@ -242,7 +260,14 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
num := fd.Number() num := fd.Number()
m.Clear(fd) // start with an empty list m.Clear(fd) // start with an empty list
list := m.Mutable(fd).List() list := m.Get(fd).List()
if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
}
list = m.Mutable(fd).List() // mutable list
if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
}
// Append values. // Append values.
var want pref.List = &testList{} var want pref.List = &testList{}
@ -293,6 +318,7 @@ func (l *testList) Len() int { return len(l.a) }
func (l *testList) Set(n int, v pref.Value) { l.a[n] = v } func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
func (l *testList) Truncate(n int) { l.a = l.a[:n] } func (l *testList) Truncate(n int) { l.a = l.a[:n] }
func (l *testList) NewMessage() pref.Message { panic("unimplemented") } func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
func (l *testList) NewElement() pref.Value { panic("unimplemented") }
// testFieldFloat exercises some interesting floating-point scalar field values. // testFieldFloat exercises some interesting floating-point scalar field values.
func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) { func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
@ -462,6 +488,19 @@ const (
maxVal seed = -2 maxVal seed = -2
) )
// newSeed creates new seed values from a base, for example to create seeds for the
// elements in a list. If the input seed is minVal or maxVal, so is the output.
func newSeed(n seed, adjust ...int) seed {
switch n {
case minVal, maxVal:
return n
}
for _, a := range adjust {
n = 10*n + seed(a)
}
return n
}
// newValue returns a new value assignable to a field. // newValue returns a new value assignable to a field.
// //
// The stack parameter is used to avoid infinite recursion when populating circular // The stack parameter is used to avoid infinite recursion when populating circular
@ -469,7 +508,7 @@ const (
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value { func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
switch { switch {
case fd.IsList(): case fd.IsList():
list := m.New().Mutable(fd).List() list := m.NewField(fd).List()
if n == 0 { if n == 0 {
return pref.ValueOf(list) return pref.ValueOf(list)
} }
@ -479,17 +518,17 @@ func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.Mess
list.Append(newListElement(fd, list, n, stack)) list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list) return pref.ValueOf(list)
case fd.IsMap(): case fd.IsMap():
mapv := m.New().Mutable(fd).Map() mapv := m.NewField(fd).Map()
if n == 0 { if n == 0 {
return pref.ValueOf(mapv) return pref.ValueOf(mapv)
} }
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack)) mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack)) mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack)) mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack)) mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
return pref.ValueOf(mapv) return pref.ValueOf(mapv)
case fd.Message() != nil: case fd.Message() != nil:
return populateMessage(m.Mutable(fd).Message(), n, stack) return populateMessage(m.NewField(fd).Message(), n, stack)
default: default:
return newScalarValue(fd, n) return newScalarValue(fd, n)
} }
@ -499,7 +538,7 @@ func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pre
if fd.Message() == nil { if fd.Message() == nil {
return newScalarValue(fd, n) return newScalarValue(fd, n)
} }
return populateMessage(list.NewMessage(), n, stack) return populateMessage(list.NewElement().Message(), n, stack)
} }
func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey { func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
@ -512,7 +551,7 @@ func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.Me
if vd.Message() == nil { if vd.Message() == nil {
return newScalarValue(vd, n) return newScalarValue(vd, n)
} }
return populateMessage(mapv.NewMessage(), n, stack) return populateMessage(mapv.NewValue().Message(), n, stack)
} }
func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value { func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
@ -520,8 +559,17 @@ func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
case pref.BoolKind: case pref.BoolKind:
return pref.ValueOf(n != 0) return pref.ValueOf(n != 0)
case pref.EnumKind: case pref.EnumKind:
// TODO: use actual value vals := fd.Enum().Values()
return pref.ValueOf(pref.EnumNumber(n)) var i int
switch n {
case minVal:
i = 0
case maxVal:
i = vals.Len() - 1
default:
i = int(n) % vals.Len()
}
return pref.ValueOf(vals.Get(i).Number())
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind: case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
switch n { switch n {
case minVal: case minVal:
@ -608,7 +656,7 @@ func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pre
if fd.IsWeak() { if fd.IsWeak() {
continue continue
} }
m.Set(fd, newValue(m, fd, 10*n+seed(i), stack)) m.Set(fd, newValue(m, fd, newSeed(n, i), stack))
} }
return pref.ValueOf(m) return pref.ValueOf(m)
} }

View File

@ -23,7 +23,7 @@ import (
// package documentation for that interface for how to get and set fields and // package documentation for that interface for how to get and set fields and
// otherwise interact with the contents of a Message. // otherwise interact with the contents of a Message.
// //
// Reflection API functions which construct messages, such as NewMessage, // Reflection API functions which construct messages, such as NewField,
// return new dynamic messages of the appropriate type. Functions which take // return new dynamic messages of the appropriate type. Functions which take
// messages, such as Set for a message-value field, will accept any message // messages, such as Set for a message-value field, will accept any message
// with a compatible type. // with a compatible type.
@ -156,19 +156,9 @@ func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName())) panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
} }
m.clearOtherOneofFields(fd) m.clearOtherOneofFields(fd)
switch { m.known[num] = m.NewField(fd)
case fd.IsExtension(): if fd.IsExtension() {
m.known[num] = fd.(pref.ExtensionType).New()
m.ext[num] = fd m.ext[num] = fd
case fd.IsMap():
m.known[num] = pref.ValueOf(&dynamicMap{
desc: fd,
mapv: make(map[interface{}]pref.Value),
})
case fd.IsList():
m.known[num] = pref.ValueOf(&dynamicList{desc: fd})
case fd.Message() != nil:
m.known[num] = pref.ValueOf(m.NewMessage(fd))
} }
return m.known[num] return m.known[num]
} }
@ -221,6 +211,27 @@ func (m *Message) NewMessage(fd pref.FieldDescriptor) pref.Message {
return New(md).ProtoReflect() return New(md).ProtoReflect()
} }
// NewField returns a new value for assignable to the field of a given descriptor.
// See protoreflect.Message for details.
func (m *Message) NewField(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd)
switch {
case fd.IsExtension():
return fd.(pref.ExtensionType).New()
case fd.IsMap():
return pref.ValueOf(&dynamicMap{
desc: fd,
mapv: make(map[interface{}]pref.Value),
})
case fd.IsList():
return pref.ValueOf(&dynamicList{desc: fd})
case fd.Message() != nil:
return pref.ValueOf(New(fd.Message()).ProtoReflect())
default:
return fd.Default()
}
}
// WhichOneof reports which field in a oneof is populated, returning nil if none are populated. // WhichOneof reports which field in a oneof is populated, returning nil if none are populated.
// See protoreflect.Message for details. // See protoreflect.Message for details.
func (m *Message) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor { func (m *Message) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor {
@ -278,6 +289,9 @@ func (x emptyList) NewMessage() pref.Message {
} }
return New(md).ProtoReflect() return New(md).ProtoReflect()
} }
func (x emptyList) NewElement() pref.Value {
return newListEntry(x.desc)
}
type dynamicList struct { type dynamicList struct {
desc pref.FieldDescriptor desc pref.FieldDescriptor
@ -318,6 +332,10 @@ func (x *dynamicList) NewMessage() pref.Message {
return New(md).ProtoReflect() return New(md).ProtoReflect()
} }
func (x *dynamicList) NewElement() pref.Value {
return newListEntry(x.desc)
}
type dynamicMap struct { type dynamicMap struct {
desc pref.FieldDescriptor desc pref.FieldDescriptor
mapv map[interface{}]pref.Value mapv map[interface{}]pref.Value
@ -339,6 +357,13 @@ func (x *dynamicMap) NewMessage() pref.Message {
} }
return New(md).ProtoReflect() return New(md).ProtoReflect()
} }
func (x *dynamicMap) NewValue() pref.Value {
if md := x.desc.MapValue().Message(); md != nil {
return pref.ValueOf(New(md).ProtoReflect())
}
return x.desc.MapValue().Default()
}
func (x *dynamicMap) Range(f func(pref.MapKey, pref.Value) bool) { func (x *dynamicMap) Range(f func(pref.MapKey, pref.Value) bool) {
for k, v := range x.mapv { for k, v := range x.mapv {
if !f(pref.ValueOf(k).MapKey(), v) { if !f(pref.ValueOf(k).MapKey(), v) {
@ -412,3 +437,31 @@ func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())) panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
} }
} }
func newListEntry(fd pref.FieldDescriptor) pref.Value {
switch fd.Kind() {
case pref.BoolKind:
return pref.ValueOf(false)
case pref.EnumKind:
return pref.ValueOf(fd.Enum().Values().Get(0).Number())
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
return pref.ValueOf(int32(0))
case pref.Uint32Kind, pref.Fixed32Kind:
return pref.ValueOf(uint32(0))
case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return pref.ValueOf(int64(0))
case pref.Uint64Kind, pref.Fixed64Kind:
return pref.ValueOf(uint64(0))
case pref.FloatKind:
return pref.ValueOf(float32(0))
case pref.DoubleKind:
return pref.ValueOf(float64(0))
case pref.StringKind:
return pref.ValueOf("")
case pref.BytesKind:
return pref.ValueOf(([]byte)(nil))
case pref.MessageKind, pref.GroupKind:
return pref.ValueOf(New(fd.Message()).ProtoReflect())
}
panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
}