internal/fileinit: prevent map entry descriptors from implementing MessageType

The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.

Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.

This CL makes it such that message descriptors for map entries
do not implement MessageType.

Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-03-19 17:04:06 -07:00 committed by Joe Tsai
parent 300cff08c7
commit 4532dd7969
6 changed files with 87 additions and 40 deletions

View File

@ -227,11 +227,19 @@ var fileinitDescListTemplate = template.Must(template.New("").Funcs(template.Fun
return len(p.list)
}
func (p *{{$nameList}}) Get(i int) {{.Expr}} {
{{- if (eq . "Message")}}
return p.list[i].asDesc()
{{- else}}
return &p.list[i]
{{- end}}
}
func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
if d := p.lazyInit().byName[s]; d != nil {
{{- if (eq . "Message")}}
return d.asDesc()
{{- else}}
return d
{{- end}}
}
return nil
}

View File

@ -108,7 +108,8 @@ type FileBuilder struct {
// in "flattened ordering".
EnumOutputTypes []pref.EnumType
// MessageOutputTypes is where Init stores all initialized message types
// in "flattened ordering"; this includes map entry types.
// in "flattened ordering". This includes slots for map entry messages,
// which are skipped over.
MessageOutputTypes []pref.MessageType
// ExtensionOutputTypes is where Init stores all initialized extension types
// in "flattened ordering".
@ -141,7 +142,9 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
fb.EnumOutputTypes[i] = &fd.allEnums[i]
}
for i := range fd.allMessages {
fb.MessageOutputTypes[i] = &fd.allMessages[i]
if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
fb.MessageOutputTypes[i] = mt
}
}
for i := range fd.allExtensions {
fb.ExtensionOutputTypes[i] = &fd.allExtensions[i]
@ -160,8 +163,10 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
}
}
for i := range fd.allMessages {
if err := fb.TypesRegistry.Register(&fd.allMessages[i]); err != nil {
panic(err)
if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
if err := fb.TypesRegistry.Register(mt); err != nil {
panic(err)
}
}
}
for i := range fd.allExtensions {
@ -278,6 +283,11 @@ func (ed *enumValueDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s
func (ed *enumValueDesc) ProtoType(pref.EnumValueDescriptor) {}
type (
messageType struct{ *messageDesc }
messageDescriptor struct{ *messageDesc }
// messageDesc does not implement protoreflect.Descriptor to avoid
// accidental usages of it as such. Use the asDesc method to retrieve one.
messageDesc struct {
baseDesc
@ -285,13 +295,13 @@ type (
messages messageDescs
extensions extensionDescs
lazy *messageLazy // protected by fileDesc.once
isMapEntry bool
lazy *messageLazy // protected by fileDesc.once
}
messageLazy struct {
typ reflect.Type
new func() pref.Message
isMapEntry bool
isMessageSet bool
fields fieldDescs
oneofs oneofDescs
@ -328,12 +338,10 @@ type (
}
)
func (md *messageDesc) GoType() reflect.Type { return md.lazyInit().typ }
func (md *messageDesc) New() pref.Message { return md.lazyInit().new() }
func (md *messageDesc) Options() pref.OptionsMessage {
func (md *messageDesc) options() pref.OptionsMessage {
return unmarshalOptions(ptype.X.MessageOptions(), md.lazyInit().options)
}
func (md *messageDesc) IsMapEntry() bool { return md.lazyInit().isMapEntry }
func (md *messageDesc) IsMapEntry() bool { return md.isMapEntry }
func (md *messageDesc) Fields() pref.FieldDescriptors { return &md.lazyInit().fields }
func (md *messageDesc) Oneofs() pref.OneofDescriptors { return &md.lazyInit().oneofs }
func (md *messageDesc) ReservedNames() pref.Names { return &md.lazyInit().resvNames }
@ -346,8 +354,8 @@ func (md *messageDesc) ExtensionRangeOptions(i int) pref.OptionsMessage {
func (md *messageDesc) Enums() pref.EnumDescriptors { return &md.enums }
func (md *messageDesc) Messages() pref.MessageDescriptors { return &md.messages }
func (md *messageDesc) Extensions() pref.ExtensionDescriptors { return &md.extensions }
func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md) }
func (md *messageDesc) ProtoType(pref.MessageDescriptor) {}
func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md.asDesc()) }
func (md *messageDesc) lazyInit() *messageLazy {
md.parentFile.lazyInit() // implicitly initializes messageLazy
return md.lazy
@ -359,6 +367,19 @@ func (md *messageDesc) IsMessageSet() bool {
return md.lazyInit().isMessageSet
}
// asDesc returns a protoreflect.MessageDescriptor or protoreflect.MessageType
// depending on whether the message is a map entry or not.
func (mb *messageDesc) asDesc() pref.MessageDescriptor {
if !mb.isMapEntry {
return messageType{mb}
}
return messageDescriptor{mb}
}
func (mt messageType) GoType() reflect.Type { return mt.lazyInit().typ }
func (mt messageType) New() pref.Message { return mt.lazyInit().new() }
func (mt messageType) Options() pref.OptionsMessage { return mt.options() }
func (md messageDescriptor) Options() pref.OptionsMessage { return md.options() }
func (fd *fieldDesc) Options() pref.OptionsMessage {
return unmarshalOptions(ptype.X.FieldOptions(), fd.options)
}

View File

@ -19,6 +19,13 @@ func newFileDesc(fb FileBuilder) *fileDesc {
file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes))
file.unmarshalSeed(fb.RawDescriptor)
// Determine which message descriptors represent map entries based on the
// lack of an associated Go type.
messageDecls := file.GoTypes[len(file.allEnums):]
for i := range file.allMessages {
file.allMessages[i].isMapEntry = messageDecls[i] == nil
}
// Extended message dependencies are eagerly handled since registration
// needs this information at program init time.
for i := range file.allExtensions {
@ -31,7 +38,7 @@ func newFileDesc(fb FileBuilder) *fileDesc {
}
// initDecls pre-allocates slices for the exact number of enums, messages
// (excluding map entries), and extensions declared in the proto file.
// (including map entries), and extensions declared in the proto file.
// This is done to avoid regrowing the slice, which would change the address
// for any previously seen declaration.
//
@ -279,7 +286,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.enums.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
md.enums.list[i].unmarshalSeed(v, nb, pf, md, i)
md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}
@ -288,7 +295,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.messages.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
md.messages.list[i].unmarshalSeed(v, nb, pf, md, i)
md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}
@ -297,7 +304,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.extensions.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i)
md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}

