proto: Implement proto.Equal fast-path

Also adds better benchmark cases for large message where some fields are
actually populated.

This change was previously done in Google internal cl/660848520.

Change-Id: I682aae0c9c2850bfe7638de29ab743ad7d7b119a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/609035
Reviewed-by: Christian Höppner <hoeppi@google.com>
Reviewed-by: Cassondra Foesch <cfoesch@gmail.com>
Reviewed-by: Michael Stapelberg <stapelberg@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Reno Reckling 2024-08-27 14:07:11 +02:00 committed by Damien Neil
parent 013dd178dc
commit 03df6c145d
8 changed files with 2664 additions and 1005 deletions

View File

@ -189,6 +189,9 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
if mi.methods.Merge == nil {
mi.methods.Merge = mi.merge
}
if mi.methods.Equal == nil {
mi.methods.Equal = equal
}
}
// getUnknownBytes returns a *[]byte for the unknown fields.

224
internal/impl/equal.go Normal file
View File

@ -0,0 +1,224 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"bytes"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
func equal(in protoiface.EqualInput) protoiface.EqualOutput {
return protoiface.EqualOutput{Equal: equalMessage(in.MessageA, in.MessageB)}
}
// equalMessage is a fast-path variant of protoreflect.equalMessage.
// It takes advantage of the internal messageState type to avoid
// unnecessary allocations, type assertions.
func equalMessage(mx, my protoreflect.Message) bool {
if mx == nil || my == nil {
return mx == my
}
if mx.Descriptor() != my.Descriptor() {
return false
}
msx, ok := mx.(*messageState)
if !ok {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
msy, ok := my.(*messageState)
if !ok {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
mi := msx.messageInfo()
miy := msy.messageInfo()
if mi != miy {
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
}
mi.init()
// Compares regular fields
// Modified Message.Range code that compares two messages of the same type
// while going over the fields.
for _, ri := range mi.rangeInfos {
var fd protoreflect.FieldDescriptor
var vx, vy protoreflect.Value
switch ri := ri.(type) {
case *fieldInfo:
hx := ri.has(msx.pointer())
hy := ri.has(msy.pointer())
if hx != hy {
return false
}
if !hx {
continue
}
fd = ri.fieldDesc
vx = ri.get(msx.pointer())
vy = ri.get(msy.pointer())
case *oneofInfo:
fnx := ri.which(msx.pointer())
fny := ri.which(msy.pointer())
if fnx != fny {
return false
}
if fnx <= 0 {
continue
}
fi := mi.fields[fnx]
fd = fi.fieldDesc
vx = fi.get(msx.pointer())
vy = fi.get(msy.pointer())
}
if !equalValue(fd, vx, vy) {
return false
}
}
// Compare extensions.
// This is more complicated because mx or my could have empty/nil extension maps,
// however some populated extension map values are equal to nil extension maps.
emx := mi.extensionMap(msx.pointer())
emy := mi.extensionMap(msy.pointer())
if emx != nil {
for k, x := range *emx {
xd := x.Type().TypeDescriptor()
xv := x.Value()
var y ExtensionField
ok := false
if emy != nil {
y, ok = (*emy)[k]
}
// We need to treat empty lists as equal to nil values
if emy == nil || !ok {
if xd.IsList() && xv.List().Len() == 0 {
continue
}
return false
}
if !equalValue(xd, xv, y.Value()) {
return false
}
}
}
if emy != nil {
// emy may have extensions emx does not have, need to check them as well
for k, y := range *emy {
if emx != nil {
// emx has the field, so we already checked it
if _, ok := (*emx)[k]; ok {
continue
}
}
// Empty lists are equal to nil
if y.Type().TypeDescriptor().IsList() && y.Value().List().Len() == 0 {
continue
}
// Cant be equal if the extension is populated
return false
}
}
return equalUnknown(mx.GetUnknown(), my.GetUnknown())
}
func equalValue(fd protoreflect.FieldDescriptor, vx, vy protoreflect.Value) bool {
// slow path
if fd.Kind() != protoreflect.MessageKind {
return vx.Equal(vy)
}
// fast path special cases
if fd.IsMap() {
if fd.MapValue().Kind() == protoreflect.MessageKind {
return equalMessageMap(vx.Map(), vy.Map())
}
return vx.Equal(vy)
}
if fd.IsList() {
return equalMessageList(vx.List(), vy.List())
}
return equalMessage(vx.Message(), vy.Message())
}
// Mostly copied from protoreflect.equalMap.
// This variant only works for messages as map types.
// All other map types should be handled via Value.Equal.
func equalMessageMap(mx, my protoreflect.Map) bool {
if mx.Len() != my.Len() {
return false
}
equal := true
mx.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
if !my.Has(k) {
equal = false
return false
}
vy := my.Get(k)
equal = equalMessage(vx.Message(), vy.Message())
return equal
})
return equal
}
// Mostly copied from protoreflect.equalList.
// The only change is the usage of equalImpl instead of protoreflect.equalValue.
func equalMessageList(lx, ly protoreflect.List) bool {
if lx.Len() != ly.Len() {
return false
}
for i := 0; i < lx.Len(); i++ {
// We only operate on messages here since equalImpl will not call us in any other case.
if !equalMessage(lx.Get(i).Message(), ly.Get(i).Message()) {
return false
}
}
return true
}
// equalUnknown compares unknown fields by direct comparison on the raw bytes
// of each individual field number.
// Copied from protoreflect.equalUnknown.
func equalUnknown(x, y protoreflect.RawFields) bool {
if len(x) != len(y) {
return false
}
if bytes.Equal([]byte(x), []byte(y)) {
return true
}
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
for len(x) > 0 {
fnum, _, n := protowire.ConsumeField(x)
mx[fnum] = append(mx[fnum], x[:n]...)
x = x[n:]
}
for len(y) > 0 {
fnum, _, n := protowire.ConsumeField(y)
my[fnum] = append(my[fnum], y[:n]...)
y = y[n:]
}
if len(mx) != len(my) {
return false
}
for k, v1 := range mx {
if v2, ok := my[k]; !ok || !bytes.Equal([]byte(v1), []byte(v2)) {
return false
}
}
return true
}

