mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-01-30 12:32:36 +00:00
types/dynamicpb: support dynamic extensions
Add a dynamicpb.NewExtensionType function to permit creating extension types from descriptors. Also fix a some bugs around extension field handling: When creating a new value for an extension field, use the ExtensionType's Zero or New method to create the value. Ensure that prototest exercises true zero-values of fields. (i.e., getting a list, map, or message from an empty message rather than creating a new empty one with NewField.) Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180 Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
293dc761cb
commit
290ceea663
@ -29,6 +29,12 @@ type MessageOptions struct {
|
||||
//
|
||||
// If nil, TestMessage will look for extension types in the global registry.
|
||||
ExtensionTypes []pref.ExtensionType
|
||||
|
||||
// Resolver is used for looking up types when unmarshaling extension fields.
|
||||
// If nil, this defaults to using protoregistry.GlobalTypes.
|
||||
Resolver interface {
|
||||
preg.ExtensionTypeResolver
|
||||
}
|
||||
}
|
||||
|
||||
// TestMessage runs the provided m through a series of tests
|
||||
@ -57,12 +63,20 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
|
||||
// Test round-trip marshal/unmarshal.
|
||||
m2 := m.ProtoReflect().New().Interface()
|
||||
populateMessage(m2.ProtoReflect(), 1, nil)
|
||||
b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
|
||||
for _, xt := range opts.ExtensionTypes {
|
||||
m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
|
||||
}
|
||||
b, err := proto.MarshalOptions{
|
||||
AllowPartial: true,
|
||||
}.Marshal(m2)
|
||||
if err != nil {
|
||||
t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
|
||||
}
|
||||
m3 := m.ProtoReflect().New().Interface()
|
||||
if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
|
||||
if err := (proto.UnmarshalOptions{
|
||||
AllowPartial: true,
|
||||
Resolver: opts.Resolver,
|
||||
}.Unmarshal(b, m3)); err != nil {
|
||||
t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
|
||||
}
|
||||
if !proto.Equal(m2, m3) {
|
||||
@ -150,7 +164,7 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
|
||||
}
|
||||
case fd.IsMap():
|
||||
if got := m.Get(fd); got.Map().Len() != 0 {
|
||||
t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
|
||||
t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
|
||||
}
|
||||
case fd.Message() == nil:
|
||||
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
|
||||
@ -158,6 +172,21 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
|
||||
}
|
||||
}
|
||||
|
||||
// Set to the default value.
|
||||
switch {
|
||||
case fd.IsList() || fd.IsMap():
|
||||
m.Set(fd, m.Get(fd))
|
||||
if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
|
||||
t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
|
||||
}
|
||||
case fd.Message() == nil:
|
||||
m.Set(fd, m.Get(fd))
|
||||
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
|
||||
t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
|
||||
}
|
||||
}
|
||||
m.Clear(fd)
|
||||
|
||||
// Set to the wrong type.
|
||||
v := pref.ValueOf("")
|
||||
if fd.Kind() == pref.StringKind {
|
||||
@ -508,26 +537,29 @@ func newSeed(n seed, adjust ...int) seed {
|
||||
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
|
||||
switch {
|
||||
case fd.IsList():
|
||||
list := m.NewField(fd).List()
|
||||
if n == 0 {
|
||||
return pref.ValueOf(list)
|
||||
return m.New().Get(fd)
|
||||
}
|
||||
list := m.NewField(fd).List()
|
||||
list.Append(newListElement(fd, list, 0, stack))
|
||||
list.Append(newListElement(fd, list, minVal, stack))
|
||||
list.Append(newListElement(fd, list, maxVal, stack))
|
||||
list.Append(newListElement(fd, list, n, stack))
|
||||
return pref.ValueOf(list)
|
||||
case fd.IsMap():
|
||||
mapv := m.NewField(fd).Map()
|
||||
if n == 0 {
|
||||
return pref.ValueOf(mapv)
|
||||
return m.New().Get(fd)
|
||||
}
|
||||
mapv := m.NewField(fd).Map()
|
||||
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
|
||||
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
|
||||
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
|
||||
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
|
||||
return pref.ValueOf(mapv)
|
||||
case fd.Message() != nil:
|
||||
//if n == 0 {
|
||||
// return m.New().Get(fd)
|
||||
//}
|
||||
return populateMessage(m.NewField(fd).Message(), n, stack)
|
||||
default:
|
||||
return newScalarValue(fd, n)
|
||||
|
@ -122,16 +122,22 @@ func (m *Message) Clear(fd pref.FieldDescriptor) {
|
||||
func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
|
||||
m.checkField(fd)
|
||||
num := fd.Number()
|
||||
if v, ok := m.known[num]; ok {
|
||||
if !fd.IsExtension() || fd == m.ext[num] {
|
||||
return v
|
||||
if fd.IsExtension() {
|
||||
if fd != m.ext[num] {
|
||||
return fd.(pref.ExtensionTypeDescriptor).Type().Zero()
|
||||
}
|
||||
return m.known[num]
|
||||
}
|
||||
if v, ok := m.known[num]; ok {
|
||||
return v
|
||||
}
|
||||
switch {
|
||||
case fd.IsMap():
|
||||
return pref.ValueOf(&dynamicMap{desc: fd})
|
||||
case fd.Cardinality() == pref.Repeated:
|
||||
case fd.IsList():
|
||||
return pref.ValueOf(emptyList{desc: fd})
|
||||
case fd.Message() != nil:
|
||||
return pref.ValueOf(&Message{desc: fd.Message()})
|
||||
case fd.Kind() == pref.BytesKind:
|
||||
return pref.ValueOf(append([]byte(nil), fd.Default().Bytes()...))
|
||||
default:
|
||||
@ -143,15 +149,23 @@ func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
|
||||
// See protoreflect.Message for details.
|
||||
func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
|
||||
m.checkField(fd)
|
||||
num := fd.Number()
|
||||
if v, ok := m.known[num]; ok {
|
||||
if !fd.IsExtension() || fd == m.ext[num] {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if !fd.IsMap() && !fd.IsList() && fd.Message() == nil {
|
||||
panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
|
||||
}
|
||||
if m.known == nil {
|
||||
panic(errors.New("%v: modification of read-only message", fd.FullName()))
|
||||
}
|
||||
num := fd.Number()
|
||||
if fd.IsExtension() {
|
||||
if fd != m.ext[num] {
|
||||
m.ext[num] = fd
|
||||
m.known[num] = fd.(pref.ExtensionTypeDescriptor).Type().New()
|
||||
}
|
||||
return m.known[num]
|
||||
}
|
||||
if v, ok := m.known[num]; ok {
|
||||
return v
|
||||
}
|
||||
m.clearOtherOneofFields(fd)
|
||||
m.known[num] = m.NewField(fd)
|
||||
if fd.IsExtension() {
|
||||
@ -164,22 +178,16 @@ func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
|
||||
// See protoreflect.Message for details.
|
||||
func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
|
||||
m.checkField(fd)
|
||||
switch {
|
||||
case fd.IsExtension():
|
||||
if m.known == nil {
|
||||
panic(errors.New("%v: modification of read-only message", fd.FullName()))
|
||||
}
|
||||
if fd.IsExtension() {
|
||||
if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
|
||||
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
|
||||
}
|
||||
m.ext[fd.Number()] = fd
|
||||
case fd.IsMap():
|
||||
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
|
||||
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
|
||||
}
|
||||
case fd.IsList():
|
||||
if list, ok := v.Interface().(*dynamicList); !ok || list.desc != fd {
|
||||
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
|
||||
}
|
||||
default:
|
||||
typecheckSingular(fd, v)
|
||||
} else {
|
||||
typecheck(fd, v)
|
||||
}
|
||||
m.clearOtherOneofFields(fd)
|
||||
m.known[fd.Number()] = v
|
||||
@ -251,6 +259,9 @@ func (m *Message) GetUnknown() pref.RawFields {
|
||||
// SetUnknown sets the raw unknown fields.
|
||||
// See protoreflect.Message for details.
|
||||
func (m *Message) SetUnknown(r pref.RawFields) {
|
||||
if m.known == nil {
|
||||
panic(errors.New("%v: modification of read-only message", m.desc.FullName()))
|
||||
}
|
||||
m.unknown = r
|
||||
}
|
||||
|
||||
@ -406,7 +417,43 @@ func isSet(fd pref.FieldDescriptor, v pref.Value) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func typecheck(fd pref.FieldDescriptor, v pref.Value) {
|
||||
if err := typeIsValid(fd, v); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
|
||||
switch {
|
||||
case fd.IsMap():
|
||||
if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
|
||||
return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
|
||||
}
|
||||
return nil
|
||||
case fd.IsList():
|
||||
switch list := v.Interface().(type) {
|
||||
case *dynamicList:
|
||||
if list.desc == fd {
|
||||
return nil
|
||||
}
|
||||
case emptyList:
|
||||
if list.desc == fd {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
|
||||
default:
|
||||
return singularTypeIsValid(fd, v)
|
||||
}
|
||||
}
|
||||
|
||||
func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
|
||||
if err := singularTypeIsValid(fd, v); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func singularTypeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
|
||||
vi := v.Interface()
|
||||
var ok bool
|
||||
switch fd.Kind() {
|
||||
@ -435,12 +482,16 @@ func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
|
||||
var m pref.Message
|
||||
m, ok = vi.(pref.Message)
|
||||
if ok && m.Descriptor().FullName() != fd.Message().FullName() {
|
||||
panic(errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName()))
|
||||
return errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())
|
||||
}
|
||||
if dm, ok := vi.(*Message); ok && dm.known == nil {
|
||||
return errors.New("%v: assigning invalid zero-value message", fd.FullName())
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
|
||||
return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newListEntry(fd pref.FieldDescriptor) pref.Value {
|
||||
@ -470,3 +521,102 @@ func newListEntry(fd pref.FieldDescriptor) pref.Value {
|
||||
}
|
||||
panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
|
||||
}
|
||||
|
||||
// extensionType is a dynamic protoreflect.ExtensionType.
|
||||
type extensionType struct {
|
||||
desc extensionTypeDescriptor
|
||||
}
|
||||
|
||||
// NewExtensionType creates a new ExtensionType with the provided descriptor.
|
||||
//
|
||||
// Dynamic ExtensionTypes with the same descriptor compare as equal. That is,
|
||||
// if xd1 == xd2, then NewExtensionType(xd1) == NewExtensionType(xd2).
|
||||
//
|
||||
// The InterfaceOf and ValueOf methods of the extension type are defined as:
|
||||
//
|
||||
// func (xt extensionType) ValueOf(iv interface{}) protoreflect.Value {
|
||||
// return protoreflect.ValueOf(iv)
|
||||
// }
|
||||
//
|
||||
// func (xt extensionType) InterfaceOf(v protoreflect.Value) interface{} {
|
||||
// return v.Interface()
|
||||
// }
|
||||
//
|
||||
// The Go type used by the proto.GetExtension and proto.SetExtension functions
|
||||
// is determined by these methods, and is therefore equivalent to the Go type
|
||||
// used to represent a protoreflect.Value. See the protoreflect.Value
|
||||
// documentation for more details.
|
||||
func NewExtensionType(desc pref.ExtensionDescriptor) pref.ExtensionType {
|
||||
if xt, ok := desc.(pref.ExtensionTypeDescriptor); ok {
|
||||
desc = xt.Descriptor()
|
||||
}
|
||||
return extensionType{extensionTypeDescriptor{desc}}
|
||||
}
|
||||
|
||||
func (xt extensionType) New() pref.Value {
|
||||
switch {
|
||||
case xt.desc.IsMap():
|
||||
return pref.ValueOf(&dynamicMap{
|
||||
desc: xt.desc,
|
||||
mapv: make(map[interface{}]pref.Value),
|
||||
})
|
||||
case xt.desc.IsList():
|
||||
return pref.ValueOf(&dynamicList{desc: xt.desc})
|
||||
case xt.desc.Message() != nil:
|
||||
return pref.ValueOf(New(xt.desc.Message()))
|
||||
default:
|
||||
return xt.desc.Default()
|
||||
}
|
||||
}
|
||||
|
||||
func (xt extensionType) Zero() pref.Value {
|
||||
switch {
|
||||
case xt.desc.IsMap():
|
||||
return pref.ValueOf(&dynamicMap{desc: xt.desc})
|
||||
case xt.desc.Cardinality() == pref.Repeated:
|
||||
return pref.ValueOf(emptyList{desc: xt.desc})
|
||||
case xt.desc.Message() != nil:
|
||||
return pref.ValueOf(&Message{desc: xt.desc.Message()})
|
||||
default:
|
||||
return xt.desc.Default()
|
||||
}
|
||||
}
|
||||
|
||||
func (xt extensionType) GoType() reflect.Type {
|
||||
return reflect.TypeOf(xt.InterfaceOf(xt.New()))
|
||||
}
|
||||
|
||||
func (xt extensionType) TypeDescriptor() pref.ExtensionTypeDescriptor {
|
||||
return xt.desc
|
||||
}
|
||||
|
||||
func (xt extensionType) ValueOf(iv interface{}) pref.Value {
|
||||
v := pref.ValueOf(iv)
|
||||
typecheck(xt.desc, v)
|
||||
return v
|
||||
}
|
||||
|
||||
func (xt extensionType) InterfaceOf(v pref.Value) interface{} {
|
||||
typecheck(xt.desc, v)
|
||||
return v.Interface()
|
||||
}
|
||||
|
||||
func (xt extensionType) IsValidInterface(iv interface{}) bool {
|
||||
return typeIsValid(xt.desc, pref.ValueOf(iv)) == nil
|
||||
}
|
||||
|
||||
func (xt extensionType) IsValidValue(v pref.Value) bool {
|
||||
return typeIsValid(xt.desc, v) == nil
|
||||
}
|
||||
|
||||
type extensionTypeDescriptor struct {
|
||||
pref.ExtensionDescriptor
|
||||
}
|
||||
|
||||
func (xt extensionTypeDescriptor) Type() pref.ExtensionType {
|
||||
return extensionType{xt}
|
||||
}
|
||||
|
||||
func (xt extensionTypeDescriptor) Descriptor() pref.ExtensionDescriptor {
|
||||
return xt.ExtensionDescriptor
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
pref "google.golang.org/protobuf/reflect/protoreflect"
|
||||
preg "google.golang.org/protobuf/reflect/protoregistry"
|
||||
"google.golang.org/protobuf/testing/prototest"
|
||||
"google.golang.org/protobuf/types/dynamicpb"
|
||||
|
||||
@ -24,3 +26,37 @@ func TestConformance(t *testing.T) {
|
||||
prototest.TestMessage(t, dynamicpb.New(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicExtensions(t *testing.T) {
|
||||
file, err := preg.GlobalFiles.FindFileByPath("test/ext.proto")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
|
||||
opts := prototest.MessageOptions{
|
||||
Resolver: extResolver{},
|
||||
}
|
||||
for i := 0; i < file.Extensions().Len(); i++ {
|
||||
opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
|
||||
}
|
||||
prototest.TestMessage(t, dynamicpb.New(md), opts)
|
||||
}
|
||||
|
||||
type extResolver struct{}
|
||||
|
||||
func (extResolver) FindExtensionByName(field pref.FullName) (pref.ExtensionType, error) {
|
||||
xt, err := preg.GlobalTypes.FindExtensionByName(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
|
||||
}
|
||||
|
||||
func (extResolver) FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error) {
|
||||
xt, err := preg.GlobalTypes.FindExtensionByNumber(message, field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user