From 73618879f4294820aef0e4f0cf2137e8c3d547a2 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 4 Oct 2019 14:58:46 -0700 Subject: [PATCH] reflect/protoregistry: protect global registries with a lock The global registry is initialized via generated code. The Go language guarantees that these are serialized (non concurrently). The main concern is when a concurrent read operation occurs while registration is still ongoing. In such a case, we do need a lock to serialize the read with regard to the writes (i.e. registrations). Change-Id: Ied35d6f8d2620f448cb281c3ec46d8de893b5671 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/199217 Reviewed-by: Damien Neil --- reflect/protoregistry/registry.go | 83 +++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index a4df65b4..cba713f5 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -20,6 +20,7 @@ import ( "log" "reflect" "strings" + "sync" "google.golang.org/protobuf/internal/errors" "google.golang.org/protobuf/reflect/protoreflect" @@ -37,6 +38,8 @@ var ignoreConflict = func(d protoreflect.Descriptor, err error) bool { return true } +var globalMutex sync.RWMutex + // GlobalFiles is a global registry of file descriptors. var GlobalFiles *Files = new(Files) @@ -87,6 +90,10 @@ func NewFiles(files ...protoreflect.FileDescriptor) *Files { // // It is permitted for multiple files to have the same file path. func (r *Files) Register(files ...protoreflect.FileDescriptor) error { + if r == GlobalFiles { + globalMutex.Lock() + defer globalMutex.Unlock() + } if r.descsByName == nil { r.descsByName = map[protoreflect.FullName]interface{}{ "": &packageDescriptor{}, @@ -161,6 +168,10 @@ func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.D if r == nil { return nil, NotFound } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } prefix := name suffix := nameSuffix("") for prefix != "" { @@ -249,6 +260,10 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) if r == nil { return nil, NotFound } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } if fd, ok := r.filesByPath[path]; ok { return fd, nil } @@ -260,6 +275,10 @@ func (r *Files) NumFiles() int { if r == nil { return 0 } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } return len(r.filesByPath) } @@ -269,6 +288,10 @@ func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) { if r == nil { return } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } for _, file := range r.filesByPath { if !f(file) { return @@ -281,6 +304,10 @@ func (r *Files) NumFilesByPackage(name protoreflect.FullName) int { if r == nil { return 0 } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } p, ok := r.descsByName[name].(*packageDescriptor) if !ok { return 0 @@ -294,6 +321,10 @@ func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protorefl if r == nil { return } + if r == GlobalFiles { + globalMutex.RLock() + defer globalMutex.RUnlock() + } p, ok := r.descsByName[name].(*packageDescriptor) if !ok { return @@ -441,6 +472,10 @@ func NewTypes(typs ...Type) *Types { // (e.g., two different types have the same full name), // then the first type takes precedence and an error is returned. func (r *Types) Register(typs ...Type) error { + if r == GlobalTypes { + globalMutex.Lock() + defer globalMutex.Unlock() + } var firstErr error typeLoop: for _, typ := range typs { @@ -525,6 +560,10 @@ func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumTyp if r == nil { return nil, NotFound } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } if v := r.typesByName[enum]; v != nil { if et, _ := v.(protoreflect.EnumType); et != nil { return et, nil @@ -551,6 +590,10 @@ func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { if r == nil { return nil, NotFound } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } message := protoreflect.FullName(url) if i := strings.LastIndexByte(url, '/'); i >= 0 { message = message[i+len("/"):] @@ -575,6 +618,10 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E if r == nil { return nil, NotFound } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } if v := r.typesByName[field]; v != nil { if xt, _ := v.(protoreflect.ExtensionType); xt != nil { return xt, nil @@ -592,6 +639,10 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto if r == nil { return nil, NotFound } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } if xt, ok := r.extensionsByMessage[message][field]; ok { return xt, nil } @@ -603,6 +654,10 @@ func (r *Types) NumEnums() int { if r == nil { return 0 } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } return r.numEnums } @@ -612,6 +667,10 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) { if r == nil { return } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } for _, typ := range r.typesByName { if et, ok := typ.(protoreflect.EnumType); ok { if !f(et) { @@ -626,6 +685,10 @@ func (r *Types) NumMessages() int { if r == nil { return 0 } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } return r.numMessages } @@ -635,6 +698,10 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) { if r == nil { return } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } for _, typ := range r.typesByName { if mt, ok := typ.(protoreflect.MessageType); ok { if !f(mt) { @@ -649,6 +716,10 @@ func (r *Types) NumExtensions() int { if r == nil { return 0 } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } return r.numExtensions } @@ -658,6 +729,10 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) { if r == nil { return } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } for _, typ := range r.typesByName { if xt, ok := typ.(protoreflect.ExtensionType); ok { if !f(xt) { @@ -673,6 +748,10 @@ func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int { if r == nil { return 0 } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } return len(r.extensionsByMessage[message]) } @@ -682,6 +761,10 @@ func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(p if r == nil { return } + if r == GlobalTypes { + globalMutex.RLock() + defer globalMutex.RUnlock() + } for _, xt := range r.extensionsByMessage[message] { if !f(xt) { return