From 2aea614c5eed1edc254edd3321506f4014191150 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 31 Jul 2019 12:27:30 -0700 Subject: [PATCH] internal/impl: fix race over messageState.mi The messageState.mi field is atomically checked and set in generated code to the *MessageInfo associated with that message. However, the messageState type accesses the mi field without any atomic loads, thus being a potential race. We fix this by always calling a messageInfo method that performs a atomic.LoadPointer on the *MessageInfo. There is no performance effect from this change on x86 since an atomic.LoadPointer is identical to a MOV instruction. From an assembly perspective, there was no memory race previously. However, the lack of an atomic.LoadPointer meant that the compiler could in theory reorder the "normal" load to produce truly racy code. Change-Id: I8afefaf35c1916872781abc0239cbb63d62edf16 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189017 Reviewed-by: Damien Neil --- internal/cmd/generate-types/impl.go | 66 +++++++------- internal/impl/message_reflect.go | 3 +- internal/impl/message_reflect_gen.go | 132 +++++++++++++-------------- internal/impl/pointer_reflect.go | 1 + internal/impl/pointer_unsafe.go | 3 + 5 files changed, 105 insertions(+), 100 deletions(-) diff --git a/internal/cmd/generate-types/impl.go b/internal/cmd/generate-types/impl.go index 728e758e..085710d4 100644 --- a/internal/cmd/generate-types/impl.go +++ b/internal/cmd/generate-types/impl.go @@ -565,13 +565,13 @@ func generateImplMessage() string { var implMessageTemplate = template.Must(template.New("").Parse(` {{range . -}} func (m *{{.}}) Descriptor() protoreflect.MessageDescriptor { - return m.mi.PBType.Descriptor() + return m.messageInfo().PBType.Descriptor() } func (m *{{.}}) Type() protoreflect.MessageType { - return m.mi.PBType + return m.messageInfo().PBType } func (m *{{.}}) New() protoreflect.Message { - return m.mi.PBType.New() + return m.messageInfo().PBType.New() } func (m *{{.}}) Interface() protoreflect.ProtoMessage { {{if eq . "messageState" -}} @@ -584,11 +584,11 @@ func (m *{{.}}) Interface() protoreflect.ProtoMessage { {{- end -}} } func (m *{{.}}) ProtoUnwrap() interface{} { - return m.pointer().AsIfaceOf(m.mi.GoType.Elem()) + return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem()) } func (m *{{.}}) ProtoMethods() *protoiface.Methods { - m.mi.init() - return &m.mi.methods + m.messageInfo().init() + return &m.messageInfo().methods } // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code @@ -597,82 +597,82 @@ func (m *{{.}}) ProtoMethods() *protoiface.Methods { // WARNING: This method is exempt from the compatibility promise and // may be removed in the future without warning. func (m *{{.}}) ProtoMessageInfo() *MessageInfo { - return m.mi + return m.messageInfo() } func (m *{{.}}) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { - m.mi.init() - for _, fi := range m.mi.fields { + m.messageInfo().init() + for _, fi := range m.messageInfo().fields { if fi.has(m.pointer()) { if !f(fi.fieldDesc, fi.get(m.pointer())) { return } } } - m.mi.extensionMap(m.pointer()).Range(f) + m.messageInfo().extensionMap(m.pointer()).Range(f) } func (m *{{.}}) Has(fd protoreflect.FieldDescriptor) bool { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.has(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Has(xt) + return m.messageInfo().extensionMap(m.pointer()).Has(xt) } } func (m *{{.}}) Clear(fd protoreflect.FieldDescriptor) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.clear(m.pointer()) } else { - m.mi.extensionMap(m.pointer()).Clear(xt) + m.messageInfo().extensionMap(m.pointer()).Clear(xt) } } func (m *{{.}}) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.get(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Get(xt) + return m.messageInfo().extensionMap(m.pointer()).Get(xt) } } func (m *{{.}}) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.set(m.pointer(), v) } else { - m.mi.extensionMap(m.pointer()).Set(xt, v) + m.messageInfo().extensionMap(m.pointer()).Set(xt, v) } } func (m *{{.}}) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.mutable(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Mutable(xt) + return m.messageInfo().extensionMap(m.pointer()).Mutable(xt) } } func (m *{{.}}) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.newMessage() } else { return xt.New().Message() } } func (m *{{.}}) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { - m.mi.init() - if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { + m.messageInfo().init() + if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { return od.Fields().ByNumber(oi.which(m.pointer())) } panic("invalid oneof descriptor") } func (m *{{.}}) GetUnknown() protoreflect.RawFields { - m.mi.init() - return m.mi.getUnknown(m.pointer()) + m.messageInfo().init() + return m.messageInfo().getUnknown(m.pointer()) } func (m *{{.}}) SetUnknown(b protoreflect.RawFields) { - m.mi.init() - m.mi.setUnknown(m.pointer(), b) + m.messageInfo().init() + m.messageInfo().setUnknown(m.pointer(), b) } {{end}} diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go index c79b55ab..fd5c8a97 100644 --- a/internal/impl/message_reflect.go +++ b/internal/impl/message_reflect.go @@ -106,7 +106,8 @@ func (mi *MessageInfo) MessageOf(m interface{}) pref.Message { return &messageReflectWrapper{p, mi} } -func (m *messageReflectWrapper) pointer() pointer { return m.p } +func (m *messageReflectWrapper) pointer() pointer { return m.p } +func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi } func (m *messageIfaceWrapper) ProtoReflect() pref.Message { return (*messageReflectWrapper)(m) diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go index 40447a68..e2f6d17a 100644 --- a/internal/impl/message_reflect_gen.go +++ b/internal/impl/message_reflect_gen.go @@ -12,23 +12,23 @@ import ( ) func (m *messageState) Descriptor() protoreflect.MessageDescriptor { - return m.mi.PBType.Descriptor() + return m.messageInfo().PBType.Descriptor() } func (m *messageState) Type() protoreflect.MessageType { - return m.mi.PBType + return m.messageInfo().PBType } func (m *messageState) New() protoreflect.Message { - return m.mi.PBType.New() + return m.messageInfo().PBType.New() } func (m *messageState) Interface() protoreflect.ProtoMessage { return m.ProtoUnwrap().(protoreflect.ProtoMessage) } func (m *messageState) ProtoUnwrap() interface{} { - return m.pointer().AsIfaceOf(m.mi.GoType.Elem()) + return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem()) } func (m *messageState) ProtoMethods() *protoiface.Methods { - m.mi.init() - return &m.mi.methods + m.messageInfo().init() + return &m.messageInfo().methods } // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code @@ -37,92 +37,92 @@ func (m *messageState) ProtoMethods() *protoiface.Methods { // WARNING: This method is exempt from the compatibility promise and // may be removed in the future without warning. func (m *messageState) ProtoMessageInfo() *MessageInfo { - return m.mi + return m.messageInfo() } func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { - m.mi.init() - for _, fi := range m.mi.fields { + m.messageInfo().init() + for _, fi := range m.messageInfo().fields { if fi.has(m.pointer()) { if !f(fi.fieldDesc, fi.get(m.pointer())) { return } } } - m.mi.extensionMap(m.pointer()).Range(f) + m.messageInfo().extensionMap(m.pointer()).Range(f) } func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.has(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Has(xt) + return m.messageInfo().extensionMap(m.pointer()).Has(xt) } } func (m *messageState) Clear(fd protoreflect.FieldDescriptor) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.clear(m.pointer()) } else { - m.mi.extensionMap(m.pointer()).Clear(xt) + m.messageInfo().extensionMap(m.pointer()).Clear(xt) } } func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.get(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Get(xt) + return m.messageInfo().extensionMap(m.pointer()).Get(xt) } } func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.set(m.pointer(), v) } else { - m.mi.extensionMap(m.pointer()).Set(xt, v) + m.messageInfo().extensionMap(m.pointer()).Set(xt, v) } } func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.mutable(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Mutable(xt) + return m.messageInfo().extensionMap(m.pointer()).Mutable(xt) } } func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.newMessage() } else { return xt.New().Message() } } func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { - m.mi.init() - if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { + m.messageInfo().init() + if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { return od.Fields().ByNumber(oi.which(m.pointer())) } panic("invalid oneof descriptor") } func (m *messageState) GetUnknown() protoreflect.RawFields { - m.mi.init() - return m.mi.getUnknown(m.pointer()) + m.messageInfo().init() + return m.messageInfo().getUnknown(m.pointer()) } func (m *messageState) SetUnknown(b protoreflect.RawFields) { - m.mi.init() - m.mi.setUnknown(m.pointer(), b) + m.messageInfo().init() + m.messageInfo().setUnknown(m.pointer(), b) } func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor { - return m.mi.PBType.Descriptor() + return m.messageInfo().PBType.Descriptor() } func (m *messageReflectWrapper) Type() protoreflect.MessageType { - return m.mi.PBType + return m.messageInfo().PBType } func (m *messageReflectWrapper) New() protoreflect.Message { - return m.mi.PBType.New() + return m.messageInfo().PBType.New() } func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage { if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok { @@ -131,11 +131,11 @@ func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage { return (*messageIfaceWrapper)(m) } func (m *messageReflectWrapper) ProtoUnwrap() interface{} { - return m.pointer().AsIfaceOf(m.mi.GoType.Elem()) + return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem()) } func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods { - m.mi.init() - return &m.mi.methods + m.messageInfo().init() + return &m.messageInfo().methods } // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code @@ -144,80 +144,80 @@ func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods { // WARNING: This method is exempt from the compatibility promise and // may be removed in the future without warning. func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo { - return m.mi + return m.messageInfo() } func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) { - m.mi.init() - for _, fi := range m.mi.fields { + m.messageInfo().init() + for _, fi := range m.messageInfo().fields { if fi.has(m.pointer()) { if !f(fi.fieldDesc, fi.get(m.pointer())) { return } } } - m.mi.extensionMap(m.pointer()).Range(f) + m.messageInfo().extensionMap(m.pointer()).Range(f) } func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.has(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Has(xt) + return m.messageInfo().extensionMap(m.pointer()).Has(xt) } } func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.clear(m.pointer()) } else { - m.mi.extensionMap(m.pointer()).Clear(xt) + m.messageInfo().extensionMap(m.pointer()).Clear(xt) } } func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.get(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Get(xt) + return m.messageInfo().extensionMap(m.pointer()).Get(xt) } } func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { fi.set(m.pointer(), v) } else { - m.mi.extensionMap(m.pointer()).Set(xt, v) + m.messageInfo().extensionMap(m.pointer()).Set(xt, v) } } func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.mutable(m.pointer()) } else { - return m.mi.extensionMap(m.pointer()).Mutable(xt) + return m.messageInfo().extensionMap(m.pointer()).Mutable(xt) } } func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message { - m.mi.init() - if fi, xt := m.mi.checkField(fd); fi != nil { + m.messageInfo().init() + if fi, xt := m.messageInfo().checkField(fd); fi != nil { return fi.newMessage() } else { return xt.New().Message() } } func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor { - m.mi.init() - if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { + m.messageInfo().init() + if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od { return od.Fields().ByNumber(oi.which(m.pointer())) } panic("invalid oneof descriptor") } func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields { - m.mi.init() - return m.mi.getUnknown(m.pointer()) + m.messageInfo().init() + return m.messageInfo().getUnknown(m.pointer()) } func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) { - m.mi.init() - m.mi.setUnknown(m.pointer(), b) + m.messageInfo().init() + m.messageInfo().setUnknown(m.pointer(), b) } diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go index d076b9da..7b4510a3 100644 --- a/internal/impl/pointer_reflect.go +++ b/internal/impl/pointer_reflect.go @@ -159,6 +159,7 @@ func (p pointer) SetPointer(v pointer) { func (Export) MessageStateOf(p Pointer) *messageState { panic("not supported") } func (ms *messageState) pointer() pointer { panic("not supported") } +func (ms *messageState) messageInfo() *MessageInfo { panic("not supported") } func (ms *messageState) LoadMessageInfo() *MessageInfo { panic("not supported") } func (ms *messageState) StoreMessageInfo(mi *MessageInfo) { panic("not supported") } diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go index 3f53cbc5..b7f2b1e0 100644 --- a/internal/impl/pointer_unsafe.go +++ b/internal/impl/pointer_unsafe.go @@ -147,6 +147,9 @@ func (ms *messageState) pointer() pointer { // Super-tricky - see documentation on MessageState. return pointer{p: unsafe.Pointer(ms)} } +func (ms *messageState) messageInfo() *MessageInfo { + return ms.LoadMessageInfo() +} func (ms *messageState) LoadMessageInfo() *MessageInfo { return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi)))) }