mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-10 07:14:24 +00:00
reflect/protoregistry: add Num methods for every Range method
The Num methods provide an O(1) lookup for the number of entries that Range would return. This is needed to implement efficient cache invalidation logic for caches that wrap the global registry. Change-Id: I7c4ff97f674c4e9e4caae291f017cfad7294856c Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193599 Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
ea5ada15be
commit
72980ee410
@ -255,23 +255,39 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error)
|
||||
return nil, NotFound
|
||||
}
|
||||
|
||||
// NumFiles reports the number of registered files.
|
||||
func (r *Files) NumFiles() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return len(r.filesByPath)
|
||||
}
|
||||
|
||||
// RangeFiles iterates over all registered files.
|
||||
// The iteration order is undefined.
|
||||
func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
for _, d := range r.descsByName {
|
||||
if p, ok := d.(*packageDescriptor); ok {
|
||||
for _, file := range p.files {
|
||||
if !f(file) {
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, file := range r.filesByPath {
|
||||
if !f(file) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NumFilesByPackage reports the number of registered files in a proto package.
|
||||
func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
p, ok := r.descsByName[name].(*packageDescriptor)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return len(p.files)
|
||||
}
|
||||
|
||||
// RangeFilesByPackage iterates over all registered files in a give proto package.
|
||||
// The iteration order is undefined.
|
||||
func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
|
||||
@ -399,6 +415,10 @@ type Types struct {
|
||||
|
||||
typesByName typesByName
|
||||
extensionsByMessage extensionsByMessage
|
||||
|
||||
numEnums int
|
||||
numMessages int
|
||||
numExtensions int
|
||||
}
|
||||
|
||||
type (
|
||||
@ -428,13 +448,17 @@ typeLoop:
|
||||
case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
|
||||
// Check for conflicts in typesByName.
|
||||
var desc protoreflect.Descriptor
|
||||
var pcnt *int
|
||||
switch t := typ.(type) {
|
||||
case protoreflect.EnumType:
|
||||
desc = t.Descriptor()
|
||||
pcnt = &r.numEnums
|
||||
case protoreflect.MessageType:
|
||||
desc = t.Descriptor()
|
||||
pcnt = &r.numMessages
|
||||
case protoreflect.ExtensionType:
|
||||
desc = t.TypeDescriptor()
|
||||
pcnt = &r.numExtensions
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid type: %T", t))
|
||||
}
|
||||
@ -478,11 +502,12 @@ typeLoop:
|
||||
r.extensionsByMessage[message][field] = xt
|
||||
}
|
||||
|
||||
// Update typesByName.
|
||||
// Update typesByName and the count.
|
||||
if r.typesByName == nil {
|
||||
r.typesByName = make(typesByName)
|
||||
}
|
||||
r.typesByName[name] = typ
|
||||
(*pcnt)++
|
||||
default:
|
||||
if firstErr == nil {
|
||||
firstErr = errors.New("invalid type: %v", typeName(typ))
|
||||
@ -573,6 +598,14 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto
|
||||
return nil, NotFound
|
||||
}
|
||||
|
||||
// NumEnums reports the number of registered enums.
|
||||
func (r *Types) NumEnums() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return r.numEnums
|
||||
}
|
||||
|
||||
// RangeEnums iterates over all registered enums.
|
||||
// Iteration order is undefined.
|
||||
func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
|
||||
@ -588,6 +621,14 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// NumMessages reports the number of registered messages.
|
||||
func (r *Types) NumMessages() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return r.numMessages
|
||||
}
|
||||
|
||||
// RangeMessages iterates over all registered messages.
|
||||
// Iteration order is undefined.
|
||||
func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
|
||||
@ -603,6 +644,14 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// NumExtensions reports the number of registered extensions.
|
||||
func (r *Types) NumExtensions() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return r.numExtensions
|
||||
}
|
||||
|
||||
// RangeExtensions iterates over all registered extensions.
|
||||
// Iteration order is undefined.
|
||||
func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
|
||||
@ -618,6 +667,15 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// NumExtensionsByMessage reports the number of registered extensions for
|
||||
// a given message type.
|
||||
func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return len(r.extensionsByMessage[message])
|
||||
}
|
||||
|
||||
// RangeExtensionsByMessage iterates over all registered extensions filtered
|
||||
// by a given message type. Iteration order is undefined.
|
||||
func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
|
||||
|
@ -298,10 +298,16 @@ func TestFiles(t *testing.T) {
|
||||
|
||||
for _, tc := range tt.rangePkgs {
|
||||
var gotFiles []file
|
||||
var gotCnt int
|
||||
wantCnt := files.NumFilesByPackage(tc.inPkg)
|
||||
files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
|
||||
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
|
||||
gotCnt++
|
||||
return true
|
||||
})
|
||||
if gotCnt != wantCnt {
|
||||
t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
|
||||
t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
|
||||
}
|
||||
@ -552,44 +558,59 @@ func TestTypes(t *testing.T) {
|
||||
return x == y
|
||||
})
|
||||
|
||||
t.Run("RangeMessages", func(t *testing.T) {
|
||||
want := []preg.Type{mt1}
|
||||
var got []preg.Type
|
||||
registry.RangeMessages(func(mt pref.MessageType) bool {
|
||||
got = append(got, mt)
|
||||
return true
|
||||
})
|
||||
|
||||
diff := cmp.Diff(want, got, sortTypes, compare)
|
||||
if diff != "" {
|
||||
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RangeEnums", func(t *testing.T) {
|
||||
want := []preg.Type{et1}
|
||||
var got []preg.Type
|
||||
var gotCnt int
|
||||
wantCnt := registry.NumEnums()
|
||||
registry.RangeEnums(func(et pref.EnumType) bool {
|
||||
got = append(got, et)
|
||||
gotCnt++
|
||||
return true
|
||||
})
|
||||
|
||||
diff := cmp.Diff(want, got, sortTypes, compare)
|
||||
if diff != "" {
|
||||
if gotCnt != wantCnt {
|
||||
t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
|
||||
}
|
||||
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
|
||||
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RangeMessages", func(t *testing.T) {
|
||||
want := []preg.Type{mt1}
|
||||
var got []preg.Type
|
||||
var gotCnt int
|
||||
wantCnt := registry.NumMessages()
|
||||
registry.RangeMessages(func(mt pref.MessageType) bool {
|
||||
got = append(got, mt)
|
||||
gotCnt++
|
||||
return true
|
||||
})
|
||||
|
||||
if gotCnt != wantCnt {
|
||||
t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
|
||||
}
|
||||
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
|
||||
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RangeExtensions", func(t *testing.T) {
|
||||
want := []preg.Type{xt1, xt2}
|
||||
var got []preg.Type
|
||||
var gotCnt int
|
||||
wantCnt := registry.NumExtensions()
|
||||
registry.RangeExtensions(func(xt pref.ExtensionType) bool {
|
||||
got = append(got, xt)
|
||||
gotCnt++
|
||||
return true
|
||||
})
|
||||
|
||||
diff := cmp.Diff(want, got, sortTypes, compare)
|
||||
if diff != "" {
|
||||
if gotCnt != wantCnt {
|
||||
t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
|
||||
}
|
||||
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
|
||||
t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
|
||||
}
|
||||
})
|
||||
@ -597,13 +618,18 @@ func TestTypes(t *testing.T) {
|
||||
t.Run("RangeExtensionsByMessage", func(t *testing.T) {
|
||||
want := []preg.Type{xt1, xt2}
|
||||
var got []preg.Type
|
||||
registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool {
|
||||
var gotCnt int
|
||||
wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
|
||||
registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
|
||||
got = append(got, xt)
|
||||
gotCnt++
|
||||
return true
|
||||
})
|
||||
|
||||
diff := cmp.Diff(want, got, sortTypes, compare)
|
||||
if diff != "" {
|
||||
if gotCnt != wantCnt {
|
||||
t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
|
||||
}
|
||||
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
|
||||
t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
|
||||
}
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user