File diff suppressed because it is too large Load Diff

View File

@ -142,6 +142,109 @@ message TestAllTypes {
}
}
message TestManyMessageFieldsMessage {
optional TestAllTypes f1 = 1;
optional TestAllTypes f2 = 2;
optional TestAllTypes f3 = 3;
optional TestAllTypes f4 = 4;
optional TestAllTypes f5 = 5;
optional TestAllTypes f6 = 6;
optional TestAllTypes f7 = 7;
optional TestAllTypes f8 = 8;
optional TestAllTypes f9 = 9;
optional TestAllTypes f10 = 10;
optional TestAllTypes f11 = 11;
optional TestAllTypes f12 = 12;
optional TestAllTypes f13 = 13;
optional TestAllTypes f14 = 14;
optional TestAllTypes f15 = 15;
optional TestAllTypes f16 = 16;
optional TestAllTypes f17 = 17;
optional TestAllTypes f18 = 18;
optional TestAllTypes f19 = 19;
optional TestAllTypes f20 = 20;
optional TestAllTypes f21 = 21;
optional TestAllTypes f22 = 22;
optional TestAllTypes f23 = 23;
optional TestAllTypes f24 = 24;
optional TestAllTypes f25 = 25;
optional TestAllTypes f26 = 26;
optional TestAllTypes f27 = 27;
optional TestAllTypes f28 = 28;
optional TestAllTypes f29 = 29;
optional TestAllTypes f30 = 30;
optional TestAllTypes f31 = 31;
optional TestAllTypes f32 = 32;
optional TestAllTypes f33 = 33;
optional TestAllTypes f34 = 34;
optional TestAllTypes f35 = 35;
optional TestAllTypes f36 = 36;
optional TestAllTypes f37 = 37;
optional TestAllTypes f38 = 38;
optional TestAllTypes f39 = 39;
optional TestAllTypes f40 = 40;
optional TestAllTypes f41 = 41;
optional TestAllTypes f42 = 42;
optional TestAllTypes f43 = 43;
optional TestAllTypes f44 = 44;
optional TestAllTypes f45 = 45;
optional TestAllTypes f46 = 46;
optional TestAllTypes f47 = 47;
optional TestAllTypes f48 = 48;
optional TestAllTypes f49 = 49;
optional TestAllTypes f50 = 50;
optional TestAllTypes f51 = 51;
optional TestAllTypes f52 = 52;
optional TestAllTypes f53 = 53;
optional TestAllTypes f54 = 54;
optional TestAllTypes f55 = 55;
optional TestAllTypes f56 = 56;
optional TestAllTypes f57 = 57;
optional TestAllTypes f58 = 58;
optional TestAllTypes f59 = 59;
optional TestAllTypes f60 = 60;
optional TestAllTypes f61 = 61;
optional TestAllTypes f62 = 62;
optional TestAllTypes f63 = 63;
optional TestAllTypes f64 = 64;
optional TestAllTypes f65 = 65;
optional TestAllTypes f66 = 66;
optional TestAllTypes f67 = 67;
optional TestAllTypes f68 = 68;
optional TestAllTypes f69 = 69;
optional TestAllTypes f70 = 70;
optional TestAllTypes f71 = 71;
optional TestAllTypes f72 = 72;
optional TestAllTypes f73 = 73;
optional TestAllTypes f74 = 74;
optional TestAllTypes f75 = 75;
optional TestAllTypes f76 = 76;
optional TestAllTypes f77 = 77;
optional TestAllTypes f78 = 78;
optional TestAllTypes f79 = 79;
optional TestAllTypes f80 = 80;
optional TestAllTypes f81 = 81;
optional TestAllTypes f82 = 82;
optional TestAllTypes f83 = 83;
optional TestAllTypes f84 = 84;
optional TestAllTypes f85 = 85;
optional TestAllTypes f86 = 86;
optional TestAllTypes f87 = 87;
optional TestAllTypes f88 = 88;
optional TestAllTypes f89 = 89;
optional TestAllTypes f90 = 90;
optional TestAllTypes f91 = 91;
optional TestAllTypes f92 = 92;
optional TestAllTypes f93 = 93;
optional TestAllTypes f94 = 94;
optional TestAllTypes f95 = 95;
optional TestAllTypes f96 = 96;
optional TestAllTypes f97 = 97;
optional TestAllTypes f98 = 98;
optional TestAllTypes f99 = 99;
optional TestAllTypes f100 = 100;
}
message TestDeprecatedMessage {
option deprecated = true;

View File

@ -8,6 +8,7 @@ import (
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
// Equal reports whether two messages are equal,
@ -51,6 +52,14 @@ func Equal(x, y Message) bool {
if mx.IsValid() != my.IsValid() {
return false
}
// Only one of the messages needs to implement the fast-path for it to work.
pmx := protoMethods(mx)
pmy := protoMethods(my)
if pmx != nil && pmy != nil && pmx.Equal != nil && pmy.Equal != nil {
return pmx.Equal(protoiface.EqualInput{MessageA: mx, MessageB: my}).Equal
}
vx := protoreflect.ValueOfMessage(mx)
vy := protoreflect.ValueOfMessage(my)
return vx.Equal(vy)

View File

@ -1004,6 +1004,7 @@ func TestEqual(t *testing.T) {
}
func BenchmarkEqualWithSmallEmpty(b *testing.B) {
b.ReportAllocs()
x := &testpb.ForeignMessage{}
y := &testpb.ForeignMessage{}
@ -1014,6 +1015,7 @@ func BenchmarkEqualWithSmallEmpty(b *testing.B) {
}
func BenchmarkEqualWithIdenticalPtrEmpty(b *testing.B) {
b.ReportAllocs()
x := &testpb.ForeignMessage{}
b.ResetTimer()
@ -1023,8 +1025,31 @@ func BenchmarkEqualWithIdenticalPtrEmpty(b *testing.B) {
}
func BenchmarkEqualWithLargeEmpty(b *testing.B) {
x := &testpb.TestAllTypes{}
y := &testpb.TestAllTypes{}
b.ReportAllocs()
x := &testpb.TestManyMessageFieldsMessage{
F1: makeNested(2),
F10: makeNested(2),
F20: makeNested(2),
F30: makeNested(2),
F40: makeNested(2),
F50: makeNested(2),
F60: makeNested(2),
F70: makeNested(2),
F80: makeNested(2),
F90: makeNested(2),
}
y := &testpb.TestManyMessageFieldsMessage{
F1: makeNested(2),
F10: makeNested(2),
F20: makeNested(2),
F30: makeNested(2),
F40: makeNested(2),
F50: makeNested(2),
F60: makeNested(2),
F70: makeNested(2),
F80: makeNested(2),
F90: makeNested(2),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -1044,6 +1069,7 @@ func makeNested(depth int) *testpb.TestAllTypes {
}
func BenchmarkEqualWithDeeplyNestedEqual(b *testing.B) {
b.ReportAllocs()
x := makeNested(20)
y := makeNested(20)
@ -1054,6 +1080,7 @@ func BenchmarkEqualWithDeeplyNestedEqual(b *testing.B) {
}
func BenchmarkEqualWithDeeplyNestedDifferent(b *testing.B) {
b.ReportAllocs()
x := makeNested(20)
y := makeNested(21)
@ -1064,6 +1091,7 @@ func BenchmarkEqualWithDeeplyNestedDifferent(b *testing.B) {
}
func BenchmarkEqualWithDeeplyNestedIdenticalPtr(b *testing.B) {
b.ReportAllocs()
x := makeNested(20)
b.ResetTimer()

View File

@ -23,6 +23,7 @@ type (
Unmarshal func(unmarshalInput) (unmarshalOutput, error)
Merge func(mergeInput) mergeOutput
CheckInitialized func(checkInitializedInput) (checkInitializedOutput, error)
Equal func(equalInput) equalOutput
}
supportFlags = uint64
sizeInput = struct {
@ -75,4 +76,13 @@ type (
checkInitializedOutput = struct {
pragma.NoUnkeyedLiterals
}
equalInput = struct {
pragma.NoUnkeyedLiterals
MessageA Message
MessageB Message
}
equalOutput = struct {
pragma.NoUnkeyedLiterals
Equal bool
}
)

View File

@ -39,6 +39,9 @@ type Methods = struct {
// CheckInitialized returns an error if any required fields in the message are not set.
CheckInitialized func(CheckInitializedInput) (CheckInitializedOutput, error)
// Equal compares two messages and returns EqualOutput.Equal == true if they are equal.
Equal func(EqualInput) EqualOutput
}
// SupportFlags indicate support for optional features.
@ -166,3 +169,18 @@ type CheckInitializedInput = struct {
type CheckInitializedOutput = struct {
pragma.NoUnkeyedLiterals
}
// EqualInput is input to the Equal method.
type EqualInput = struct {
pragma.NoUnkeyedLiterals
MessageA protoreflect.Message
MessageB protoreflect.Message
}
// EqualOutput is output from the Equal method.
type EqualOutput = struct {
pragma.NoUnkeyedLiterals
Equal bool
}