reflect/protoregistry: protect global registries with a lock

The global registry is initialized via generated code.
The Go language guarantees that these are serialized (non concurrently).
The main concern is when a concurrent read operation occurs while
registration is still ongoing. In such a case, we do need a lock to
serialize the read with regard to the writes (i.e. registrations).

Change-Id: Ied35d6f8d2620f448cb281c3ec46d8de893b5671
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/199217
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-10-04 14:58:46 -07:00
parent 8e9d5f6e8a
commit 73618879f4

View File

@ -20,6 +20,7 @@ import (
"log" "log"
"reflect" "reflect"
"strings" "strings"
"sync"
"google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
@ -37,6 +38,8 @@ var ignoreConflict = func(d protoreflect.Descriptor, err error) bool {
return true return true
} }
var globalMutex sync.RWMutex
// GlobalFiles is a global registry of file descriptors. // GlobalFiles is a global registry of file descriptors.
var GlobalFiles *Files = new(Files) var GlobalFiles *Files = new(Files)
@ -87,6 +90,10 @@ func NewFiles(files ...protoreflect.FileDescriptor) *Files {
// //
// It is permitted for multiple files to have the same file path. // It is permitted for multiple files to have the same file path.
func (r *Files) Register(files ...protoreflect.FileDescriptor) error { func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
if r == GlobalFiles {
globalMutex.Lock()
defer globalMutex.Unlock()
}
if r.descsByName == nil { if r.descsByName == nil {
r.descsByName = map[protoreflect.FullName]interface{}{ r.descsByName = map[protoreflect.FullName]interface{}{
"": &packageDescriptor{}, "": &packageDescriptor{},
@ -161,6 +168,10 @@ func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.D
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
prefix := name prefix := name
suffix := nameSuffix("") suffix := nameSuffix("")
for prefix != "" { for prefix != "" {
@ -249,6 +260,10 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error)
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
if fd, ok := r.filesByPath[path]; ok { if fd, ok := r.filesByPath[path]; ok {
return fd, nil return fd, nil
} }
@ -260,6 +275,10 @@ func (r *Files) NumFiles() int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
return len(r.filesByPath) return len(r.filesByPath)
} }
@ -269,6 +288,10 @@ func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
if r == nil { if r == nil {
return return
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
for _, file := range r.filesByPath { for _, file := range r.filesByPath {
if !f(file) { if !f(file) {
return return
@ -281,6 +304,10 @@ func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
p, ok := r.descsByName[name].(*packageDescriptor) p, ok := r.descsByName[name].(*packageDescriptor)
if !ok { if !ok {
return 0 return 0
@ -294,6 +321,10 @@ func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protorefl
if r == nil { if r == nil {
return return
} }
if r == GlobalFiles {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
p, ok := r.descsByName[name].(*packageDescriptor) p, ok := r.descsByName[name].(*packageDescriptor)
if !ok { if !ok {
return return
@ -441,6 +472,10 @@ func NewTypes(typs ...Type) *Types {
// (e.g., two different types have the same full name), // (e.g., two different types have the same full name),
// then the first type takes precedence and an error is returned. // then the first type takes precedence and an error is returned.
func (r *Types) Register(typs ...Type) error { func (r *Types) Register(typs ...Type) error {
if r == GlobalTypes {
globalMutex.Lock()
defer globalMutex.Unlock()
}
var firstErr error var firstErr error
typeLoop: typeLoop:
for _, typ := range typs { for _, typ := range typs {
@ -525,6 +560,10 @@ func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumTyp
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
if v := r.typesByName[enum]; v != nil { if v := r.typesByName[enum]; v != nil {
if et, _ := v.(protoreflect.EnumType); et != nil { if et, _ := v.(protoreflect.EnumType); et != nil {
return et, nil return et, nil
@ -551,6 +590,10 @@ func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
message := protoreflect.FullName(url) message := protoreflect.FullName(url)
if i := strings.LastIndexByte(url, '/'); i >= 0 { if i := strings.LastIndexByte(url, '/'); i >= 0 {
message = message[i+len("/"):] message = message[i+len("/"):]
@ -575,6 +618,10 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
if v := r.typesByName[field]; v != nil { if v := r.typesByName[field]; v != nil {
if xt, _ := v.(protoreflect.ExtensionType); xt != nil { if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
return xt, nil return xt, nil
@ -592,6 +639,10 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto
if r == nil { if r == nil {
return nil, NotFound return nil, NotFound
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
if xt, ok := r.extensionsByMessage[message][field]; ok { if xt, ok := r.extensionsByMessage[message][field]; ok {
return xt, nil return xt, nil
} }
@ -603,6 +654,10 @@ func (r *Types) NumEnums() int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
return r.numEnums return r.numEnums
} }
@ -612,6 +667,10 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
if r == nil { if r == nil {
return return
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
for _, typ := range r.typesByName { for _, typ := range r.typesByName {
if et, ok := typ.(protoreflect.EnumType); ok { if et, ok := typ.(protoreflect.EnumType); ok {
if !f(et) { if !f(et) {
@ -626,6 +685,10 @@ func (r *Types) NumMessages() int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
return r.numMessages return r.numMessages
} }
@ -635,6 +698,10 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
if r == nil { if r == nil {
return return
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
for _, typ := range r.typesByName { for _, typ := range r.typesByName {
if mt, ok := typ.(protoreflect.MessageType); ok { if mt, ok := typ.(protoreflect.MessageType); ok {
if !f(mt) { if !f(mt) {
@ -649,6 +716,10 @@ func (r *Types) NumExtensions() int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
return r.numExtensions return r.numExtensions
} }
@ -658,6 +729,10 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
if r == nil { if r == nil {
return return
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
for _, typ := range r.typesByName { for _, typ := range r.typesByName {
if xt, ok := typ.(protoreflect.ExtensionType); ok { if xt, ok := typ.(protoreflect.ExtensionType); ok {
if !f(xt) { if !f(xt) {
@ -673,6 +748,10 @@ func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
if r == nil { if r == nil {
return 0 return 0
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
return len(r.extensionsByMessage[message]) return len(r.extensionsByMessage[message])
} }
@ -682,6 +761,10 @@ func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(p
if r == nil { if r == nil {
return return
} }
if r == GlobalTypes {
globalMutex.RLock()
defer globalMutex.RUnlock()
}
for _, xt := range r.extensionsByMessage[message] { for _, xt := range r.extensionsByMessage[message] {
if !f(xt) { if !f(xt) {
return return