reflect/protoreflect: add MessageFieldTypes

The MessageFieldTypes interface (if implemented by a MessageType)
provides Go type information about the fields if they are
an enum or message type.

Change-Id: I68b20f5726377f6b0f2c20a8b6e45f9802b43f67
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/236777
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2020-06-05 15:31:25 -07:00
parent b5523e32b3
commit 1a290e9a0e
9 changed files with 261 additions and 8 deletions

View File

@ -167,7 +167,7 @@ func (Export) MessageTypeOf(m message) pref.MessageType {
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type()
}
return legacyLoadMessageInfo(reflect.TypeOf(m), "")
return legacyLoadMessageType(reflect.TypeOf(m), "")
}
// MessageStringOf returns the message value as a string,

View File

@ -30,7 +30,7 @@ func (Export) LegacyMessageTypeOf(m piface.MessageV1, name pref.FullName) pref.M
if mv := (Export{}).protoMessageV2Of(m); mv != nil {
return mv.ProtoReflect().Type()
}
return legacyLoadMessageInfo(reflect.TypeOf(m), name)
return legacyLoadMessageType(reflect.TypeOf(m), name)
}
// UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input.

View File

@ -32,6 +32,16 @@ func legacyWrapMessage(v reflect.Value) pref.Message {
return mt.MessageOf(v.Interface())
}
// legacyLoadMessageType dynamically loads a protoreflect.Type for t,
// where t must be not implement the v2 API already.
// The provided name is used if it cannot be determined from the message.
func legacyLoadMessageType(t reflect.Type, name pref.FullName) protoreflect.MessageType {
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
return aberrantMessageType{t}
}
return legacyLoadMessageInfo(t, name)
}
var legacyMessageTypeCache sync.Map // map[reflect.Type]*MessageInfo
// legacyLoadMessageInfo dynamically loads a *MessageInfo for t,

View File

