mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-14 10:21:28 +00:00
proto: Implement proto.Equal fast-path
Also adds better benchmark cases for large message where some fields are actually populated. This change was previously done in Google internal cl/660848520. Change-Id: I682aae0c9c2850bfe7638de29ab743ad7d7b119a Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/609035 Reviewed-by: Christian Höppner <hoeppi@google.com> Reviewed-by: Cassondra Foesch <cfoesch@gmail.com> Reviewed-by: Michael Stapelberg <stapelberg@google.com> Reviewed-by: Damien Neil <dneil@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
parent
013dd178dc
commit
03df6c145d
@ -189,6 +189,9 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
|
||||
if mi.methods.Merge == nil {
|
||||
mi.methods.Merge = mi.merge
|
||||
}
|
||||
if mi.methods.Equal == nil {
|
||||
mi.methods.Equal = equal
|
||||
}
|
||||
}
|
||||
|
||||
// getUnknownBytes returns a *[]byte for the unknown fields.
|
||||
|
224
internal/impl/equal.go
Normal file
224
internal/impl/equal.go
Normal file
@ -0,0 +1,224 @@
|
||||
// Copyright 2024 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package impl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/runtime/protoiface"
|
||||
)
|
||||
|
||||
func equal(in protoiface.EqualInput) protoiface.EqualOutput {
|
||||
return protoiface.EqualOutput{Equal: equalMessage(in.MessageA, in.MessageB)}
|
||||
}
|
||||
|
||||
// equalMessage is a fast-path variant of protoreflect.equalMessage.
|
||||
// It takes advantage of the internal messageState type to avoid
|
||||
// unnecessary allocations, type assertions.
|
||||
func equalMessage(mx, my protoreflect.Message) bool {
|
||||
if mx == nil || my == nil {
|
||||
return mx == my
|
||||
}
|
||||
if mx.Descriptor() != my.Descriptor() {
|
||||
return false
|
||||
}
|
||||
|
||||
msx, ok := mx.(*messageState)
|
||||
if !ok {
|
||||
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
|
||||
}
|
||||
msy, ok := my.(*messageState)
|
||||
if !ok {
|
||||
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
|
||||
}
|
||||
|
||||
mi := msx.messageInfo()
|
||||
miy := msy.messageInfo()
|
||||
if mi != miy {
|
||||
return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my))
|
||||
}
|
||||
mi.init()
|
||||
// Compares regular fields
|
||||
// Modified Message.Range code that compares two messages of the same type
|
||||
// while going over the fields.
|
||||
for _, ri := range mi.rangeInfos {
|
||||
var fd protoreflect.FieldDescriptor
|
||||
var vx, vy protoreflect.Value
|
||||
|
||||
switch ri := ri.(type) {
|
||||
case *fieldInfo:
|
||||
hx := ri.has(msx.pointer())
|
||||
hy := ri.has(msy.pointer())
|
||||
if hx != hy {
|
||||
return false
|
||||
}
|
||||
if !hx {
|
||||
continue
|
||||
}
|
||||
fd = ri.fieldDesc
|
||||
vx = ri.get(msx.pointer())
|
||||
vy = ri.get(msy.pointer())
|
||||
case *oneofInfo:
|
||||
fnx := ri.which(msx.pointer())
|
||||
fny := ri.which(msy.pointer())
|
||||
if fnx != fny {
|
||||
return false
|
||||
}
|
||||
if fnx <= 0 {
|
||||
continue
|
||||
}
|
||||
fi := mi.fields[fnx]
|
||||
fd = fi.fieldDesc
|
||||
vx = fi.get(msx.pointer())
|
||||
vy = fi.get(msy.pointer())
|
||||
}
|
||||
|
||||
if !equalValue(fd, vx, vy) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Compare extensions.
|
||||
// This is more complicated because mx or my could have empty/nil extension maps,
|
||||
// however some populated extension map values are equal to nil extension maps.
|
||||
emx := mi.extensionMap(msx.pointer())
|
||||
emy := mi.extensionMap(msy.pointer())
|
||||
if emx != nil {
|
||||
for k, x := range *emx {
|
||||
xd := x.Type().TypeDescriptor()
|
||||
xv := x.Value()
|
||||
var y ExtensionField
|
||||
ok := false
|
||||
if emy != nil {
|
||||
y, ok = (*emy)[k]
|
||||
}
|
||||
// We need to treat empty lists as equal to nil values
|
||||
if emy == nil || !ok {
|
||||
if xd.IsList() && xv.List().Len() == 0 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !equalValue(xd, xv, y.Value()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
if emy != nil {
|
||||
// emy may have extensions emx does not have, need to check them as well
|
||||
for k, y := range *emy {
|
||||
if emx != nil {
|
||||
// emx has the field, so we already checked it
|
||||
if _, ok := (*emx)[k]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Empty lists are equal to nil
|
||||
if y.Type().TypeDescriptor().IsList() && y.Value().List().Len() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Cant be equal if the extension is populated
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return equalUnknown(mx.GetUnknown(), my.GetUnknown())
|
||||
}
|
||||
|
||||
func equalValue(fd protoreflect.FieldDescriptor, vx, vy protoreflect.Value) bool {
|
||||
// slow path
|
||||
if fd.Kind() != protoreflect.MessageKind {
|
||||
return vx.Equal(vy)
|
||||
}
|
||||
|
||||
// fast path special cases
|
||||
if fd.IsMap() {
|
||||
if fd.MapValue().Kind() == protoreflect.MessageKind {
|
||||
return equalMessageMap(vx.Map(), vy.Map())
|
||||
}
|
||||
return vx.Equal(vy)
|
||||
}
|
||||
|
||||
if fd.IsList() {
|
||||
return equalMessageList(vx.List(), vy.List())
|
||||
}
|
||||
|
||||
return equalMessage(vx.Message(), vy.Message())
|
||||
}
|
||||
|
||||
// Mostly copied from protoreflect.equalMap.
|
||||
// This variant only works for messages as map types.
|
||||
// All other map types should be handled via Value.Equal.
|
||||
func equalMessageMap(mx, my protoreflect.Map) bool {
|
||||
if mx.Len() != my.Len() {
|
||||
return false
|
||||
}
|
||||
equal := true
|
||||
mx.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
|
||||
if !my.Has(k) {
|
||||
equal = false
|
||||
return false
|
||||
}
|
||||
vy := my.Get(k)
|
||||
equal = equalMessage(vx.Message(), vy.Message())
|
||||
return equal
|
||||
})
|
||||
return equal
|
||||
}
|
||||
|
||||
// Mostly copied from protoreflect.equalList.
|
||||
// The only change is the usage of equalImpl instead of protoreflect.equalValue.
|
||||
func equalMessageList(lx, ly protoreflect.List) bool {
|
||||
if lx.Len() != ly.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < lx.Len(); i++ {
|
||||
// We only operate on messages here since equalImpl will not call us in any other case.
|
||||
if !equalMessage(lx.Get(i).Message(), ly.Get(i).Message()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// equalUnknown compares unknown fields by direct comparison on the raw bytes
|
||||
// of each individual field number.
|
||||
// Copied from protoreflect.equalUnknown.
|
||||
func equalUnknown(x, y protoreflect.RawFields) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
if bytes.Equal([]byte(x), []byte(y)) {
|
||||
return true
|
||||
}
|
||||
|
||||
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
|
||||
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
|
||||
for len(x) > 0 {
|
||||
fnum, _, n := protowire.ConsumeField(x)
|
||||
mx[fnum] = append(mx[fnum], x[:n]...)
|
||||
x = x[n:]
|
||||
}
|
||||
for len(y) > 0 {
|
||||
fnum, _, n := protowire.ConsumeField(y)
|
||||
my[fnum] = append(my[fnum], y[:n]...)
|
||||
y = y[n:]
|
||||
}
|
||||
if len(mx) != len(my) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, v1 := range mx {
|
||||
if v2, ok := my[k]; !ok || !bytes.Equal([]byte(v1), []byte(v2)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -142,6 +142,109 @@ message TestAllTypes {
|
||||
}
|
||||
}
|
||||
|
||||
message TestManyMessageFieldsMessage {
|
||||
optional TestAllTypes f1 = 1;
|
||||
optional TestAllTypes f2 = 2;
|
||||
optional TestAllTypes f3 = 3;
|
||||
optional TestAllTypes f4 = 4;
|
||||
optional TestAllTypes f5 = 5;
|
||||
optional TestAllTypes f6 = 6;
|
||||
optional TestAllTypes f7 = 7;
|
||||
optional TestAllTypes f8 = 8;
|
||||
optional TestAllTypes f9 = 9;
|
||||
optional TestAllTypes f10 = 10;
|
||||
optional TestAllTypes f11 = 11;
|
||||
optional TestAllTypes f12 = 12;
|
||||
optional TestAllTypes f13 = 13;
|
||||
optional TestAllTypes f14 = 14;
|
||||
optional TestAllTypes f15 = 15;
|
||||
optional TestAllTypes f16 = 16;
|
||||
optional TestAllTypes f17 = 17;
|
||||
optional TestAllTypes f18 = 18;
|
||||
optional TestAllTypes f19 = 19;
|
||||
optional TestAllTypes f20 = 20;
|
||||
optional TestAllTypes f21 = 21;
|
||||
optional TestAllTypes f22 = 22;
|
||||
optional TestAllTypes f23 = 23;
|
||||
optional TestAllTypes f24 = 24;
|
||||
optional TestAllTypes f25 = 25;
|
||||
optional TestAllTypes f26 = 26;
|
||||
optional TestAllTypes f27 = 27;
|
||||
optional TestAllTypes f28 = 28;
|
||||
optional TestAllTypes f29 = 29;
|
||||
optional TestAllTypes f30 = 30;
|
||||
optional TestAllTypes f31 = 31;
|
||||
optional TestAllTypes f32 = 32;
|
||||
optional TestAllTypes f33 = 33;
|
||||
optional TestAllTypes f34 = 34;
|
||||
optional TestAllTypes f35 = 35;
|
||||
optional TestAllTypes f36 = 36;
|
||||
optional TestAllTypes f37 = 37;
|
||||
optional TestAllTypes f38 = 38;
|
||||
optional TestAllTypes f39 = 39;
|
||||
optional TestAllTypes f40 = 40;
|
||||
optional TestAllTypes f41 = 41;
|
||||
optional TestAllTypes f42 = 42;
|
||||
optional TestAllTypes f43 = 43;
|
||||
optional TestAllTypes f44 = 44;
|
||||
optional TestAllTypes f45 = 45;
|
||||
optional TestAllTypes f46 = 46;
|
||||
optional TestAllTypes f47 = 47;
|
||||
optional TestAllTypes f48 = 48;
|
||||
optional TestAllTypes f49 = 49;
|
||||
optional TestAllTypes f50 = 50;
|
||||
optional TestAllTypes f51 = 51;
|
||||
optional TestAllTypes f52 = 52;
|
||||
optional TestAllTypes f53 = 53;
|
||||
optional TestAllTypes f54 = 54;
|
||||
optional TestAllTypes f55 = 55;
|
||||
optional TestAllTypes f56 = 56;
|
||||
optional TestAllTypes f57 = 57;
|
||||
optional TestAllTypes f58 = 58;
|
||||
optional TestAllTypes f59 = 59;
|
||||
optional TestAllTypes f60 = 60;
|
||||
optional TestAllTypes f61 = 61;
|
||||
optional TestAllTypes f62 = 62;
|
||||
optional TestAllTypes f63 = 63;
|
||||
optional TestAllTypes f64 = 64;
|
||||
optional TestAllTypes f65 = 65;
|
||||
optional TestAllTypes f66 = 66;
|
||||
optional TestAllTypes f67 = 67;
|
||||
optional TestAllTypes f68 = 68;
|
||||
optional TestAllTypes f69 = 69;
|
||||
optional TestAllTypes f70 = 70;
|
||||
optional TestAllTypes f71 = 71;
|
||||
optional TestAllTypes f72 = 72;
|
||||
optional TestAllTypes f73 = 73;
|
||||
optional TestAllTypes f74 = 74;
|
||||
optional TestAllTypes f75 = 75;
|
||||
optional TestAllTypes f76 = 76;
|
||||
optional TestAllTypes f77 = 77;
|
||||
optional TestAllTypes f78 = 78;
|
||||
optional TestAllTypes f79 = 79;
|
||||
optional TestAllTypes f80 = 80;
|
||||
optional TestAllTypes f81 = 81;
|
||||
optional TestAllTypes f82 = 82;
|
||||
optional TestAllTypes f83 = 83;
|
||||
optional TestAllTypes f84 = 84;
|
||||
optional TestAllTypes f85 = 85;
|
||||
optional TestAllTypes f86 = 86;
|
||||
optional TestAllTypes f87 = 87;
|
||||
optional TestAllTypes f88 = 88;
|
||||
optional TestAllTypes f89 = 89;
|
||||
optional TestAllTypes f90 = 90;
|
||||
optional TestAllTypes f91 = 91;
|
||||
optional TestAllTypes f92 = 92;
|
||||
optional TestAllTypes f93 = 93;
|
||||
optional TestAllTypes f94 = 94;
|
||||
optional TestAllTypes f95 = 95;
|
||||
optional TestAllTypes f96 = 96;
|
||||
optional TestAllTypes f97 = 97;
|
||||
optional TestAllTypes f98 = 98;
|
||||
optional TestAllTypes f99 = 99;
|
||||
optional TestAllTypes f100 = 100;
|
||||
}
|
||||
|
||||
message TestDeprecatedMessage {
|
||||
option deprecated = true;
|
||||
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/runtime/protoiface"
|
||||
)
|
||||
|
||||
// Equal reports whether two messages are equal,
|
||||
@ -51,6 +52,14 @@ func Equal(x, y Message) bool {
|
||||
if mx.IsValid() != my.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only one of the messages needs to implement the fast-path for it to work.
|
||||
pmx := protoMethods(mx)
|
||||
pmy := protoMethods(my)
|
||||
if pmx != nil && pmy != nil && pmx.Equal != nil && pmy.Equal != nil {
|
||||
return pmx.Equal(protoiface.EqualInput{MessageA: mx, MessageB: my}).Equal
|
||||
}
|
||||
|
||||
vx := protoreflect.ValueOfMessage(mx)
|
||||
vy := protoreflect.ValueOfMessage(my)
|
||||
return vx.Equal(vy)
|
||||
|
@ -1004,6 +1004,7 @@ func TestEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithSmallEmpty(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
x := &testpb.ForeignMessage{}
|
||||
y := &testpb.ForeignMessage{}
|
||||
|
||||
@ -1014,6 +1015,7 @@ func BenchmarkEqualWithSmallEmpty(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithIdenticalPtrEmpty(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
x := &testpb.ForeignMessage{}
|
||||
|
||||
b.ResetTimer()
|
||||
@ -1023,8 +1025,31 @@ func BenchmarkEqualWithIdenticalPtrEmpty(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithLargeEmpty(b *testing.B) {
|
||||
x := &testpb.TestAllTypes{}
|
||||
y := &testpb.TestAllTypes{}
|
||||
b.ReportAllocs()
|
||||
x := &testpb.TestManyMessageFieldsMessage{
|
||||
F1: makeNested(2),
|
||||
F10: makeNested(2),
|
||||
F20: makeNested(2),
|
||||
F30: makeNested(2),
|
||||
F40: makeNested(2),
|
||||
F50: makeNested(2),
|
||||
F60: makeNested(2),
|
||||
F70: makeNested(2),
|
||||
F80: makeNested(2),
|
||||
F90: makeNested(2),
|
||||
}
|
||||
y := &testpb.TestManyMessageFieldsMessage{
|
||||
F1: makeNested(2),
|
||||
F10: makeNested(2),
|
||||
F20: makeNested(2),
|
||||
F30: makeNested(2),
|
||||
F40: makeNested(2),
|
||||
F50: makeNested(2),
|
||||
F60: makeNested(2),
|
||||
F70: makeNested(2),
|
||||
F80: makeNested(2),
|
||||
F90: makeNested(2),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -1044,6 +1069,7 @@ func makeNested(depth int) *testpb.TestAllTypes {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithDeeplyNestedEqual(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
x := makeNested(20)
|
||||
y := makeNested(20)
|
||||
|
||||
@ -1054,6 +1080,7 @@ func BenchmarkEqualWithDeeplyNestedEqual(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithDeeplyNestedDifferent(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
x := makeNested(20)
|
||||
y := makeNested(21)
|
||||
|
||||
@ -1064,6 +1091,7 @@ func BenchmarkEqualWithDeeplyNestedDifferent(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkEqualWithDeeplyNestedIdenticalPtr(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
x := makeNested(20)
|
||||
|
||||
b.ResetTimer()
|
||||
|
@ -23,6 +23,7 @@ type (
|
||||
Unmarshal func(unmarshalInput) (unmarshalOutput, error)
|
||||
Merge func(mergeInput) mergeOutput
|
||||
CheckInitialized func(checkInitializedInput) (checkInitializedOutput, error)
|
||||
Equal func(equalInput) equalOutput
|
||||
}
|
||||
supportFlags = uint64
|
||||
sizeInput = struct {
|
||||
@ -75,4 +76,13 @@ type (
|
||||
checkInitializedOutput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
}
|
||||
equalInput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
MessageA Message
|
||||
MessageB Message
|
||||
}
|
||||
equalOutput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
Equal bool
|
||||
}
|
||||
)
|
||||
|
@ -39,6 +39,9 @@ type Methods = struct {
|
||||
|
||||
// CheckInitialized returns an error if any required fields in the message are not set.
|
||||
CheckInitialized func(CheckInitializedInput) (CheckInitializedOutput, error)
|
||||
|
||||
// Equal compares two messages and returns EqualOutput.Equal == true if they are equal.
|
||||
Equal func(EqualInput) EqualOutput
|
||||
}
|
||||
|
||||
// SupportFlags indicate support for optional features.
|
||||
@ -166,3 +169,18 @@ type CheckInitializedInput = struct {
|
||||
type CheckInitializedOutput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
}
|
||||
|
||||
// EqualInput is input to the Equal method.
|
||||
type EqualInput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
|
||||
MessageA protoreflect.Message
|
||||
MessageB protoreflect.Message
|
||||
}
|
||||
|
||||
// EqualOutput is output from the Equal method.
|
||||
type EqualOutput = struct {
|
||||
pragma.NoUnkeyedLiterals
|
||||
|
||||
Equal bool
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user