diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go index dfca236a..ac727e89 100644 --- a/internal/cmd/generate-types/main.go +++ b/internal/cmd/generate-types/main.go @@ -227,11 +227,19 @@ var fileinitDescListTemplate = template.Must(template.New("").Funcs(template.Fun return len(p.list) } func (p *{{$nameList}}) Get(i int) {{.Expr}} { + {{- if (eq . "Message")}} + return p.list[i].asDesc() + {{- else}} return &p.list[i] + {{- end}} } func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} { if d := p.lazyInit().byName[s]; d != nil { + {{- if (eq . "Message")}} + return d.asDesc() + {{- else}} return d + {{- end}} } return nil } diff --git a/internal/fileinit/desc.go b/internal/fileinit/desc.go index 6d3b1b49..ec70095c 100644 --- a/internal/fileinit/desc.go +++ b/internal/fileinit/desc.go @@ -108,7 +108,8 @@ type FileBuilder struct { // in "flattened ordering". EnumOutputTypes []pref.EnumType // MessageOutputTypes is where Init stores all initialized message types - // in "flattened ordering"; this includes map entry types. + // in "flattened ordering". This includes slots for map entry messages, + // which are skipped over. MessageOutputTypes []pref.MessageType // ExtensionOutputTypes is where Init stores all initialized extension types // in "flattened ordering". @@ -141,7 +142,9 @@ func (fb FileBuilder) Init() pref.FileDescriptor { fb.EnumOutputTypes[i] = &fd.allEnums[i] } for i := range fd.allMessages { - fb.MessageOutputTypes[i] = &fd.allMessages[i] + if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil { + fb.MessageOutputTypes[i] = mt + } } for i := range fd.allExtensions { fb.ExtensionOutputTypes[i] = &fd.allExtensions[i] @@ -160,8 +163,10 @@ func (fb FileBuilder) Init() pref.FileDescriptor { } } for i := range fd.allMessages { - if err := fb.TypesRegistry.Register(&fd.allMessages[i]); err != nil { - panic(err) + if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil { + if err := fb.TypesRegistry.Register(mt); err != nil { + panic(err) + } } } for i := range fd.allExtensions { @@ -278,6 +283,11 @@ func (ed *enumValueDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s func (ed *enumValueDesc) ProtoType(pref.EnumValueDescriptor) {} type ( + messageType struct{ *messageDesc } + messageDescriptor struct{ *messageDesc } + + // messageDesc does not implement protoreflect.Descriptor to avoid + // accidental usages of it as such. Use the asDesc method to retrieve one. messageDesc struct { baseDesc @@ -285,13 +295,13 @@ type ( messages messageDescs extensions extensionDescs - lazy *messageLazy // protected by fileDesc.once + isMapEntry bool + lazy *messageLazy // protected by fileDesc.once } messageLazy struct { typ reflect.Type new func() pref.Message - isMapEntry bool isMessageSet bool fields fieldDescs oneofs oneofDescs @@ -328,12 +338,10 @@ type ( } ) -func (md *messageDesc) GoType() reflect.Type { return md.lazyInit().typ } -func (md *messageDesc) New() pref.Message { return md.lazyInit().new() } -func (md *messageDesc) Options() pref.OptionsMessage { +func (md *messageDesc) options() pref.OptionsMessage { return unmarshalOptions(ptype.X.MessageOptions(), md.lazyInit().options) } -func (md *messageDesc) IsMapEntry() bool { return md.lazyInit().isMapEntry } +func (md *messageDesc) IsMapEntry() bool { return md.isMapEntry } func (md *messageDesc) Fields() pref.FieldDescriptors { return &md.lazyInit().fields } func (md *messageDesc) Oneofs() pref.OneofDescriptors { return &md.lazyInit().oneofs } func (md *messageDesc) ReservedNames() pref.Names { return &md.lazyInit().resvNames } @@ -346,8 +354,8 @@ func (md *messageDesc) ExtensionRangeOptions(i int) pref.OptionsMessage { func (md *messageDesc) Enums() pref.EnumDescriptors { return &md.enums } func (md *messageDesc) Messages() pref.MessageDescriptors { return &md.messages } func (md *messageDesc) Extensions() pref.ExtensionDescriptors { return &md.extensions } -func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md) } func (md *messageDesc) ProtoType(pref.MessageDescriptor) {} +func (md *messageDesc) Format(s fmt.State, r rune) { pfmt.FormatDesc(s, r, md.asDesc()) } func (md *messageDesc) lazyInit() *messageLazy { md.parentFile.lazyInit() // implicitly initializes messageLazy return md.lazy @@ -359,6 +367,19 @@ func (md *messageDesc) IsMessageSet() bool { return md.lazyInit().isMessageSet } +// asDesc returns a protoreflect.MessageDescriptor or protoreflect.MessageType +// depending on whether the message is a map entry or not. +func (mb *messageDesc) asDesc() pref.MessageDescriptor { + if !mb.isMapEntry { + return messageType{mb} + } + return messageDescriptor{mb} +} +func (mt messageType) GoType() reflect.Type { return mt.lazyInit().typ } +func (mt messageType) New() pref.Message { return mt.lazyInit().new() } +func (mt messageType) Options() pref.OptionsMessage { return mt.options() } +func (md messageDescriptor) Options() pref.OptionsMessage { return md.options() } + func (fd *fieldDesc) Options() pref.OptionsMessage { return unmarshalOptions(ptype.X.FieldOptions(), fd.options) } diff --git a/internal/fileinit/desc_init.go b/internal/fileinit/desc_init.go index 7d149985..d1310a40 100644 --- a/internal/fileinit/desc_init.go +++ b/internal/fileinit/desc_init.go @@ -19,6 +19,13 @@ func newFileDesc(fb FileBuilder) *fileDesc { file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes)) file.unmarshalSeed(fb.RawDescriptor) + // Determine which message descriptors represent map entries based on the + // lack of an associated Go type. + messageDecls := file.GoTypes[len(file.allEnums):] + for i := range file.allMessages { + file.allMessages[i].isMapEntry = messageDecls[i] == nil + } + // Extended message dependencies are eagerly handled since registration // needs this information at program init time. for i := range file.allExtensions { @@ -31,7 +38,7 @@ func newFileDesc(fb FileBuilder) *fileDesc { } // initDecls pre-allocates slices for the exact number of enums, messages -// (excluding map entries), and extensions declared in the proto file. +// (including map entries), and extensions declared in the proto file. // This is done to avoid regrowing the slice, which would change the address // for any previously seen declaration. // @@ -279,7 +286,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd for i := range md.enums.list { _, n := wire.ConsumeVarint(b) v, m := wire.ConsumeBytes(b[n:]) - md.enums.list[i].unmarshalSeed(v, nb, pf, md, i) + md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i) b = b[n+m:] } } @@ -288,7 +295,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd for i := range md.messages.list { _, n := wire.ConsumeVarint(b) v, m := wire.ConsumeBytes(b[n:]) - md.messages.list[i].unmarshalSeed(v, nb, pf, md, i) + md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i) b = b[n+m:] } } @@ -297,7 +304,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd for i := range md.extensions.list { _, n := wire.ConsumeVarint(b) v, m := wire.ConsumeBytes(b[n:]) - md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i) + md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i) b = b[n+m:] } } diff --git a/internal/fileinit/desc_lazy.go b/internal/fileinit/desc_lazy.go index 4eef5468..7356b323 100644 --- a/internal/fileinit/desc_lazy.go +++ b/internal/fileinit/desc_lazy.go @@ -64,15 +64,12 @@ func (file *fileDesc) resolveMessages() { md := &file.allMessages[i] // Associate the MessageType with a concrete Go type. - // - // Note that descriptors for map entries, which have no associated - // Go type, also implement the protoreflect.MessageType interface, - // but have a GoType accessor that reports nil. Calling New results - // in a panic, which is sensible behavior. - md.lazy.typ = reflect.TypeOf(messageDecls[i]) - md.lazy.new = func() pref.Message { - t := md.lazy.typ.Elem() - return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect() + if !md.isMapEntry { + md.lazy.typ = reflect.TypeOf(messageDecls[i]) + md.lazy.new = func() pref.Message { + t := md.lazy.typ.Elem() + return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect() + } } // Resolve message field dependencies. @@ -173,9 +170,9 @@ func (file *fileDesc) resolveExtensions() { // Resolve extension field dependency. switch xd.lazy.kind { case pref.EnumKind: - xd.lazy.enumType = file.popEnumDependency() + xd.lazy.enumType = file.popEnumDependency().(pref.EnumType) case pref.MessageKind, pref.GroupKind: - xd.lazy.messageType = file.popMessageDependency() + xd.lazy.messageType = file.popMessageDependency().(pref.MessageType) } xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType)) } @@ -219,8 +216,8 @@ func (fd *fileDesc) isMapEntry(md pref.MessageDescriptor) bool { if md == nil { return false } - if md, ok := md.(*messageDesc); ok && md.parentFile == fd { - return md.lazy.isMapEntry + if md, ok := md.(*messageDescriptor); ok && md.parentFile == fd { + return md.isMapEntry } return md.IsMapEntry() } @@ -238,7 +235,7 @@ func (fd *fileDesc) enumValuesOf(ed pref.EnumDescriptor) pref.EnumValueDescripto return ed.Values() } -func (fd *fileDesc) popEnumDependency() pref.EnumType { +func (fd *fileDesc) popEnumDependency() pref.EnumDescriptor { depIdx := fd.popDependencyIndex() if depIdx < len(fd.allEnums)+len(fd.allMessages) { return &fd.allEnums[depIdx] @@ -247,10 +244,10 @@ func (fd *fileDesc) popEnumDependency() pref.EnumType { } } -func (fd *fileDesc) popMessageDependency() pref.MessageType { +func (fd *fileDesc) popMessageDependency() pref.MessageDescriptor { depIdx := fd.popDependencyIndex() if depIdx < len(fd.allEnums)+len(fd.allMessages) { - return &fd.allMessages[depIdx-len(fd.allEnums)] + return fd.allMessages[depIdx-len(fd.allEnums)].asDesc() } else { return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx]) } @@ -490,6 +487,7 @@ func (vd *enumValueDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc, func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) { var rawFields, rawOneofs [][]byte var enumIdx, messageIdx, extensionIdx int + var isMapEntry bool md.lazy = new(messageLazy) for len(b) > 0 { num, typ, n := wire.ConsumeTag(b) @@ -521,7 +519,7 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) { md.extensions.list[extensionIdx].unmarshalFull(v, nb) extensionIdx++ case fieldnum.DescriptorProto_Options: - md.unmarshalOptions(v) + md.unmarshalOptions(v, &isMapEntry) } default: m := wire.ConsumeFieldValue(num, typ, b) @@ -534,21 +532,25 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) { md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs)) for i, b := range rawFields { fd := &md.lazy.fields.list[i] - fd.unmarshalFull(b, nb, md.parentFile, md, i) + fd.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i) if fd.cardinality == pref.Required { md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number) } } for i, b := range rawOneofs { od := &md.lazy.oneofs.list[i] - od.unmarshalFull(b, nb, md.parentFile, md, i) + od.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i) } } - md.parentFile.lazy.byName[md.FullName()] = md + if isMapEntry != md.isMapEntry { + panic("mismatching map entry property") + } + + md.parentFile.lazy.byName[md.FullName()] = md.asDesc() } -func (md *messageDesc) unmarshalOptions(b []byte) { +func (md *messageDesc) unmarshalOptions(b []byte, isMapEntry *bool) { md.lazy.options = append(md.lazy.options, b...) for len(b) > 0 { num, typ, n := wire.ConsumeTag(b) @@ -559,7 +561,7 @@ func (md *messageDesc) unmarshalOptions(b []byte) { b = b[m:] switch num { case fieldnum.MessageOptions_MapEntry: - md.lazy.isMapEntry = wire.DecodeBool(v) + *isMapEntry = wire.DecodeBool(v) case fieldnum.MessageOptions_MessageSetWireFormat: md.lazy.isMessageSet = wire.DecodeBool(v) } @@ -646,7 +648,7 @@ func (fd *fieldDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc, pd p // In messageDesc.UnmarshalFull, we allocate slices for both // the field and oneof descriptors before unmarshaling either // of them. This ensures pointers to slice elements are stable. - od := &pd.(*messageDesc).lazy.oneofs.list[v] + od := &pd.(messageType).lazy.oneofs.list[v] od.fields.list = append(od.fields.list, fd) if fd.oneofType != nil { panic("oneof type already set") diff --git a/internal/fileinit/desc_list_gen.go b/internal/fileinit/desc_list_gen.go index 5ec5663f..367b9218 100644 --- a/internal/fileinit/desc_list_gen.go +++ b/internal/fileinit/desc_list_gen.go @@ -110,11 +110,11 @@ func (p *messageDescs) Len() int { return len(p.list) } func (p *messageDescs) Get(i int) protoreflect.MessageDescriptor { - return &p.list[i] + return p.list[i].asDesc() } func (p *messageDescs) ByName(s protoreflect.Name) protoreflect.MessageDescriptor { if d := p.lazyInit().byName[s]; d != nil { - return d + return d.asDesc() } return nil } diff --git a/internal/fileinit/fileinit_test.go b/internal/fileinit/fileinit_test.go index f5a99230..e900168f 100644 --- a/internal/fileinit/fileinit_test.go +++ b/internal/fileinit/fileinit_test.go @@ -68,6 +68,15 @@ func TestInit(t *testing.T) { } } + // Verify that message descriptors for map entries have no Go type info. + mapEntryName := protoreflect.FullName("goproto.proto.test.TestAllTypes.MapInt32Int32Entry") + d := testpb.File_test_test_proto.DescriptorByName(mapEntryName) + if _, ok := d.(protoreflect.MessageDescriptor); !ok { + t.Errorf("message descriptor for %v not found", mapEntryName) + } + if _, ok := d.(protoreflect.MessageType); ok { + t.Errorf("message descriptor for %v must not implement protoreflect.MessageType", mapEntryName) + } } // visitFields calls f for every field set in m and its children.