internal/fileinit: prevent map entry descriptors from implementing MessageType

The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.

Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.

This CL makes it such that message descriptors for map entries
do not implement MessageType.

Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-03-19 17:04:06 -07:00 committed by Joe Tsai
parent 300cff08c7
commit 4532dd7969
6 changed files with 87 additions and 40 deletions

View File

@ -227,11 +227,19 @@ var fileinitDescListTemplate = template.Must(template.New("").Funcs(template.Fun
return len(p.list) return len(p.list)
} }
func (p *{{$nameList}}) Get(i int) {{.Expr}} { func (p *{{$nameList}}) Get(i int) {{.Expr}} {
{{- if (eq . "Message")}}
return p.list[i].asDesc()
{{- else}}
return &p.list[i] return &p.list[i]
{{- end}}
} }
func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} { func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
if d := p.lazyInit().byName[s]; d != nil { if d := p.lazyInit().byName[s]; d != nil {
{{- if (eq . "Message")}}
return d.asDesc()
{{- else}}
return d return d
{{- end}}
} }
return nil return nil
} }

View File

@ -108,7 +108,8 @@ type FileBuilder struct {
// in "flattened ordering". // in "flattened ordering".
EnumOutputTypes []pref.EnumType EnumOutputTypes []pref.EnumType
// MessageOutputTypes is where Init stores all initialized message types // MessageOutputTypes is where Init stores all initialized message types
// in "flattened ordering"; this includes map entry types. // in "flattened ordering". This includes slots for map entry messages,
// which are skipped over.
MessageOutputTypes []pref.MessageType MessageOutputTypes []pref.MessageType
// ExtensionOutputTypes is where Init stores all initialized extension types // ExtensionOutputTypes is where Init stores all initialized extension types
// in "flattened ordering". // in "flattened ordering".
@ -141,7 +142,9 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
fb.EnumOutputTypes[i] = &fd.allEnums[i] fb.EnumOutputTypes[i] = &fd.allEnums[i]
} }
for i := range fd.allMessages { for i := range fd.allMessages {
fb.MessageOutputTypes[i] = &fd.allMessages[i] if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
fb.MessageOutputTypes[i] = mt
}
} }
for i := range fd.allExtensions { for i := range fd.allExtensions {
fb.ExtensionOutputTypes[i] = &fd.allExtensions[i] fb.ExtensionOutputTypes[i] = &fd.allExtensions[i]
@ -160,8 +163,10 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
} }
} }
for i := range fd.allMessages { for i := range fd.allMessages {
if err := fb.TypesRegistry.Register(&fd.allMessages[i]); err != nil { if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
panic(err) if err := fb.TypesRegistry.Register(mt); err != nil {
panic(err)
}
} }
} }
for i := range fd.allExtensions { for i := range fd.allExtensions {
@ -278,6 +283,11 @@ func (ed *enumValueDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s
func (ed *enumValueDesc) ProtoType(pref.EnumValueDescriptor) {} func (ed *enumValueDesc) ProtoType(pref.EnumValueDescriptor) {}
type ( type (
messageType struct{ *messageDesc }
messageDescriptor struct{ *messageDesc }
// messageDesc does not implement protoreflect.Descriptor to avoid
// accidental usages of it as such. Use the asDesc method to retrieve one.
messageDesc struct { messageDesc struct {
baseDesc baseDesc
@ -285,13 +295,13 @@ type (
messages messageDescs messages messageDescs
extensions extensionDescs extensions extensionDescs
lazy *messageLazy // protected by fileDesc.once isMapEntry bool
lazy *messageLazy // protected by fileDesc.once
} }
messageLazy struct { messageLazy struct {
typ reflect.Type typ reflect.Type
new func() pref.Message new func() pref.Message
isMapEntry bool
isMessageSet bool isMessageSet bool
fields fieldDescs fields fieldDescs
oneofs oneofDescs oneofs oneofDescs
@ -328,12 +338,10 @@ type (
} }
) )
func (md *messageDesc) GoType() reflect.Type { return md.lazyInit().typ } func (md *messageDesc) options() pref.OptionsMessage {
func (md *messageDesc) New() pref.Message { return md.lazyInit().new() }
func (md *messageDesc) Options() pref.OptionsMessage {
return unmarshalOptions(ptype.X.MessageOptions(), md.lazyInit().options) return unmarshalOptions(ptype.X.MessageOptions(), md.lazyInit().options)
} }
func (md *messageDesc) IsMapEntry() bool { return md.lazyInit().isMapEntry } func (md *messageDesc) IsMapEntry() bool { return md.isMapEntry }
func (md *messageDesc) Fields() pref.FieldDescriptors { return &md.lazyInit().fields } func (md *messageDesc) Fields() pref.FieldDescriptors { return &md.lazyInit().fields }
func (md *messageDesc) Oneofs() pref.OneofDescriptors { return &md.lazyInit().oneofs } func (md *messageDesc) Oneofs() pref.OneofDescriptors { return &md.lazyInit().oneofs }
func (md *messageDesc) ReservedNames() pref.Names { return &md.lazyInit().resvNames } func (md *messageDesc) ReservedNames() pref.Names { return &md.lazyInit().resvNames }
@ -346,8 +354,8 @@ func (md *messageDesc) ExtensionRangeOptions(i int) pref.OptionsMessage {
func (md *messageDesc) Enums() pref.EnumDescriptors { return &md.enums } func (md *messageDesc) Enums() pref.EnumDescriptors { return &md.enums }
func (md *messageDesc) Messages() pref.MessageDescriptors { return &md.messages } func (md *messageDesc) Messages() pref.MessageDescriptors { return &md.messages }
func (md *messageDesc) Extensions() pref.ExtensionDescriptors { return &md.extensions } func (md *messageDesc) Extensions() pref.ExtensionDescriptors { return &md.extensions }
func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md) }
func (md *messageDesc) ProtoType(pref.MessageDescriptor) {} func (md *messageDesc) ProtoType(pref.MessageDescriptor) {}
func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md.asDesc()) }
func (md *messageDesc) lazyInit() *messageLazy { func (md *messageDesc) lazyInit() *messageLazy {
md.parentFile.lazyInit() // implicitly initializes messageLazy md.parentFile.lazyInit() // implicitly initializes messageLazy
return md.lazy return md.lazy
@ -359,6 +367,19 @@ func (md *messageDesc) IsMessageSet() bool {
return md.lazyInit().isMessageSet return md.lazyInit().isMessageSet
} }
// asDesc returns a protoreflect.MessageDescriptor or protoreflect.MessageType
// depending on whether the message is a map entry or not.
func (mb *messageDesc) asDesc() pref.MessageDescriptor {
if !mb.isMapEntry {
return messageType{mb}
}
return messageDescriptor{mb}
}
func (mt messageType) GoType() reflect.Type { return mt.lazyInit().typ }
func (mt messageType) New() pref.Message { return mt.lazyInit().new() }
func (mt messageType) Options() pref.OptionsMessage { return mt.options() }
func (md messageDescriptor) Options() pref.OptionsMessage { return md.options() }
func (fd *fieldDesc) Options() pref.OptionsMessage { func (fd *fieldDesc) Options() pref.OptionsMessage {
return unmarshalOptions(ptype.X.FieldOptions(), fd.options) return unmarshalOptions(ptype.X.FieldOptions(), fd.options)
} }

View File

@ -19,6 +19,13 @@ func newFileDesc(fb FileBuilder) *fileDesc {
file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes)) file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes))
file.unmarshalSeed(fb.RawDescriptor) file.unmarshalSeed(fb.RawDescriptor)
// Determine which message descriptors represent map entries based on the
// lack of an associated Go type.
messageDecls := file.GoTypes[len(file.allEnums):]
for i := range file.allMessages {
file.allMessages[i].isMapEntry = messageDecls[i] == nil
}
// Extended message dependencies are eagerly handled since registration // Extended message dependencies are eagerly handled since registration
// needs this information at program init time. // needs this information at program init time.
for i := range file.allExtensions { for i := range file.allExtensions {
@ -31,7 +38,7 @@ func newFileDesc(fb FileBuilder) *fileDesc {
} }
// initDecls pre-allocates slices for the exact number of enums, messages // initDecls pre-allocates slices for the exact number of enums, messages
// (excluding map entries), and extensions declared in the proto file. // (including map entries), and extensions declared in the proto file.
// This is done to avoid regrowing the slice, which would change the address // This is done to avoid regrowing the slice, which would change the address
// for any previously seen declaration. // for any previously seen declaration.
// //
@ -279,7 +286,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.enums.list { for i := range md.enums.list {
_, n := wire.ConsumeVarint(b) _, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:]) v, m := wire.ConsumeBytes(b[n:])
md.enums.list[i].unmarshalSeed(v, nb, pf, md, i) md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:] b = b[n+m:]
} }
} }
@ -288,7 +295,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.messages.list { for i := range md.messages.list {
_, n := wire.ConsumeVarint(b) _, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:]) v, m := wire.ConsumeBytes(b[n:])
md.messages.list[i].unmarshalSeed(v, nb, pf, md, i) md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:] b = b[n+m:]
} }
} }
@ -297,7 +304,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
for i := range md.extensions.list { for i := range md.extensions.list {
_, n := wire.ConsumeVarint(b) _, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:]) v, m := wire.ConsumeBytes(b[n:])
md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i) md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:] b = b[n+m:]
} }
} }

