reflect/protoregistry: add Num methods for every Range method

The Num methods provide an O(1) lookup for the number of entries that Range
would return. This is needed to implement efficient cache invalidation logic
for caches that wrap the global registry.

Change-Id: I7c4ff97f674c4e9e4caae291f017cfad7294856c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193599
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-09-05 10:19:36 -07:00
parent ea5ada15be
commit 72980ee410
2 changed files with 113 additions and 29 deletions

View File

@ -255,23 +255,39 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error)
return nil, NotFound return nil, NotFound
} }
// NumFiles reports the number of registered files.
func (r *Files) NumFiles() int {
if r == nil {
return 0
}
return len(r.filesByPath)
}
// RangeFiles iterates over all registered files. // RangeFiles iterates over all registered files.
// The iteration order is undefined. // The iteration order is undefined.
func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) { func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
if r == nil { if r == nil {
return return
} }
for _, d := range r.descsByName { for _, file := range r.filesByPath {
if p, ok := d.(*packageDescriptor); ok { if !f(file) {
for _, file := range p.files { return
if !f(file) {
return
}
}
} }
} }
} }
// NumFilesByPackage reports the number of registered files in a proto package.
func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
if r == nil {
return 0
}
p, ok := r.descsByName[name].(*packageDescriptor)
if !ok {
return 0
}
return len(p.files)
}
// RangeFilesByPackage iterates over all registered files in a give proto package. // RangeFilesByPackage iterates over all registered files in a give proto package.
// The iteration order is undefined. // The iteration order is undefined.
func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) { func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
@ -399,6 +415,10 @@ type Types struct {
typesByName typesByName typesByName typesByName
extensionsByMessage extensionsByMessage extensionsByMessage extensionsByMessage
numEnums int
numMessages int
numExtensions int
} }
type ( type (
@ -428,13 +448,17 @@ typeLoop:
case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType: case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
// Check for conflicts in typesByName. // Check for conflicts in typesByName.
var desc protoreflect.Descriptor var desc protoreflect.Descriptor
var pcnt *int
switch t := typ.(type) { switch t := typ.(type) {
case protoreflect.EnumType: case protoreflect.EnumType:
desc = t.Descriptor() desc = t.Descriptor()
pcnt = &r.numEnums
case protoreflect.MessageType: case protoreflect.MessageType:
desc = t.Descriptor() desc = t.Descriptor()
pcnt = &r.numMessages
case protoreflect.ExtensionType: case protoreflect.ExtensionType:
desc = t.TypeDescriptor() desc = t.TypeDescriptor()
pcnt = &r.numExtensions
default: default:
panic(fmt.Sprintf("invalid type: %T", t)) panic(fmt.Sprintf("invalid type: %T", t))
} }
@ -478,11 +502,12 @@ typeLoop:
r.extensionsByMessage[message][field] = xt r.extensionsByMessage[message][field] = xt
} }
// Update typesByName. // Update typesByName and the count.
if r.typesByName == nil { if r.typesByName == nil {
r.typesByName = make(typesByName) r.typesByName = make(typesByName)
} }
r.typesByName[name] = typ r.typesByName[name] = typ
(*pcnt)++
default: default:
if firstErr == nil { if firstErr == nil {
firstErr = errors.New("invalid type: %v", typeName(typ)) firstErr = errors.New("invalid type: %v", typeName(typ))
@ -573,6 +598,14 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto
return nil, NotFound return nil, NotFound
} }
// NumEnums reports the number of registered enums.
func (r *Types) NumEnums() int {
if r == nil {
return 0
}
return r.numEnums
}
// RangeEnums iterates over all registered enums. // RangeEnums iterates over all registered enums.
// Iteration order is undefined. // Iteration order is undefined.
func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) { func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
@ -588,6 +621,14 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
} }
} }
// NumMessages reports the number of registered messages.
func (r *Types) NumMessages() int {
if r == nil {
return 0
}
return r.numMessages
}
// RangeMessages iterates over all registered messages. // RangeMessages iterates over all registered messages.
// Iteration order is undefined. // Iteration order is undefined.
func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) { func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
@ -603,6 +644,14 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
} }
} }
// NumExtensions reports the number of registered extensions.
func (r *Types) NumExtensions() int {
if r == nil {
return 0
}
return r.numExtensions
}
// RangeExtensions iterates over all registered extensions. // RangeExtensions iterates over all registered extensions.
// Iteration order is undefined. // Iteration order is undefined.
func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) { func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
@ -618,6 +667,15 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
} }
} }
// NumExtensionsByMessage reports the number of registered extensions for
// a given message type.
func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
if r == nil {
return 0
}
return len(r.extensionsByMessage[message])
}
// RangeExtensionsByMessage iterates over all registered extensions filtered // RangeExtensionsByMessage iterates over all registered extensions filtered
// by a given message type. Iteration order is undefined. // by a given message type. Iteration order is undefined.
func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) { func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {

View File

@ -298,10 +298,16 @@ func TestFiles(t *testing.T) {
for _, tc := range tt.rangePkgs { for _, tc := range tt.rangePkgs {
var gotFiles []file var gotFiles []file
var gotCnt int
wantCnt := files.NumFilesByPackage(tc.inPkg)
files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool { files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()}) gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
gotCnt++
return true return true
}) })
if gotCnt != wantCnt {
t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
}
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" { if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff) t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
} }
@ -552,44 +558,59 @@ func TestTypes(t *testing.T) {
return x == y return x == y
}) })
t.Run("RangeMessages", func(t *testing.T) {
want := []preg.Type{mt1}
var got []preg.Type
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
return true
})
diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeEnums", func(t *testing.T) { t.Run("RangeEnums", func(t *testing.T) {
want := []preg.Type{et1} want := []preg.Type{et1}
var got []preg.Type var got []preg.Type
var gotCnt int
wantCnt := registry.NumEnums()
registry.RangeEnums(func(et pref.EnumType) bool { registry.RangeEnums(func(et pref.EnumType) bool {
got = append(got, et) got = append(got, et)
gotCnt++
return true return true
}) })
diff := cmp.Diff(want, got, sortTypes, compare) if gotCnt != wantCnt {
if diff != "" { t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff) t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
} }
}) })
t.Run("RangeMessages", func(t *testing.T) {
want := []preg.Type{mt1}
var got []preg.Type
var gotCnt int
wantCnt := registry.NumMessages()
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeExtensions", func(t *testing.T) { t.Run("RangeExtensions", func(t *testing.T) {
want := []preg.Type{xt1, xt2} want := []preg.Type{xt1, xt2}
var got []preg.Type var got []preg.Type
var gotCnt int
wantCnt := registry.NumExtensions()
registry.RangeExtensions(func(xt pref.ExtensionType) bool { registry.RangeExtensions(func(xt pref.ExtensionType) bool {
got = append(got, xt) got = append(got, xt)
gotCnt++
return true return true
}) })
diff := cmp.Diff(want, got, sortTypes, compare) if gotCnt != wantCnt {
if diff != "" { t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff) t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
} }
}) })
@ -597,13 +618,18 @@ func TestTypes(t *testing.T) {
t.Run("RangeExtensionsByMessage", func(t *testing.T) { t.Run("RangeExtensionsByMessage", func(t *testing.T) {
want := []preg.Type{xt1, xt2} want := []preg.Type{xt1, xt2}
var got []preg.Type var got []preg.Type
registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool { var gotCnt int
wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
got = append(got, xt) got = append(got, xt)
gotCnt++
return true return true
}) })
diff := cmp.Diff(want, got, sortTypes, compare) if gotCnt != wantCnt {
if diff != "" { t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff) t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
} }
}) })