internal/impl: fix race over messageState.mi

The messageState.mi field is atomically checked and set
in generated code to the *MessageInfo associated with that message.
However, the messageState type accesses the mi field without
any atomic loads, thus being a potential race.
We fix this by always calling a messageInfo method that performs
a atomic.LoadPointer on the *MessageInfo.

There is no performance effect from this change on x86 since
an atomic.LoadPointer is identical to a MOV instruction.
From an assembly perspective, there was no memory race previously.
However, the lack of an atomic.LoadPointer meant that the compiler
could in theory reorder the "normal" load to produce truly racy code.

Change-Id: I8afefaf35c1916872781abc0239cbb63d62edf16
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189017
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-07-31 12:27:30 -07:00
parent d57568e763
commit 2aea614c5e
5 changed files with 105 additions and 100 deletions

View File

@ -565,13 +565,13 @@ func generateImplMessage() string {
var implMessageTemplate = template.Must(template.New("").Parse(`
{{range . -}}
func (m *{{.}}) Descriptor() protoreflect.MessageDescriptor {
return m.mi.PBType.Descriptor()
return m.messageInfo().PBType.Descriptor()
}
func (m *{{.}}) Type() protoreflect.MessageType {
return m.mi.PBType
return m.messageInfo().PBType
}
func (m *{{.}}) New() protoreflect.Message {
return m.mi.PBType.New()
return m.messageInfo().PBType.New()
}
func (m *{{.}}) Interface() protoreflect.ProtoMessage {
{{if eq . "messageState" -}}
@ -584,11 +584,11 @@ func (m *{{.}}) Interface() protoreflect.ProtoMessage {
{{- end -}}
}
func (m *{{.}}) ProtoUnwrap() interface{} {
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
}
func (m *{{.}}) ProtoMethods() *protoiface.Methods {
m.mi.init()
return &m.mi.methods
m.messageInfo().init()
return &m.messageInfo().methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@ -597,82 +597,82 @@ func (m *{{.}}) ProtoMethods() *protoiface.Methods {
// WARNING: This method is exempt from the compatibility promise and
// may be removed in the future without warning.
func (m *{{.}}) ProtoMessageInfo() *MessageInfo {
return m.mi
return m.messageInfo()
}
func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.mi.init()
for _, fi := range m.mi.fields {
m.messageInfo().init()
for _, fi := range m.messageInfo().fields {
if fi.has(m.pointer()) {
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
m.mi.extensionMap(m.pointer()).Range(f)
m.messageInfo().extensionMap(m.pointer()).Range(f)
}
func (m *{{.}}) Has(fd protoreflect.FieldDescriptor) bool {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Has(xt)
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
}
}
func (m *{{.}}) Clear(fd protoreflect.FieldDescriptor) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
m.mi.extensionMap(m.pointer()).Clear(xt)
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
}
}
func (m *{{.}}) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Get(xt)
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
}
}
func (m *{{.}}) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
m.mi.extensionMap(m.pointer()).Set(xt, v)
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
}
}
func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Mutable(xt)
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
}
}
func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
} else {
return xt.New().Message()
}
}
func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
m.mi.init()
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
m.messageInfo().init()
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor")
}
func (m *{{.}}) GetUnknown() protoreflect.RawFields {
m.mi.init()
return m.mi.getUnknown(m.pointer())
m.messageInfo().init()
return m.messageInfo().getUnknown(m.pointer())
}
func (m *{{.}}) SetUnknown(b protoreflect.RawFields) {
m.mi.init()
m.mi.setUnknown(m.pointer(), b)
m.messageInfo().init()
m.messageInfo().setUnknown(m.pointer(), b)
}
{{end}}

View File

@ -106,7 +106,8 @@ func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
return &messageReflectWrapper{p, mi}
}
func (m *messageReflectWrapper) pointer() pointer { return m.p }
func (m *messageReflectWrapper) pointer() pointer { return m.p }
func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
return (*messageReflectWrapper)(m)

View File

@ -12,23 +12,23 @@ import (
)
func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
return m.mi.PBType.Descriptor()
return m.messageInfo().PBType.Descriptor()
}
func (m *messageState) Type() protoreflect.MessageType {
return m.mi.PBType
return m.messageInfo().PBType
}
func (m *messageState) New() protoreflect.Message {
return m.mi.PBType.New()
return m.messageInfo().PBType.New()
}
func (m *messageState) Interface() protoreflect.ProtoMessage {
return m.ProtoUnwrap().(protoreflect.ProtoMessage)
}
func (m *messageState) ProtoUnwrap() interface{} {
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
}
func (m *messageState) ProtoMethods() *protoiface.Methods {
m.mi.init()
return &m.mi.methods
m.messageInfo().init()
return &m.messageInfo().methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@ -37,92 +37,92 @@ func (m *messageState) ProtoMethods() *protoiface.Methods {
// WARNING: This method is exempt from the compatibility promise and
// may be removed in the future without warning.
func (m *messageState) ProtoMessageInfo() *MessageInfo {
return m.mi
return m.messageInfo()
}
func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.mi.init()
for _, fi := range m.mi.fields {
m.messageInfo().init()
for _, fi := range m.messageInfo().fields {
if fi.has(m.pointer()) {
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
m.mi.extensionMap(m.pointer()).Range(f)
m.messageInfo().extensionMap(m.pointer()).Range(f)
}
func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Has(xt)
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
}
}
func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
m.mi.extensionMap(m.pointer()).Clear(xt)
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
}
}
func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Get(xt)
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
}
}
func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
m.mi.extensionMap(m.pointer()).Set(xt, v)
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
}
}
func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Mutable(xt)
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
}
}
func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
} else {
return xt.New().Message()
}
}
func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
m.mi.init()
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
m.messageInfo().init()
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor")
}
func (m *messageState) GetUnknown() protoreflect.RawFields {
m.mi.init()
return m.mi.getUnknown(m.pointer())
m.messageInfo().init()
return m.messageInfo().getUnknown(m.pointer())
}
func (m *messageState) SetUnknown(b protoreflect.RawFields) {
m.mi.init()
m.mi.setUnknown(m.pointer(), b)
m.messageInfo().init()
m.messageInfo().setUnknown(m.pointer(), b)
}
func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
return m.mi.PBType.Descriptor()
return m.messageInfo().PBType.Descriptor()
}
func (m *messageReflectWrapper) Type() protoreflect.MessageType {
return m.mi.PBType
return m.messageInfo().PBType
}
func (m *messageReflectWrapper) New() protoreflect.Message {
return m.mi.PBType.New()
return m.messageInfo().PBType.New()
}
func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok {
@ -131,11 +131,11 @@ func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
return (*messageIfaceWrapper)(m)
}
func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
}
func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
m.mi.init()
return &m.mi.methods
m.messageInfo().init()
return &m.messageInfo().methods
}
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@ -144,80 +144,80 @@ func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
// WARNING: This method is exempt from the compatibility promise and
// may be removed in the future without warning.
func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo {
return m.mi
return m.messageInfo()
}
func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
m.mi.init()
for _, fi := range m.mi.fields {
m.messageInfo().init()
for _, fi := range m.messageInfo().fields {
if fi.has(m.pointer()) {
if !f(fi.fieldDesc, fi.get(m.pointer())) {
return
}
}
}
m.mi.extensionMap(m.pointer()).Range(f)
m.messageInfo().extensionMap(m.pointer()).Range(f)
}
func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.has(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Has(xt)
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
}
}
func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.clear(m.pointer())
} else {
m.mi.extensionMap(m.pointer()).Clear(xt)
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
}
}
func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.get(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Get(xt)
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
}
}
func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
fi.set(m.pointer(), v)
} else {
m.mi.extensionMap(m.pointer()).Set(xt, v)
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
}
}
func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.mutable(m.pointer())
} else {
return m.mi.extensionMap(m.pointer()).Mutable(xt)
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
}
}
func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
m.mi.init()
if fi, xt := m.mi.checkField(fd); fi != nil {
m.messageInfo().init()
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
return fi.newMessage()
} else {
return xt.New().Message()
}
}
func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
m.mi.init()
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
m.messageInfo().init()
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
return od.Fields().ByNumber(oi.which(m.pointer()))
}
panic("invalid oneof descriptor")
}
func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
m.mi.init()
return m.mi.getUnknown(m.pointer())
m.messageInfo().init()
return m.messageInfo().getUnknown(m.pointer())
}
func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
m.mi.init()
m.mi.setUnknown(m.pointer(), b)
m.messageInfo().init()
m.messageInfo().setUnknown(m.pointer(), b)
}

View File

@ -159,6 +159,7 @@ func (p pointer) SetPointer(v pointer) {
func (Export) MessageStateOf(p Pointer) *messageState { panic("not supported") }
func (ms *messageState) pointer() pointer { panic("not supported") }
func (ms *messageState) messageInfo() *MessageInfo { panic("not supported") }
func (ms *messageState) LoadMessageInfo() *MessageInfo { panic("not supported") }
func (ms *messageState) StoreMessageInfo(mi *MessageInfo) { panic("not supported") }

View File

@ -147,6 +147,9 @@ func (ms *messageState) pointer() pointer {
// Super-tricky - see documentation on MessageState.
return pointer{p: unsafe.Pointer(ms)}
}
func (ms *messageState) messageInfo() *MessageInfo {
return ms.LoadMessageInfo()
}
func (ms *messageState) LoadMessageInfo() *MessageInfo {
return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi))))
}