types/dynamicpb: support dynamic extensions

Add a dynamicpb.NewExtensionType function to permit creating extension
types from descriptors.

Also fix a some bugs around extension field handling:
When creating a new value for an extension field, use the
ExtensionType's Zero or New method to create the value.

Ensure that prototest exercises true zero-values of fields. (i.e.,
getting a list, map, or message from an empty message rather than
creating a new empty one with NewField.)

Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-07-15 13:39:43 -07:00
parent 293dc761cb
commit 290ceea663
3 changed files with 249 additions and 31 deletions

View File

@ -29,6 +29,12 @@ type MessageOptions struct {
// //
// If nil, TestMessage will look for extension types in the global registry. // If nil, TestMessage will look for extension types in the global registry.
ExtensionTypes []pref.ExtensionType ExtensionTypes []pref.ExtensionType
// Resolver is used for looking up types when unmarshaling extension fields.
// If nil, this defaults to using protoregistry.GlobalTypes.
Resolver interface {
preg.ExtensionTypeResolver
}
} }
// TestMessage runs the provided m through a series of tests // TestMessage runs the provided m through a series of tests
@ -57,12 +63,20 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
// Test round-trip marshal/unmarshal. // Test round-trip marshal/unmarshal.
m2 := m.ProtoReflect().New().Interface() m2 := m.ProtoReflect().New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil) populateMessage(m2.ProtoReflect(), 1, nil)
b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2) for _, xt := range opts.ExtensionTypes {
m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
}
b, err := proto.MarshalOptions{
AllowPartial: true,
}.Marshal(m2)
if err != nil { if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2)) t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
} }
m3 := m.ProtoReflect().New().Interface() m3 := m.ProtoReflect().New().Interface()
if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil { if err := (proto.UnmarshalOptions{
AllowPartial: true,
Resolver: opts.Resolver,
}.Unmarshal(b, m3)); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2)) t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
} }
if !proto.Equal(m2, m3) { if !proto.Equal(m2, m3) {
@ -150,7 +164,7 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
} }
case fd.IsMap(): case fd.IsMap():
if got := m.Get(fd); got.Map().Len() != 0 { if got := m.Get(fd); got.Map().Len() != 0 {
t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got)) t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
} }
case fd.Message() == nil: case fd.Message() == nil:
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) { if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
@ -158,6 +172,21 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
} }
} }
// Set to the default value.
switch {
case fd.IsList() || fd.IsMap():
m.Set(fd, m.Get(fd))
if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
}
case fd.Message() == nil:
m.Set(fd, m.Get(fd))
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
}
}
m.Clear(fd)
// Set to the wrong type. // Set to the wrong type.
v := pref.ValueOf("") v := pref.ValueOf("")
if fd.Kind() == pref.StringKind { if fd.Kind() == pref.StringKind {
@ -508,26 +537,29 @@ func newSeed(n seed, adjust ...int) seed {
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value { func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
switch { switch {
case fd.IsList(): case fd.IsList():
list := m.NewField(fd).List()
if n == 0 { if n == 0 {
return pref.ValueOf(list) return m.New().Get(fd)
} }
list := m.NewField(fd).List()
list.Append(newListElement(fd, list, 0, stack)) list.Append(newListElement(fd, list, 0, stack))
list.Append(newListElement(fd, list, minVal, stack)) list.Append(newListElement(fd, list, minVal, stack))
list.Append(newListElement(fd, list, maxVal, stack)) list.Append(newListElement(fd, list, maxVal, stack))
list.Append(newListElement(fd, list, n, stack)) list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list) return pref.ValueOf(list)
case fd.IsMap(): case fd.IsMap():
mapv := m.NewField(fd).Map()
if n == 0 { if n == 0 {
return pref.ValueOf(mapv) return m.New().Get(fd)
} }
mapv := m.NewField(fd).Map()
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack)) mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack)) mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack)) mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack)) mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
return pref.ValueOf(mapv) return pref.ValueOf(mapv)
case fd.Message() != nil: case fd.Message() != nil:
//if n == 0 {
// return m.New().Get(fd)
//}
return populateMessage(m.NewField(fd).Message(), n, stack) return populateMessage(m.NewField(fd).Message(), n, stack)
default: default:
return newScalarValue(fd, n) return newScalarValue(fd, n)

View File