View File

@ -64,15 +64,12 @@ func (file *fileDesc) resolveMessages() {
md := &file.allMessages[i] md := &file.allMessages[i]
// Associate the MessageType with a concrete Go type. // Associate the MessageType with a concrete Go type.
// if !md.isMapEntry {
// Note that descriptors for map entries, which have no associated md.lazy.typ = reflect.TypeOf(messageDecls[i])
// Go type, also implement the protoreflect.MessageType interface, md.lazy.new = func() pref.Message {
// but have a GoType accessor that reports nil. Calling New results t := md.lazy.typ.Elem()
// in a panic, which is sensible behavior. return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
md.lazy.typ = reflect.TypeOf(messageDecls[i]) }
md.lazy.new = func() pref.Message {
t := md.lazy.typ.Elem()
return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
} }
// Resolve message field dependencies. // Resolve message field dependencies.
@ -173,9 +170,9 @@ func (file *fileDesc) resolveExtensions() {
// Resolve extension field dependency. // Resolve extension field dependency.
switch xd.lazy.kind { switch xd.lazy.kind {
case pref.EnumKind: case pref.EnumKind:
xd.lazy.enumType = file.popEnumDependency() xd.lazy.enumType = file.popEnumDependency().(pref.EnumType)
case pref.MessageKind, pref.GroupKind: case pref.MessageKind, pref.GroupKind:
xd.lazy.messageType = file.popMessageDependency() xd.lazy.messageType = file.popMessageDependency().(pref.MessageType)
} }
xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType)) xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType))
} }
@ -219,8 +216,8 @@ func (fd *fileDesc) isMapEntry(md pref.MessageDescriptor) bool {
if md == nil { if md == nil {
return false return false
} }
if md, ok := md.(*messageDesc); ok && md.parentFile == fd { if md, ok := md.(*messageDescriptor); ok && md.parentFile == fd {
return md.lazy.isMapEntry return md.isMapEntry
} }
return md.IsMapEntry() return md.IsMapEntry()
} }
@ -238,7 +235,7 @@ func (fd *fileDesc) enumValuesOf(ed pref.EnumDescriptor) pref.EnumValueDescripto
return ed.Values() return ed.Values()
} }
func (fd *fileDesc) popEnumDependency() pref.EnumType { func (fd *fileDesc) popEnumDependency() pref.EnumDescriptor {
depIdx := fd.popDependencyIndex() depIdx := fd.popDependencyIndex()
if depIdx < len(fd.allEnums)+len(fd.allMessages) { if depIdx < len(fd.allEnums)+len(fd.allMessages) {
return &fd.allEnums[depIdx] return &fd.allEnums[depIdx]
@ -247,10 +244,10 @@ func (fd *fileDesc) popEnumDependency() pref.EnumType {
} }
} }
func (fd *fileDesc) popMessageDependency() pref.MessageType { func (fd *fileDesc) popMessageDependency() pref.MessageDescriptor {
depIdx := fd.popDependencyIndex() depIdx := fd.popDependencyIndex()
if depIdx < len(fd.allEnums)+len(fd.allMessages) { if depIdx < len(fd.allEnums)+len(fd.allMessages) {
return &fd.allMessages[depIdx-len(fd.allEnums)] return fd.allMessages[depIdx-len(fd.allEnums)].asDesc()
} else { } else {
return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx]) return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx])
} }
@ -490,6 +487,7 @@ func (vd *enumValueDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc,
func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) { func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
var rawFields, rawOneofs [][]byte var rawFields, rawOneofs [][]byte
var enumIdx, messageIdx, extensionIdx int var enumIdx, messageIdx, extensionIdx int
var isMapEntry bool
md.lazy = new(messageLazy) md.lazy = new(messageLazy)
for len(b) > 0 { for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b) num, typ, n := wire.ConsumeTag(b)
@ -521,7 +519,7 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
md.extensions.list[extensionIdx].unmarshalFull(v, nb) md.extensions.list[extensionIdx].unmarshalFull(v, nb)
extensionIdx++ extensionIdx++
case fieldnum.DescriptorProto_Options: case fieldnum.DescriptorProto_Options:
md.unmarshalOptions(v) md.unmarshalOptions(v, &isMapEntry)
} }
default: default:
m := wire.ConsumeFieldValue(num, typ, b) m := wire.ConsumeFieldValue(num, typ, b)
@ -534,21 +532,25 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs)) md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs))
for i, b := range rawFields { for i, b := range rawFields {
fd := &md.lazy.fields.list[i] fd := &md.lazy.fields.list[i]
fd.unmarshalFull(b, nb, md.parentFile, md, i) fd.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
if fd.cardinality == pref.Required { if fd.cardinality == pref.Required {
md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number) md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number)
} }
} }
for i, b := range rawOneofs { for i, b := range rawOneofs {
od := &md.lazy.oneofs.list[i] od := &md.lazy.oneofs.list[i]
od.unmarshalFull(b, nb, md.parentFile, md, i) od.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
} }
} }
md.parentFile.lazy.byName[md.FullName()] = md if isMapEntry != md.isMapEntry {
panic("mismatching map entry property")
}
md.parentFile.lazy.byName[md.FullName()] = md.asDesc()
} }
func (md *messageDesc) unmarshalOptions(b []byte) { func (md *messageDesc) unmarshalOptions(b []byte, isMapEntry *bool) {
md.lazy.options = append(md.lazy.options, b...) md.lazy.options = append(md.lazy.options, b...)
for len(b) > 0 { for len(b) > 0 {
num, typ, n := wire.ConsumeTag(b) num, typ, n := wire.ConsumeTag(b)
@ -559,7 +561,7 @@ func (md *messageDesc) unmarshalOptions(b []byte) {
b = b[m:] b = b[m:]
switch num { switch num {
case fieldnum.MessageOptions_MapEntry: case fieldnum.MessageOptions_MapEntry:
md.lazy.isMapEntry = wire.DecodeBool(v) *isMapEntry = wire.DecodeBool(v)
case fieldnum.MessageOptions_MessageSetWireFormat: case fieldnum.MessageOptions_MessageSetWireFormat:
md.lazy.isMessageSet = wire.DecodeBool(v) md.lazy.isMessageSet = wire.DecodeBool(v)
} }
@ -646,7 +648,7 @@ func (fd *fieldDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc, pd p
// In messageDesc.UnmarshalFull, we allocate slices for both // In messageDesc.UnmarshalFull, we allocate slices for both
// the field and oneof descriptors before unmarshaling either // the field and oneof descriptors before unmarshaling either
// of them. This ensures pointers to slice elements are stable. // of them. This ensures pointers to slice elements are stable.
od := &pd.(*messageDesc).lazy.oneofs.list[v] od := &pd.(messageType).lazy.oneofs.list[v]
od.fields.list = append(od.fields.list, fd) od.fields.list = append(od.fields.list, fd)
if fd.oneofType != nil { if fd.oneofType != nil {
panic("oneof type already set") panic("oneof type already set")

View File

@ -110,11 +110,11 @@ func (p *messageDescs) Len() int {
return len(p.list) return len(p.list)
} }
func (p *messageDescs) Get(i int) protoreflect.MessageDescriptor { func (p *messageDescs) Get(i int) protoreflect.MessageDescriptor {
return &p.list[i] return p.list[i].asDesc()
} }
func (p *messageDescs) ByName(s protoreflect.Name) protoreflect.MessageDescriptor { func (p *messageDescs) ByName(s protoreflect.Name) protoreflect.MessageDescriptor {
if d := p.lazyInit().byName[s]; d != nil { if d := p.lazyInit().byName[s]; d != nil {
return d return d.asDesc()
} }
return nil return nil
} }

View File

@ -68,6 +68,15 @@ func TestInit(t *testing.T) {
} }
} }
// Verify that message descriptors for map entries have no Go type info.
mapEntryName := protoreflect.FullName("goproto.proto.test.TestAllTypes.MapInt32Int32Entry")
d := testpb.File_test_test_proto.DescriptorByName(mapEntryName)
if _, ok := d.(protoreflect.MessageDescriptor); !ok {
t.Errorf("message descriptor for %v not found", mapEntryName)
}
if _, ok := d.(protoreflect.MessageType); ok {
t.Errorf("message descriptor for %v must not implement protoreflect.MessageType", mapEntryName)
}
} }
// visitFields calls f for every field set in m and its children. // visitFields calls f for every field set in m and its children.