mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-02-19 12:40:24 +00:00
internal/impl: fix race over messageState.mi
The messageState.mi field is atomically checked and set in generated code to the *MessageInfo associated with that message. However, the messageState type accesses the mi field without any atomic loads, thus being a potential race. We fix this by always calling a messageInfo method that performs a atomic.LoadPointer on the *MessageInfo. There is no performance effect from this change on x86 since an atomic.LoadPointer is identical to a MOV instruction. From an assembly perspective, there was no memory race previously. However, the lack of an atomic.LoadPointer meant that the compiler could in theory reorder the "normal" load to produce truly racy code. Change-Id: I8afefaf35c1916872781abc0239cbb63d62edf16 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189017 Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
d57568e763
commit
2aea614c5e
@ -565,13 +565,13 @@ func generateImplMessage() string {
|
||||
var implMessageTemplate = template.Must(template.New("").Parse(`
|
||||
{{range . -}}
|
||||
func (m *{{.}}) Descriptor() protoreflect.MessageDescriptor {
|
||||
return m.mi.PBType.Descriptor()
|
||||
return m.messageInfo().PBType.Descriptor()
|
||||
}
|
||||
func (m *{{.}}) Type() protoreflect.MessageType {
|
||||
return m.mi.PBType
|
||||
return m.messageInfo().PBType
|
||||
}
|
||||
func (m *{{.}}) New() protoreflect.Message {
|
||||
return m.mi.PBType.New()
|
||||
return m.messageInfo().PBType.New()
|
||||
}
|
||||
func (m *{{.}}) Interface() protoreflect.ProtoMessage {
|
||||
{{if eq . "messageState" -}}
|
||||
@ -584,11 +584,11 @@ func (m *{{.}}) Interface() protoreflect.ProtoMessage {
|
||||
{{- end -}}
|
||||
}
|
||||
func (m *{{.}}) ProtoUnwrap() interface{} {
|
||||
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
|
||||
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
|
||||
}
|
||||
func (m *{{.}}) ProtoMethods() *protoiface.Methods {
|
||||
m.mi.init()
|
||||
return &m.mi.methods
|
||||
m.messageInfo().init()
|
||||
return &m.messageInfo().methods
|
||||
}
|
||||
|
||||
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
|
||||
@ -597,82 +597,82 @@ func (m *{{.}}) ProtoMethods() *protoiface.Methods {
|
||||
// WARNING: This method is exempt from the compatibility promise and
|
||||
// may be removed in the future without warning.
|
||||
func (m *{{.}}) ProtoMessageInfo() *MessageInfo {
|
||||
return m.mi
|
||||
return m.messageInfo()
|
||||
}
|
||||
|
||||
func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
|
||||
m.mi.init()
|
||||
for _, fi := range m.mi.fields {
|
||||
m.messageInfo().init()
|
||||
for _, fi := range m.messageInfo().fields {
|
||||
if fi.has(m.pointer()) {
|
||||
if !f(fi.fieldDesc, fi.get(m.pointer())) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mi.extensionMap(m.pointer()).Range(f)
|
||||
m.messageInfo().extensionMap(m.pointer()).Range(f)
|
||||
}
|
||||
func (m *{{.}}) Has(fd protoreflect.FieldDescriptor) bool {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.has(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Has(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) Clear(fd protoreflect.FieldDescriptor) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.clear(m.pointer())
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Clear(xt)
|
||||
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.get(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Get(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.set(m.pointer(), v)
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Set(xt, v)
|
||||
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.mutable(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Mutable(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.newMessage()
|
||||
} else {
|
||||
return xt.New().Message()
|
||||
}
|
||||
}
|
||||
func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
|
||||
m.mi.init()
|
||||
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
m.messageInfo().init()
|
||||
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
return od.Fields().ByNumber(oi.which(m.pointer()))
|
||||
}
|
||||
panic("invalid oneof descriptor")
|
||||
}
|
||||
func (m *{{.}}) GetUnknown() protoreflect.RawFields {
|
||||
m.mi.init()
|
||||
return m.mi.getUnknown(m.pointer())
|
||||
m.messageInfo().init()
|
||||
return m.messageInfo().getUnknown(m.pointer())
|
||||
}
|
||||
func (m *{{.}}) SetUnknown(b protoreflect.RawFields) {
|
||||
m.mi.init()
|
||||
m.mi.setUnknown(m.pointer(), b)
|
||||
m.messageInfo().init()
|
||||
m.messageInfo().setUnknown(m.pointer(), b)
|
||||
}
|
||||
|
||||
{{end}}
|
||||
|
@ -106,7 +106,8 @@ func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
|
||||
return &messageReflectWrapper{p, mi}
|
||||
}
|
||||
|
||||
func (m *messageReflectWrapper) pointer() pointer { return m.p }
|
||||
func (m *messageReflectWrapper) pointer() pointer { return m.p }
|
||||
func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
|
||||
|
||||
func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
|
||||
return (*messageReflectWrapper)(m)
|
||||
|
@ -12,23 +12,23 @@ import (
|
||||
)
|
||||
|
||||
func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
|
||||
return m.mi.PBType.Descriptor()
|
||||
return m.messageInfo().PBType.Descriptor()
|
||||
}
|
||||
func (m *messageState) Type() protoreflect.MessageType {
|
||||
return m.mi.PBType
|
||||
return m.messageInfo().PBType
|
||||
}
|
||||
func (m *messageState) New() protoreflect.Message {
|
||||
return m.mi.PBType.New()
|
||||
return m.messageInfo().PBType.New()
|
||||
}
|
||||
func (m *messageState) Interface() protoreflect.ProtoMessage {
|
||||
return m.ProtoUnwrap().(protoreflect.ProtoMessage)
|
||||
}
|
||||
func (m *messageState) ProtoUnwrap() interface{} {
|
||||
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
|
||||
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
|
||||
}
|
||||
func (m *messageState) ProtoMethods() *protoiface.Methods {
|
||||
m.mi.init()
|
||||
return &m.mi.methods
|
||||
m.messageInfo().init()
|
||||
return &m.messageInfo().methods
|
||||
}
|
||||
|
||||
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
|
||||
@ -37,92 +37,92 @@ func (m *messageState) ProtoMethods() *protoiface.Methods {
|
||||
// WARNING: This method is exempt from the compatibility promise and
|
||||
// may be removed in the future without warning.
|
||||
func (m *messageState) ProtoMessageInfo() *MessageInfo {
|
||||
return m.mi
|
||||
return m.messageInfo()
|
||||
}
|
||||
|
||||
func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
|
||||
m.mi.init()
|
||||
for _, fi := range m.mi.fields {
|
||||
m.messageInfo().init()
|
||||
for _, fi := range m.messageInfo().fields {
|
||||
if fi.has(m.pointer()) {
|
||||
if !f(fi.fieldDesc, fi.get(m.pointer())) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mi.extensionMap(m.pointer()).Range(f)
|
||||
m.messageInfo().extensionMap(m.pointer()).Range(f)
|
||||
}
|
||||
func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.has(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Has(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.clear(m.pointer())
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Clear(xt)
|
||||
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.get(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Get(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.set(m.pointer(), v)
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Set(xt, v)
|
||||
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
|
||||
}
|
||||
}
|
||||
func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.mutable(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Mutable(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.newMessage()
|
||||
} else {
|
||||
return xt.New().Message()
|
||||
}
|
||||
}
|
||||
func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
|
||||
m.mi.init()
|
||||
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
m.messageInfo().init()
|
||||
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
return od.Fields().ByNumber(oi.which(m.pointer()))
|
||||
}
|
||||
panic("invalid oneof descriptor")
|
||||
}
|
||||
func (m *messageState) GetUnknown() protoreflect.RawFields {
|
||||
m.mi.init()
|
||||
return m.mi.getUnknown(m.pointer())
|
||||
m.messageInfo().init()
|
||||
return m.messageInfo().getUnknown(m.pointer())
|
||||
}
|
||||
func (m *messageState) SetUnknown(b protoreflect.RawFields) {
|
||||
m.mi.init()
|
||||
m.mi.setUnknown(m.pointer(), b)
|
||||
m.messageInfo().init()
|
||||
m.messageInfo().setUnknown(m.pointer(), b)
|
||||
}
|
||||
|
||||
func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
|
||||
return m.mi.PBType.Descriptor()
|
||||
return m.messageInfo().PBType.Descriptor()
|
||||
}
|
||||
func (m *messageReflectWrapper) Type() protoreflect.MessageType {
|
||||
return m.mi.PBType
|
||||
return m.messageInfo().PBType
|
||||
}
|
||||
func (m *messageReflectWrapper) New() protoreflect.Message {
|
||||
return m.mi.PBType.New()
|
||||
return m.messageInfo().PBType.New()
|
||||
}
|
||||
func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
|
||||
if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok {
|
||||
@ -131,11 +131,11 @@ func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
|
||||
return (*messageIfaceWrapper)(m)
|
||||
}
|
||||
func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
|
||||
return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
|
||||
return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
|
||||
}
|
||||
func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
|
||||
m.mi.init()
|
||||
return &m.mi.methods
|
||||
m.messageInfo().init()
|
||||
return &m.messageInfo().methods
|
||||
}
|
||||
|
||||
// ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
|
||||
@ -144,80 +144,80 @@ func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
|
||||
// WARNING: This method is exempt from the compatibility promise and
|
||||
// may be removed in the future without warning.
|
||||
func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo {
|
||||
return m.mi
|
||||
return m.messageInfo()
|
||||
}
|
||||
|
||||
func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
|
||||
m.mi.init()
|
||||
for _, fi := range m.mi.fields {
|
||||
m.messageInfo().init()
|
||||
for _, fi := range m.messageInfo().fields {
|
||||
if fi.has(m.pointer()) {
|
||||
if !f(fi.fieldDesc, fi.get(m.pointer())) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mi.extensionMap(m.pointer()).Range(f)
|
||||
m.messageInfo().extensionMap(m.pointer()).Range(f)
|
||||
}
|
||||
func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.has(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Has(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Has(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.clear(m.pointer())
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Clear(xt)
|
||||
m.messageInfo().extensionMap(m.pointer()).Clear(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.get(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Get(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Get(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
fi.set(m.pointer(), v)
|
||||
} else {
|
||||
m.mi.extensionMap(m.pointer()).Set(xt, v)
|
||||
m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.mutable(m.pointer())
|
||||
} else {
|
||||
return m.mi.extensionMap(m.pointer()).Mutable(xt)
|
||||
return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
|
||||
m.mi.init()
|
||||
if fi, xt := m.mi.checkField(fd); fi != nil {
|
||||
m.messageInfo().init()
|
||||
if fi, xt := m.messageInfo().checkField(fd); fi != nil {
|
||||
return fi.newMessage()
|
||||
} else {
|
||||
return xt.New().Message()
|
||||
}
|
||||
}
|
||||
func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
|
||||
m.mi.init()
|
||||
if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
m.messageInfo().init()
|
||||
if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
|
||||
return od.Fields().ByNumber(oi.which(m.pointer()))
|
||||
}
|
||||
panic("invalid oneof descriptor")
|
||||
}
|
||||
func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
|
||||
m.mi.init()
|
||||
return m.mi.getUnknown(m.pointer())
|
||||
m.messageInfo().init()
|
||||
return m.messageInfo().getUnknown(m.pointer())
|
||||
}
|
||||
func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
|
||||
m.mi.init()
|
||||
m.mi.setUnknown(m.pointer(), b)
|
||||
m.messageInfo().init()
|
||||
m.messageInfo().setUnknown(m.pointer(), b)
|
||||
}
|
||||
|
@ -159,6 +159,7 @@ func (p pointer) SetPointer(v pointer) {
|
||||
|
||||
func (Export) MessageStateOf(p Pointer) *messageState { panic("not supported") }
|
||||
func (ms *messageState) pointer() pointer { panic("not supported") }
|
||||
func (ms *messageState) messageInfo() *MessageInfo { panic("not supported") }
|
||||
func (ms *messageState) LoadMessageInfo() *MessageInfo { panic("not supported") }
|
||||
func (ms *messageState) StoreMessageInfo(mi *MessageInfo) { panic("not supported") }
|
||||
|
||||
|
@ -147,6 +147,9 @@ func (ms *messageState) pointer() pointer {
|
||||
// Super-tricky - see documentation on MessageState.
|
||||
return pointer{p: unsafe.Pointer(ms)}
|
||||
}
|
||||
func (ms *messageState) messageInfo() *MessageInfo {
|
||||
return ms.LoadMessageInfo()
|
||||
}
|
||||
func (ms *messageState) LoadMessageInfo() *MessageInfo {
|
||||
return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi))))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user