@ -122,16 +122,22 @@ func (m *Message) Clear(fd pref.FieldDescriptor) {
func (m *Message) Get(fd pref.FieldDescriptor) pref.Value { func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd) m.checkField(fd)
num := fd.Number() num := fd.Number()
if v, ok := m.known[num]; ok { if fd.IsExtension() {
if !fd.IsExtension() || fd == m.ext[num] { if fd != m.ext[num] {
return v return fd.(pref.ExtensionTypeDescriptor).Type().Zero()
} }
return m.known[num]
}
if v, ok := m.known[num]; ok {
return v
} }
switch { switch {
case fd.IsMap(): case fd.IsMap():
return pref.ValueOf(&dynamicMap{desc: fd}) return pref.ValueOf(&dynamicMap{desc: fd})
case fd.Cardinality() == pref.Repeated: case fd.IsList():
return pref.ValueOf(emptyList{desc: fd}) return pref.ValueOf(emptyList{desc: fd})
case fd.Message() != nil:
return pref.ValueOf(&Message{desc: fd.Message()})
case fd.Kind() == pref.BytesKind: case fd.Kind() == pref.BytesKind:
return pref.ValueOf(append([]byte(nil), fd.Default().Bytes()...)) return pref.ValueOf(append([]byte(nil), fd.Default().Bytes()...))
default: default:
@ -143,15 +149,23 @@ func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
// See protoreflect.Message for details. // See protoreflect.Message for details.
func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value { func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd) m.checkField(fd)
num := fd.Number()
if v, ok := m.known[num]; ok {
if !fd.IsExtension() || fd == m.ext[num] {
return v
}
}
if !fd.IsMap() && !fd.IsList() && fd.Message() == nil { if !fd.IsMap() && !fd.IsList() && fd.Message() == nil {
panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName())) panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
} }
if m.known == nil {
panic(errors.New("%v: modification of read-only message", fd.FullName()))
}
num := fd.Number()
if fd.IsExtension() {
if fd != m.ext[num] {
m.ext[num] = fd
m.known[num] = fd.(pref.ExtensionTypeDescriptor).Type().New()
}
return m.known[num]
}
if v, ok := m.known[num]; ok {
return v
}
m.clearOtherOneofFields(fd) m.clearOtherOneofFields(fd)
m.known[num] = m.NewField(fd) m.known[num] = m.NewField(fd)
if fd.IsExtension() { if fd.IsExtension() {
@ -164,22 +178,16 @@ func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
// See protoreflect.Message for details. // See protoreflect.Message for details.
func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) { func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
m.checkField(fd) m.checkField(fd)
switch { if m.known == nil {
case fd.IsExtension(): panic(errors.New("%v: modification of read-only message", fd.FullName()))
}
if fd.IsExtension() {
if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) { if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())) panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
} }
m.ext[fd.Number()] = fd m.ext[fd.Number()] = fd
case fd.IsMap(): } else {
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd { typecheck(fd, v)
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
}
case fd.IsList():
if list, ok := v.Interface().(*dynamicList); !ok || list.desc != fd {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
}
default:
typecheckSingular(fd, v)
} }
m.clearOtherOneofFields(fd) m.clearOtherOneofFields(fd)
m.known[fd.Number()] = v m.known[fd.Number()] = v
@ -251,6 +259,9 @@ func (m *Message) GetUnknown() pref.RawFields {
// SetUnknown sets the raw unknown fields. // SetUnknown sets the raw unknown fields.
// See protoreflect.Message for details. // See protoreflect.Message for details.
func (m *Message) SetUnknown(r pref.RawFields) { func (m *Message) SetUnknown(r pref.RawFields) {
if m.known == nil {
panic(errors.New("%v: modification of read-only message", m.desc.FullName()))
}
m.unknown = r m.unknown = r
} }
@ -406,7 +417,43 @@ func isSet(fd pref.FieldDescriptor, v pref.Value) bool {
return true return true
} }
func typecheck(fd pref.FieldDescriptor, v pref.Value) {
if err := typeIsValid(fd, v); err != nil {
panic(err)
}
}
func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
switch {
case fd.IsMap():
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
}
return nil
case fd.IsList():
switch list := v.Interface().(type) {
case *dynamicList:
if list.desc == fd {
return nil
}
case emptyList:
if list.desc == fd {
return nil
}
}
return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
default:
return singularTypeIsValid(fd, v)
}
}
func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) { func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
if err := singularTypeIsValid(fd, v); err != nil {
panic(err)
}
}
func singularTypeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
vi := v.Interface() vi := v.Interface()
var ok bool var ok bool
switch fd.Kind() { switch fd.Kind() {
@ -435,12 +482,16 @@ func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
var m pref.Message var m pref.Message
m, ok = vi.(pref.Message) m, ok = vi.(pref.Message)
if ok && m.Descriptor().FullName() != fd.Message().FullName() { if ok && m.Descriptor().FullName() != fd.Message().FullName() {
panic(errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())) return errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())
}
if dm, ok := vi.(*Message); ok && dm.known == nil {
return errors.New("%v: assigning invalid zero-value message", fd.FullName())
} }
} }
if !ok { if !ok {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())) return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
} }
return nil
} }
func newListEntry(fd pref.FieldDescriptor) pref.Value { func newListEntry(fd pref.FieldDescriptor) pref.Value {
@ -470,3 +521,102 @@ func newListEntry(fd pref.FieldDescriptor) pref.Value {
} }
panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind())) panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
} }
// extensionType is a dynamic protoreflect.ExtensionType.
type extensionType struct {
desc extensionTypeDescriptor
}
// NewExtensionType creates a new ExtensionType with the provided descriptor.
//
// Dynamic ExtensionTypes with the same descriptor compare as equal. That is,
// if xd1 == xd2, then NewExtensionType(xd1) == NewExtensionType(xd2).
//
// The InterfaceOf and ValueOf methods of the extension type are defined as:
//
// func (xt extensionType) ValueOf(iv interface{}) protoreflect.Value {
// return protoreflect.ValueOf(iv)
// }
//
// func (xt extensionType) InterfaceOf(v protoreflect.Value) interface{} {
// return v.Interface()
// }
//
// The Go type used by the proto.GetExtension and proto.SetExtension functions
// is determined by these methods, and is therefore equivalent to the Go type
// used to represent a protoreflect.Value. See the protoreflect.Value
// documentation for more details.
func NewExtensionType(desc pref.ExtensionDescriptor) pref.ExtensionType {
if xt, ok := desc.(pref.ExtensionTypeDescriptor); ok {
desc = xt.Descriptor()
}
return extensionType{extensionTypeDescriptor{desc}}
}
func (xt extensionType) New() pref.Value {
switch {
case xt.desc.IsMap():
return pref.ValueOf(&dynamicMap{
desc: xt.desc,
mapv: make(map[interface{}]pref.Value),
})
case xt.desc.IsList():
return pref.ValueOf(&dynamicList{desc: xt.desc})
case xt.desc.Message() != nil:
return pref.ValueOf(New(xt.desc.Message()))
default:
return xt.desc.Default()
}
}
func (xt extensionType) Zero() pref.Value {
switch {
case xt.desc.IsMap():
return pref.ValueOf(&dynamicMap{desc: xt.desc})
case xt.desc.Cardinality() == pref.Repeated:
return pref.ValueOf(emptyList{desc: xt.desc})
case xt.desc.Message() != nil:
return pref.ValueOf(&Message{desc: xt.desc.Message()})
default:
return xt.desc.Default()
}
}
func (xt extensionType) GoType() reflect.Type {
return reflect.TypeOf(xt.InterfaceOf(xt.New()))
}
func (xt extensionType) TypeDescriptor() pref.ExtensionTypeDescriptor {
return xt.desc
}
func (xt extensionType) ValueOf(iv interface{}) pref.Value {
v := pref.ValueOf(iv)
typecheck(xt.desc, v)
return v
}
func (xt extensionType) InterfaceOf(v pref.Value) interface{} {
typecheck(xt.desc, v)
return v.Interface()
}
func (xt extensionType) IsValidInterface(iv interface{}) bool {
return typeIsValid(xt.desc, pref.ValueOf(iv)) == nil
}
func (xt extensionType) IsValidValue(v pref.Value) bool {
return typeIsValid(xt.desc, v) == nil
}
type extensionTypeDescriptor struct {
pref.ExtensionDescriptor
}
func (xt extensionTypeDescriptor) Type() pref.ExtensionType {
return extensionType{xt}
}
func (xt extensionTypeDescriptor) Descriptor() pref.ExtensionDescriptor {
return xt.ExtensionDescriptor
}

View File

@ -8,6 +8,8 @@ import (
"testing" "testing"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/testing/prototest" "google.golang.org/protobuf/testing/prototest"
"google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/dynamicpb"
@ -24,3 +26,37 @@ func TestConformance(t *testing.T) {
prototest.TestMessage(t, dynamicpb.New(message.ProtoReflect().Descriptor()), prototest.MessageOptions{}) prototest.TestMessage(t, dynamicpb.New(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
} }
} }
func TestDynamicExtensions(t *testing.T) {
file, err := preg.GlobalFiles.FindFileByPath("test/ext.proto")
if err != nil {
t.Fatal(err)
}
md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
opts := prototest.MessageOptions{
Resolver: extResolver{},
}
for i := 0; i < file.Extensions().Len(); i++ {
opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
}
prototest.TestMessage(t, dynamicpb.New(md), opts)
}
type extResolver struct{}
func (extResolver) FindExtensionByName(field pref.FullName) (pref.ExtensionType, error) {
xt, err := preg.GlobalTypes.FindExtensionByName(field)
if err != nil {
return nil, err
}
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
}
func (extResolver) FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error) {
xt, err := preg.GlobalTypes.FindExtensionByNumber(message, field)
if err != nil {
return nil, err
}
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
}