internal/impl: fix race in aberrant message logic

Previously, when aberrantLoadMessageDesc returned it was guaranteed
to have initialized the current message through the use of the done signal.
However, this does not guarantee that the descriptor for a cylic reference
has also finished initialization.

Rather than add more complicated logic to wait until all cyclic references
have finished initializing, just add a global lock for the entire
aberrantLoadMessageDesc function.

This slows down performance, but is easier to reason about.

Change-Id: I4cdae8b955f71ee40fa6979f5a8d548d9749042c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/184657
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-07-02 10:51:24 -07:00
parent 3274acc926
commit 32e8a52cbf
2 changed files with 57 additions and 34 deletions

View File

@ -7,12 +7,14 @@ package impl_test
import (
"io"
"reflect"
"sync"
"testing"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/types/descriptorpb"
@ -286,3 +288,35 @@ func TestAberrant(t *testing.T) {
t.Errorf("mismatching descriptor:\ngot %v\nwant %v", got, want)
}
}
type AberrantMessage1 struct {
M *AberrantMessage2 `protobuf:"bytes,1,opt,name=message"`
}
type AberrantMessage2 struct {
M *AberrantMessage1 `protobuf:"bytes,1,opt,name=message"`
}
func TestAberrantRace(t *testing.T) {
var gotMD1, wantMD1, gotMD2, wantMD2 protoreflect.MessageDescriptor
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage1{}))
wantMD2 = md.Fields().Get(0).Message()
gotMD2 = wantMD2.Fields().Get(0).Message().Fields().Get(0).Message()
}()
go func() {
defer wg.Done()
md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage2{}))
wantMD1 = md.Fields().Get(0).Message()
gotMD1 = wantMD1.Fields().Get(0).Message().Fields().Get(0).Message()
}()
wg.Wait()
if gotMD1 != wantMD1 || gotMD2 != wantMD2 {
t.Errorf("mismatching exact message descriptors")
}
}

View File

@ -59,9 +59,6 @@ var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDesc
//
// This is exported for testing purposes.
func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
return legacyLoadMessageDesc(t, true)
}
func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
// Fast-path: check if a MessageDescriptor is cached for this concrete type.
if mi, ok := legacyMessageDescCache.Load(t); ok {
return mi.(pref.MessageDescriptor)
@ -74,7 +71,7 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
}
mdV1, ok := mv.(messageV1)
if !ok {
return aberrantLoadMessageDesc(t, finalized)
return aberrantLoadMessageDesc(t)
}
b, idxs := mdV1.Descriptor()
@ -88,16 +85,10 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
return md
}
var aberrantMessageDescCache sync.Map // map[reflect.Type]aberrantMessageDesc
// aberrantMessageDesc is a tuple containing a MessageDescriptor and a channel
// to signal whether the descriptor is initialized. For external lookups,
// we must ensure that the descriptor is fully initialized. For internal lookups
// to resolve cycles, we only need to obtain the descriptor reference.
type aberrantMessageDesc struct {
desc protoreflect.MessageDescriptor
done chan struct{} // closed when desc is fully initialized
}
var (
aberrantMessageDescLock sync.Mutex
aberrantMessageDescCache map[reflect.Type]protoreflect.MessageDescriptor
)
// aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
// which must not implement protoreflect.ProtoMessage or messageV1.
@ -107,31 +98,27 @@ type aberrantMessageDesc struct {
//
// The finalized flag determines whether the returned message descriptor must
// be fully initialized.
func aberrantLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
func aberrantLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
aberrantMessageDescLock.Lock()
defer aberrantMessageDescLock.Unlock()
if aberrantMessageDescCache == nil {
aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
}
return aberrantLoadMessageDescReentrant(t)
}
func aberrantLoadMessageDescReentrant(t reflect.Type) pref.MessageDescriptor {
// Fast-path: check if an MessageDescriptor is cached for this concrete type.
if mdi, ok := aberrantMessageDescCache.Load(t); ok {
if finalized {
<-mdi.(aberrantMessageDesc).done
}
return mdi.(aberrantMessageDesc).desc
if md, ok := aberrantMessageDescCache[t]; ok {
return md
}
// Medium-path: create an initial descriptor and cache it immediately,
// so that cyclic references can be resolved. Each descriptor is paired
// with a channel to signal when the descriptor is fully initialized.
md := &filedesc.Message{L2: new(filedesc.MessageL2)}
mdi := aberrantMessageDesc{desc: md, done: make(chan struct{})}
if mdi, ok := aberrantMessageDescCache.LoadOrStore(t, mdi); ok {
if finalized {
<-mdi.(aberrantMessageDesc).done
}
return mdi.(aberrantMessageDesc).desc
}
defer func() { close(mdi.done) }()
// Slow-path: construct a descriptor from the Go struct type (best-effort).
// Cache the MessageDescriptor early on so that we can resolve internal
// cyclic references.
md := &filedesc.Message{L2: new(filedesc.MessageL2)}
md.L0.FullName = aberrantDeriveFullName(t.Elem())
md.L0.ParentFile = filedesc.SurrogateProto2
aberrantMessageDescCache[t] = md
// Try to determine if the message is using proto3 by checking scalars.
for i := 0; i < t.Elem().NumField(); i++ {
@ -257,6 +244,8 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
switch v := reflect.Zero(t).Interface().(type) {
case pref.ProtoMessage:
fd.L1.Message = v.ProtoReflect().Descriptor()
case messageV1:
fd.L1.Message = LegacyLoadMessageDesc(t)
default:
if t.Kind() == reflect.Map {
n := len(md.L1.Messages.List)
@ -280,7 +269,7 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
fd.L1.Message = md2
break
}
fd.L1.Message = aberrantLoadMessageDesc(t, false)
fd.L1.Message = aberrantLoadMessageDescReentrant(t)
}
}
}