View File

@ -64,15 +64,12 @@ func (file *fileDesc) resolveMessages() {
md := &file.allMessages[i]
// Associate the MessageType with a concrete Go type.
//
// Note that descriptors for map entries, which have no associated
// Go type, also implement the protoreflect.MessageType interface,
// but have a GoType accessor that reports nil. Calling New results
// in a panic, which is sensible behavior.
md.lazy.typ = reflect.TypeOf(messageDecls[i])
md.lazy.new = func() pref.Message {
t := md.lazy.typ.Elem()
return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
if !md.isMapEntry {
md.lazy.typ = reflect.TypeOf(messageDecls[i])
md.lazy.new = func() pref.Message {
t := md.lazy.typ.Elem()
return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
}
}
// Resolve message field dependencies.
@ -173,9 +170,9 @@ func (file *fileDesc) resolveExtensions() {
// Resolve extension field dependency.
switch xd.lazy.kind {
case pref.EnumKind:
xd.lazy.enumType = file.popEnumDependency()
xd.lazy.enumType = file.popEnumDependency().(pref.EnumType)
case pref.MessageKind, pref.GroupKind:
xd.lazy.messageType = file.popMessageDependency()
xd.lazy.messageType = file.popMessageDependency().(pref.MessageType)
}
xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType))
}
@ -219,8 +216,8 @@ func (fd *fileDesc) isMapEntry(md pref.MessageDescriptor) bool {
if md == nil {
return false
}
if md, ok := md.(*messageDesc); ok && md.parentFile == fd {
return md.lazy.isMapEntry
if md, ok := md.(*messageDescriptor); ok && md.parentFile == fd {
return md.isMapEntry
}
return md.IsMapEntry()
}
@ -238,7 +235,7 @@ func (fd *fileDesc) enumValuesOf(ed pref.EnumDescriptor) pref.EnumValueDescripto
return ed.Values()
}
func (fd *fileDesc) popEnumDependency() pref.EnumType {
func (fd *fileDesc) popEnumDependency() pref.EnumDescriptor {
depIdx := fd.popDependencyIndex()
if depIdx < len(fd.allEnums)+len(fd.allMessages) {
return &fd.allEnums[depIdx]
@ -247,10 +244,10 @@ func (fd *fileDesc) popEnumDependency() pref.EnumType {
}
}
func (fd *fileDesc) popMessageDependency() pref.MessageType {
func (fd *fileDesc) popMessageDependency() pref.MessageDescriptor {
depIdx := fd.popDependencyIndex()
if depIdx < len(fd.allEnums)+len(fd.allMessages) {
return &fd.allMessages[depIdx-len(fd.allEnums)]
return fd.allMessages[depIdx-len(fd.allEnums)].asDesc()
} else {
return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx])
}
@ -490,6 +487,7 @@ func (vd *enumValueDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc,
func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
var rawFields, rawOneofs [][]byte
var enumIdx, messageIdx, extensionIdx int
var isMapEntry bool
md.lazy = new(messageLazy)
for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b)
@ -521,7 +519,7 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
md.extensions.list[extensionIdx].unmarshalFull(v, nb)
extensionIdx++
case fieldnum.DescriptorProto_Options:
md.unmarshalOptions(v)
md.unmarshalOptions(v, &isMapEntry)
}
default:
m := wire.ConsumeFieldValue(num, typ, b)
@ -534,21 +532,25 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs))
for i, b := range rawFields {
fd := &md.lazy.fields.list[i]
fd.unmarshalFull(b, nb, md.parentFile, md, i)
fd.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
if fd.cardinality == pref.Required {
md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number)
}
}
for i, b := range rawOneofs {
od := &md.lazy.oneofs.list[i]
od.unmarshalFull(b, nb, md.parentFile, md, i)
od.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
}
}
md.parentFile.lazy.byName[md.FullName()] = md
if isMapEntry != md.isMapEntry {
panic("mismatching map entry property")
}
md.parentFile.lazy.byName[md.FullName()] = md.asDesc()
}
func (md *messageDesc) unmarshalOptions(b []byte) {
func (md *messageDesc) unmarshalOptions(b []byte, isMapEntry *bool) {
md.lazy.options = append(md.lazy.options, b...)
for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b)
@ -559,7 +561,7 @@ func (md *messageDesc) unmarshalOptions(b []byte) {
b = b[m:]
switch num {
case fieldnum.MessageOptions_MapEntry:
md.lazy.isMapEntry = wire.DecodeBool(v)
*isMapEntry = wire.DecodeBool(v)
case fieldnum.MessageOptions_MessageSetWireFormat:
md.lazy.isMessageSet = wire.DecodeBool(v)
}
@ -646,7 +648,7 @@ func (fd *fieldDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc, pd p
// In messageDesc.UnmarshalFull, we allocate slices for both
// the field and oneof descriptors before unmarshaling either
// of them. This ensures pointers to slice elements are stable.
od := &pd.(*messageDesc).lazy.oneofs.list[v]
od := &pd.(messageType).lazy.oneofs.list[v]
od.fields.list = append(od.fields.list, fd)
if fd.oneofType != nil {
panic("oneof type already set")

View File

@ -110,11 +110,11 @@ func (p *messageDescs) Len() int {
return len(p.list)
}
func (p *messageDescs) Get(i int) protoreflect.MessageDescriptor {
return &p.list[i]
return p.list[i].asDesc()
}
func (p *messageDescs) ByName(s protoreflect.Name) protoreflect.MessageDescriptor {
if d := p.lazyInit().byName[s]; d != nil {
return d
return d.asDesc()
}
return nil
}

View File

@ -68,6 +68,15 @@ func TestInit(t *testing.T) {
}
}
// Verify that message descriptors for map entries have no Go type info.
mapEntryName := protoreflect.FullName("goproto.proto.test.TestAllTypes.MapInt32Int32Entry")
d := testpb.File_test_test_proto.DescriptorByName(mapEntryName)
if _, ok := d.(protoreflect.MessageDescriptor); !ok {
t.Errorf("message descriptor for %v not found", mapEntryName)
}
if _, ok := d.(protoreflect.MessageType); ok {
t.Errorf("message descriptor for %v must not implement protoreflect.MessageType", mapEntryName)
}
}
// visitFields calls f for every field set in m and its children.