@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
)
// MessageInfo provides protobuf related functionality for a given Go type
@ -212,4 +213,53 @@ func (mi *MessageInfo) New() protoreflect.Message {
func (mi *MessageInfo) Zero() protoreflect.Message {
return mi.MessageOf(reflect.Zero(mi.GoReflectType).Interface())
}
func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor { return mi.Desc }
func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor {
return mi.Desc
}
func (mi *MessageInfo) Enum(i int) protoreflect.EnumType {
mi.init()
fd := mi.Desc.Fields().Get(i)
return Export{}.EnumTypeOf(mi.fieldTypes[fd.Number()])
}
func (mi *MessageInfo) Message(i int) protoreflect.MessageType {
mi.init()
fd := mi.Desc.Fields().Get(i)
switch {
case fd.IsWeak():
mt, _ := preg.GlobalTypes.FindMessageByName(fd.Message().FullName())
return mt
case fd.IsMap():
return mapEntryType{fd.Message(), mi.fieldTypes[fd.Number()]}
default:
return Export{}.MessageTypeOf(mi.fieldTypes[fd.Number()])
}
}
type mapEntryType struct {
desc protoreflect.MessageDescriptor
valType interface{} // zero value of enum or message type
}
func (mt mapEntryType) New() protoreflect.Message {
return nil
}
func (mt mapEntryType) Zero() protoreflect.Message {
return nil
}
func (mt mapEntryType) Descriptor() protoreflect.MessageDescriptor {
return mt.desc
}
func (mt mapEntryType) Enum(i int) protoreflect.EnumType {
fd := mt.desc.Fields().Get(i)
if fd.Enum() == nil {
return nil
}
return Export{}.EnumTypeOf(mt.valType)
}
func (mt mapEntryType) Message(i int) protoreflect.MessageType {
fd := mt.desc.Fields().Get(i)
if fd.Message() == nil {
return nil
}
return Export{}.MessageTypeOf(mt.valType)
}

View File

@ -17,6 +17,11 @@ type reflectMessageInfo struct {
fields map[pref.FieldNumber]*fieldInfo
oneofs map[pref.Name]*oneofInfo
// fieldTypes contains the zero value of an enum or message field.
// For lists, it contains the element type.
// For maps, it contains the entry value type.
fieldTypes map[pref.FieldNumber]interface{}
// denseFields is a subset of fields where:
// 0 < fieldDesc.Number() < len(denseFields)
// It provides faster access to the fieldInfo, but may be incomplete.
@ -37,6 +42,7 @@ func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
mi.makeKnownFieldsFunc(si)
mi.makeUnknownFieldsFunc(t, si)
mi.makeExtensionFieldsFunc(t, si)
mi.makeFieldTypes(si)
}
// makeKnownFieldsFunc generates functions for operations that can be performed
@ -62,7 +68,7 @@ func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
fi = fieldInfoForList(fd, fs, mi.Exporter)
case fd.IsWeak():
fi = fieldInfoForWeakMessage(fd, si.weakOffset)
case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind:
case fd.Message() != nil:
fi = fieldInfoForMessage(fd, fs, mi.Exporter)
default:
fi = fieldInfoForScalar(fd, fs, mi.Exporter)
@ -146,6 +152,45 @@ func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
}
}
}
func (mi *MessageInfo) makeFieldTypes(si structInfo) {
md := mi.Desc
fds := md.Fields()
for i := 0; i < fds.Len(); i++ {
var ft reflect.Type
fd := fds.Get(i)
fs := si.fieldsByNumber[fd.Number()]
switch {
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
if fd.Enum() != nil || fd.Message() != nil {
ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
}
case fd.IsMap():
if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
ft = fs.Type.Elem()
}
case fd.IsList():
if fd.Enum() != nil || fd.Message() != nil {
ft = fs.Type.Elem()
}
case fd.Enum() != nil:
ft = fs.Type
if fd.HasPresence() {
ft = ft.Elem()
}
case fd.Message() != nil:
ft = fs.Type
if fd.IsWeak() {
ft = nil
}
}
if ft != nil {
if mi.fieldTypes == nil {
mi.fieldTypes = make(map[pref.FieldNumber]interface{})
}
mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
}
}
}
type extensionMap map[int32]ExtensionField
@ -313,7 +358,6 @@ var (
// pointer to a named Go struct. If the provided type has a ProtoReflect method,
// it must be implemented by calling this method.
func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
// TODO: Switch the input to be an opaque Pointer.
if reflect.TypeOf(m) != mi.GoReflectType {
panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
}

View File

@ -22,10 +22,15 @@ func (m *IrregularMessage) ProtoReflect() pref.Message { return (*message)(m) }
type message IrregularMessage
func (m *message) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
func (m *message) Type() pref.MessageType { return m }
type messageType struct{}
func (messageType) New() pref.Message { return &message{} }
func (messageType) Zero() pref.Message { return (*message)(nil) }
func (messageType) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
func (m *message) New() pref.Message { return &message{} }
func (m *message) Zero() pref.Message { return (*message)(nil) }
func (m *message) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
func (m *message) Type() pref.MessageType { return messageType{} }
func (m *message) Interface() pref.ProtoMessage { return (*IrregularMessage)(m) }
func (m *message) ProtoMethods() *protoiface.Methods { return nil }

View File

@ -232,11 +232,15 @@ type MessageDescriptor interface {
type isMessageDescriptor interface{ ProtoType(MessageDescriptor) }
// MessageType encapsulates a MessageDescriptor with a concrete Go implementation.
// It is recommended that implementations of this interface also implement the
// MessageFieldTypes interface.
type MessageType interface {
// New returns a newly allocated empty message.
// It may return nil for synthetic messages representing a map entry.
New() Message
// Zero returns an empty, read-only message.
// It may return nil for synthetic messages representing a map entry.
Zero() Message
// Descriptor returns the message descriptor.
@ -245,6 +249,26 @@ type MessageType interface {
Descriptor() MessageDescriptor
}
// MessageFieldTypes extends a MessageType by providing type information
// regarding enums and messages referenced by the message fields.
type MessageFieldTypes interface {
MessageType
// Enum returns the EnumType for the ith field in Descriptor.Fields.
// It returns nil if the ith field is not an enum kind.
// It panics if out of bounds.
//
// Invariant: mt.Enum(i).Descriptor() == mt.Descriptor().Fields(i).Enum()
Enum(i int) EnumType
// Message returns the MessageType for the ith field in Descriptor.Fields.
// It returns nil if the ith field is not a message or group kind.
// It panics if out of bounds.
//
// Invariant: mt.Message(i).Descriptor() == mt.Descriptor().Fields(i).Message()
Message(i int) MessageType
}
// MessageDescriptors is a list of message declarations.
type MessageDescriptors interface {
// Len reports the number of messages.

View File

@ -11,11 +11,13 @@ import (
"math"
"reflect"
"sort"
"strings"
"testing"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
@ -96,6 +98,112 @@ func testType(t testing.TB, mt pref.MessageType) {
if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want {
t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want)
}
if mt, ok := mt.(pref.MessageFieldTypes); ok {
testFieldTypes(t, mt)
}
}
func testFieldTypes(t testing.TB, mt pref.MessageFieldTypes) {
descName := func(d pref.Descriptor) pref.FullName {
if d == nil {
return "<nil>"
}
return d.FullName()
}
typeName := func(mt pref.MessageType) pref.FullName {
if mt == nil {
return "<nil>"
}
return mt.Descriptor().FullName()
}
adjustExpr := func(idx int, expr string) string {
expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1)
expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1)
expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1)
expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1)
return expr
}
checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) {
if got != want {
t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
}
}
checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) {
if got != want {
t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
}
}
checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) {
if got != want {
t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want))
}
}
fds := mt.Descriptor().Fields()
m := mt.New()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
switch {
case fd.IsList():
if fd.Enum() != nil {
checkEnumDesc(i,
"mt.Enum(i).Descriptor()", "fd.Enum()",
mt.Enum(i).Descriptor(), fd.Enum())
}
if fd.Message() != nil {
checkMessageDesc(i,
"mt.Message(i).Descriptor()", "fd.Message()",
mt.Message(i).Descriptor(), fd.Message())
checkMessageType(i,
"mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()",
mt.Message(i), m.NewField(fd).List().NewElement().Message().Type())
}
case fd.IsMap():
mti := mt.Message(i)
if m := mti.New(); m != nil {
checkMessageDesc(i,
"m.Descriptor()", "fd.Message()",
m.Descriptor(), fd.Message())
}
if m := mti.Zero(); m != nil {
checkMessageDesc(i,
"m.Descriptor()", "fd.Message()",
m.Descriptor(), fd.Message())
}
checkMessageDesc(i,
"mti.Descriptor()", "fd.Message()",
mti.Descriptor(), fd.Message())
if mti := mti.(pref.MessageFieldTypes); mti != nil {
if fd.MapValue().Enum() != nil {
checkEnumDesc(i,
"mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()",
mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum())
}
if fd.MapValue().Message() != nil {
checkMessageDesc(i,
"mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()",
mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message())
checkMessageType(i,
"mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()",
mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type())
}
}
default:
if fd.Enum() != nil {
checkEnumDesc(i,
"mt.Enum(i).Descriptor()", "fd.Enum()",
mt.Enum(i).Descriptor(), fd.Enum())
}
if fd.Message() != nil {
checkMessageDesc(i,
"mt.Message(i).Descriptor()", "fd.Message()",
mt.Message(i).Descriptor(), fd.Message())
checkMessageType(i,
"mt.Message(i)", "m.NewField(fd).Message().Type()",
mt.Message(i), m.NewField(fd).Message().Type())
}
}
}
}
// testField exercises set/get/has/clear of a field.

View File

@ -369,6 +369,18 @@ func NewMessageType(desc pref.MessageDescriptor) pref.MessageType {
func (mt messageType) New() pref.Message { return NewMessage(mt.desc) }
func (mt messageType) Zero() pref.Message { return &Message{typ: messageType{mt.desc}} }
func (mt messageType) Descriptor() pref.MessageDescriptor { return mt.desc }
func (mt messageType) Enum(i int) pref.EnumType {
if ed := mt.desc.Fields().Get(i).Enum(); ed != nil {
return NewEnumType(ed)
}
return nil
}
func (mt messageType) Message(i int) pref.MessageType {
if md := mt.desc.Fields().Get(i).Message(); md != nil {
return NewMessageType(md)
}
return nil
}
type emptyList struct {
desc pref.FieldDescriptor