mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-01-30 03:32:49 +00:00
internal/impl: fix race in aberrant message logic
Previously, when aberrantLoadMessageDesc returned it was guaranteed to have initialized the current message through the use of the done signal. However, this does not guarantee that the descriptor for a cylic reference has also finished initialization. Rather than add more complicated logic to wait until all cyclic references have finished initializing, just add a global lock for the entire aberrantLoadMessageDesc function. This slows down performance, but is easier to reason about. Change-Id: I4cdae8b955f71ee40fa6979f5a8d548d9749042c Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/184657 Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
3274acc926
commit
32e8a52cbf
@ -7,12 +7,14 @@ package impl_test
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
"google.golang.org/protobuf/internal/impl"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protodesc"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/runtime/protoiface"
|
||||
|
||||
"google.golang.org/protobuf/types/descriptorpb"
|
||||
@ -286,3 +288,35 @@ func TestAberrant(t *testing.T) {
|
||||
t.Errorf("mismatching descriptor:\ngot %v\nwant %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
type AberrantMessage1 struct {
|
||||
M *AberrantMessage2 `protobuf:"bytes,1,opt,name=message"`
|
||||
}
|
||||
|
||||
type AberrantMessage2 struct {
|
||||
M *AberrantMessage1 `protobuf:"bytes,1,opt,name=message"`
|
||||
}
|
||||
|
||||
func TestAberrantRace(t *testing.T) {
|
||||
var gotMD1, wantMD1, gotMD2, wantMD2 protoreflect.MessageDescriptor
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage1{}))
|
||||
wantMD2 = md.Fields().Get(0).Message()
|
||||
gotMD2 = wantMD2.Fields().Get(0).Message().Fields().Get(0).Message()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage2{}))
|
||||
wantMD1 = md.Fields().Get(0).Message()
|
||||
gotMD1 = wantMD1.Fields().Get(0).Message().Fields().Get(0).Message()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if gotMD1 != wantMD1 || gotMD2 != wantMD2 {
|
||||
t.Errorf("mismatching exact message descriptors")
|
||||
}
|
||||
}
|
||||
|
@ -59,9 +59,6 @@ var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDesc
|
||||
//
|
||||
// This is exported for testing purposes.
|
||||
func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
|
||||
return legacyLoadMessageDesc(t, true)
|
||||
}
|
||||
func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
|
||||
// Fast-path: check if a MessageDescriptor is cached for this concrete type.
|
||||
if mi, ok := legacyMessageDescCache.Load(t); ok {
|
||||
return mi.(pref.MessageDescriptor)
|
||||
@ -74,7 +71,7 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
|
||||
}
|
||||
mdV1, ok := mv.(messageV1)
|
||||
if !ok {
|
||||
return aberrantLoadMessageDesc(t, finalized)
|
||||
return aberrantLoadMessageDesc(t)
|
||||
}
|
||||
b, idxs := mdV1.Descriptor()
|
||||
|
||||
@ -88,16 +85,10 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
|
||||
return md
|
||||
}
|
||||
|
||||
var aberrantMessageDescCache sync.Map // map[reflect.Type]aberrantMessageDesc
|
||||
|
||||
// aberrantMessageDesc is a tuple containing a MessageDescriptor and a channel
|
||||
// to signal whether the descriptor is initialized. For external lookups,
|
||||
// we must ensure that the descriptor is fully initialized. For internal lookups
|
||||
// to resolve cycles, we only need to obtain the descriptor reference.
|
||||
type aberrantMessageDesc struct {
|
||||
desc protoreflect.MessageDescriptor
|
||||
done chan struct{} // closed when desc is fully initialized
|
||||
}
|
||||
var (
|
||||
aberrantMessageDescLock sync.Mutex
|
||||
aberrantMessageDescCache map[reflect.Type]protoreflect.MessageDescriptor
|
||||
)
|
||||
|
||||
// aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
|
||||
// which must not implement protoreflect.ProtoMessage or messageV1.
|
||||
@ -107,31 +98,27 @@ type aberrantMessageDesc struct {
|
||||
//
|
||||
// The finalized flag determines whether the returned message descriptor must
|
||||
// be fully initialized.
|
||||
func aberrantLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
|
||||
func aberrantLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
|
||||
aberrantMessageDescLock.Lock()
|
||||
defer aberrantMessageDescLock.Unlock()
|
||||
if aberrantMessageDescCache == nil {
|
||||
aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
|
||||
}
|
||||
return aberrantLoadMessageDescReentrant(t)
|
||||
}
|
||||
func aberrantLoadMessageDescReentrant(t reflect.Type) pref.MessageDescriptor {
|
||||
// Fast-path: check if an MessageDescriptor is cached for this concrete type.
|
||||
if mdi, ok := aberrantMessageDescCache.Load(t); ok {
|
||||
if finalized {
|
||||
<-mdi.(aberrantMessageDesc).done
|
||||
}
|
||||
return mdi.(aberrantMessageDesc).desc
|
||||
if md, ok := aberrantMessageDescCache[t]; ok {
|
||||
return md
|
||||
}
|
||||
|
||||
// Medium-path: create an initial descriptor and cache it immediately,
|
||||
// so that cyclic references can be resolved. Each descriptor is paired
|
||||
// with a channel to signal when the descriptor is fully initialized.
|
||||
md := &filedesc.Message{L2: new(filedesc.MessageL2)}
|
||||
mdi := aberrantMessageDesc{desc: md, done: make(chan struct{})}
|
||||
if mdi, ok := aberrantMessageDescCache.LoadOrStore(t, mdi); ok {
|
||||
if finalized {
|
||||
<-mdi.(aberrantMessageDesc).done
|
||||
}
|
||||
return mdi.(aberrantMessageDesc).desc
|
||||
}
|
||||
defer func() { close(mdi.done) }()
|
||||
|
||||
// Slow-path: construct a descriptor from the Go struct type (best-effort).
|
||||
// Cache the MessageDescriptor early on so that we can resolve internal
|
||||
// cyclic references.
|
||||
md := &filedesc.Message{L2: new(filedesc.MessageL2)}
|
||||
md.L0.FullName = aberrantDeriveFullName(t.Elem())
|
||||
md.L0.ParentFile = filedesc.SurrogateProto2
|
||||
aberrantMessageDescCache[t] = md
|
||||
|
||||
// Try to determine if the message is using proto3 by checking scalars.
|
||||
for i := 0; i < t.Elem().NumField(); i++ {
|
||||
@ -257,6 +244,8 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
|
||||
switch v := reflect.Zero(t).Interface().(type) {
|
||||
case pref.ProtoMessage:
|
||||
fd.L1.Message = v.ProtoReflect().Descriptor()
|
||||
case messageV1:
|
||||
fd.L1.Message = LegacyLoadMessageDesc(t)
|
||||
default:
|
||||
if t.Kind() == reflect.Map {
|
||||
n := len(md.L1.Messages.List)
|
||||
@ -280,7 +269,7 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
|
||||
fd.L1.Message = md2
|
||||
break
|
||||
}
|
||||
fd.L1.Message = aberrantLoadMessageDesc(t, false)
|
||||
fd.L1.Message = aberrantLoadMessageDescReentrant(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user