diff --git a/internal/legacy/enum.go b/internal/legacy/enum.go index 98fd0fa5..23eebd8e 100644 --- a/internal/legacy/enum.go +++ b/internal/legacy/enum.go @@ -45,8 +45,10 @@ func loadEnumType(t reflect.Type) pref.EnumType { m.Store(n, e) return e }) - enumTypeCache.Store(t, et) - return et.(pref.EnumType) + if et, ok := enumTypeCache.LoadOrStore(t, et); ok { + return et.(pref.EnumType) + } + return et } type enumWrapper struct { @@ -83,8 +85,8 @@ var enumNumberType = reflect.TypeOf(pref.EnumNumber(0)) // which must be an int32 kind and not implement the v2 API already. func loadEnumDesc(t reflect.Type) pref.EnumDescriptor { // Fast-path: check if an EnumDescriptor is cached for this concrete type. - if v, ok := enumDescCache.Load(t); ok { - return v.(pref.EnumDescriptor) + if ed, ok := enumDescCache.Load(t); ok { + return ed.(pref.EnumDescriptor) } // Slow-path: initialize EnumDescriptor from the proto descriptor. @@ -157,6 +159,8 @@ func loadEnumDesc(t reflect.Type) pref.EnumDescriptor { if err != nil { panic(err) } - enumDescCache.Store(t, ed) + if ed, ok := enumDescCache.LoadOrStore(t, ed); ok { + return ed.(pref.EnumDescriptor) + } return ed } diff --git a/internal/legacy/extension.go b/internal/legacy/extension.go index 839d5973..0c88c832 100644 --- a/internal/legacy/extension.go +++ b/internal/legacy/extension.go @@ -134,7 +134,9 @@ func extensionDescFromType(t pref.ExtensionType) *papi.ExtensionDesc { Tag: ptag.Marshal(t, enumName), Filename: filename, } - extensionDescCache.Store(t, d) + if d, ok := extensionDescCache.LoadOrStore(t, d); ok { + return d.(*papi.ExtensionDesc) + } return d } @@ -145,10 +147,6 @@ func extensionDescFromType(t pref.ExtensionType) *papi.ExtensionDesc { func extensionTypeFromDesc(d *papi.ExtensionDesc) pref.ExtensionType { // Fast-path: check whether an extension type is already nested within. if d.Type != nil { - // Cache descriptor for future extensionDescFromType operation. - // This assumes that there is only one legacy protoapi.ExtensionDesc - // that wraps any given specific protoreflect.ExtensionType. - extensionDescCache.LoadOrStore(d.Type, d) return d.Type } @@ -192,8 +190,10 @@ func extensionTypeFromDesc(d *papi.ExtensionDesc) pref.ExtensionType { xt := pimpl.Export{}.ExtensionTypeOf(xd, zv) // Cache the conversion for both directions. - extensionDescCache.Store(xt, d) - extensionTypeCache.Store(dk, xt) + extensionDescCache.LoadOrStore(xt, d) + if xt, ok := extensionTypeCache.LoadOrStore(dk, xt); ok { + return xt.(pref.ExtensionType) + } return xt } diff --git a/internal/legacy/file.go b/internal/legacy/file.go index 5c798104..c438c1d2 100644 --- a/internal/legacy/file.go +++ b/internal/legacy/file.go @@ -42,12 +42,12 @@ var fileDescCache sync.Map // map[*byte]*descriptorpb.FileDescriptorProto // File descriptors generated by protoc-gen-go do not rely on that property. func loadFileDesc(b []byte) *descriptorpb.FileDescriptorProto { // Fast-path: check whether we already have a cached file descriptor. - if v, ok := fileDescCache.Load(&b[0]); ok { - return v.(*descriptorpb.FileDescriptorProto) + if fd, ok := fileDescCache.Load(&b[0]); ok { + return fd.(*descriptorpb.FileDescriptorProto) } // Slow-path: decompress and unmarshal the file descriptor proto. - m := new(descriptorpb.FileDescriptorProto) + fd := new(descriptorpb.FileDescriptorProto) zr, err := gzip.NewReader(bytes.NewReader(b)) if err != nil { panic(err) @@ -56,12 +56,14 @@ func loadFileDesc(b []byte) *descriptorpb.FileDescriptorProto { if err != nil { panic(err) } - err = proto.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) + err = proto.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, fd) if err != nil { panic(err) } - fileDescCache.Store(&b[0], m) - return m + if fd, ok := fileDescCache.LoadOrStore(&b[0], fd); ok { + return fd.(*descriptorpb.FileDescriptorProto) + } + return fd } // parentFileDescriptor returns the parent protoreflect.FileDescriptor for the diff --git a/internal/legacy/legacy_test.go b/internal/legacy/legacy_test.go new file mode 100644 index 00000000..2f4dbbc2 --- /dev/null +++ b/internal/legacy/legacy_test.go @@ -0,0 +1,98 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package legacy + +import ( + "sync" + "testing" + + "github.com/golang/protobuf/v2/reflect/protoreflect" +) + +type ( + MessageA struct { + A1 *MessageA `protobuf:"bytes,1,req,name=a1"` + A2 *MessageB `protobuf:"bytes,2,req,name=a2"` + A3 Enum `protobuf:"varint,3,opt,name=a3,enum=legacy.Enum"` + } + MessageB struct { + B1 *MessageA `protobuf:"bytes,1,req,name=b1"` + B2 *MessageB `protobuf:"bytes,2,req,name=b2"` + B3 Enum `protobuf:"varint,3,opt,name=b3,enum=legacy.Enum"` + } + Enum int32 +) + +// TestConcurrentInit tests that concurrent wrapping of multiple legacy types +// results in the exact same descriptor being created. +func TestConcurrentInit(t *testing.T) { + const numParallel = 5 + var messageATypes [numParallel]protoreflect.MessageType + var messageBTypes [numParallel]protoreflect.MessageType + var enumTypes [numParallel]protoreflect.EnumType + + // Concurrently load message and enum types. + var wg sync.WaitGroup + for i := 0; i < numParallel; i++ { + i := i + wg.Add(3) + go func() { + defer wg.Done() + messageATypes[i] = Export{}.MessageTypeOf((*MessageA)(nil)) + }() + go func() { + defer wg.Done() + messageBTypes[i] = Export{}.MessageTypeOf((*MessageB)(nil)) + }() + go func() { + defer wg.Done() + enumTypes[i] = Export{}.EnumTypeOf(Enum(0)) + }() + } + wg.Wait() + + var ( + wantMTA = messageATypes[0] + wantMDA = messageATypes[0].Fields().ByNumber(1).MessageType() + wantMTB = messageBTypes[0] + wantMDB = messageBTypes[0].Fields().ByNumber(2).MessageType() + wantET = enumTypes[0] + wantED = messageATypes[0].Fields().ByNumber(3).EnumType() + ) + + for _, gotMT := range messageATypes[1:] { + if gotMT != wantMTA { + t.Error("MessageType(MessageA) mismatch") + } + if gotMDA := gotMT.Fields().ByNumber(1).MessageType(); gotMDA != wantMDA { + t.Error("MessageDescriptor(MessageA) mismatch") + } + if gotMDB := gotMT.Fields().ByNumber(2).MessageType(); gotMDB != wantMDB { + t.Error("MessageDescriptor(MessageB) mismatch") + } + if gotED := gotMT.Fields().ByNumber(3).EnumType(); gotED != wantED { + t.Error("EnumDescriptor(Enum) mismatch") + } + } + for _, gotMT := range messageBTypes[1:] { + if gotMT != wantMTB { + t.Error("MessageType(MessageB) mismatch") + } + if gotMDA := gotMT.Fields().ByNumber(1).MessageType(); gotMDA != wantMDA { + t.Error("MessageDescriptor(MessageA) mismatch") + } + if gotMDB := gotMT.Fields().ByNumber(2).MessageType(); gotMDB != wantMDB { + t.Error("MessageDescriptor(MessageB) mismatch") + } + if gotED := gotMT.Fields().ByNumber(3).EnumType(); gotED != wantED { + t.Error("EnumDescriptor(Enum) mismatch") + } + } + for _, gotET := range enumTypes[1:] { + if gotET != wantET { + t.Error("EnumType(Enum) mismatch") + } + } +} diff --git a/internal/legacy/message.go b/internal/legacy/message.go index c604ec88..cc200533 100644 --- a/internal/legacy/message.go +++ b/internal/legacy/message.go @@ -45,10 +45,13 @@ func loadMessageType(t reflect.Type) *pimpl.MessageType { p := reflect.New(t.Elem()).Interface() return mt.MessageOf(p) }) - messageTypeCache.Store(t, mt) + if mt, ok := messageTypeCache.LoadOrStore(t, mt); ok { + return mt.(*pimpl.MessageType) + } return mt } +var messageDescLock sync.Mutex var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor // loadMessageDesc returns an MessageDescriptor derived from the Go type, @@ -70,6 +73,15 @@ func (ms messageDescSet) Load(t reflect.Type) pref.MessageDescriptor { } // Slow-path: initialize MessageDescriptor from the Go type. + // + // Hold a global lock during message creation to ensure that each Go type + // maps to exactly one MessageDescriptor. After obtaining the lock, we must + // check again whether the message has already been handled. + messageDescLock.Lock() + defer messageDescLock.Unlock() + if mi, ok := messageDescCache.Load(t); ok { + return mi.(pref.MessageDescriptor) + } // Processing t recursively populates descs and types with all sub-messages. // The descriptor for the first type is guaranteed to be at the front.