diff --git a/proto/equal.go b/proto/equal.go index 1aca6703..44f840a8 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -6,6 +6,7 @@ package proto import ( "bytes" + "math" "reflect" "google.golang.org/protobuf/internal/encoding/wire" @@ -13,14 +14,16 @@ import ( ) // Equal reports whether two messages are equal. +// If two messages marshal to the same bytes under deterministic serialization, +// then Equal is guaranteed to report true. // // Two messages are equal if they belong to the same message descriptor, // have the same set of populated known and extension field values, // and the same set of unknown fields values. // // Scalar values are compared with the equivalent of the == operator in Go, -// except bytes values which are compared using bytes.Equal. -// Note that this means that floating point NaNs are considered inequal. +// except bytes values which are compared using bytes.Equal and +// floating point values which specially treat NaNs as equal. // Message values are compared by recursively calling Equal. // Lists are equal if each element value is also equal. // Maps are equal if they have the same set of keys, where the pair of values @@ -104,6 +107,13 @@ func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool { return equalMessage(x.Message(), y.Message()) case fd.Kind() == pref.BytesKind: return bytes.Equal(x.Bytes(), y.Bytes()) + case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind: + fx := x.Float() + fy := y.Float() + if math.IsNaN(fx) || math.IsNaN(fy) { + return math.IsNaN(fx) && math.IsNaN(fy) + } + return fx == fy default: return x.Interface() == y.Interface() } diff --git a/proto/equal_test.go b/proto/equal_test.go index 40507a9f..3afca838 100644 --- a/proto/equal_test.go +++ b/proto/equal_test.go @@ -5,6 +5,7 @@ package proto_test import ( + "math" "testing" "google.golang.org/protobuf/internal/encoding/pack" @@ -74,6 +75,14 @@ var inequalities = []struct{ a, b proto.Message }{ &testpb.TestAllTypes{OptionalDouble: proto.Float64(1)}, &testpb.TestAllTypes{OptionalDouble: proto.Float64(2)}, }, + { + &testpb.TestAllTypes{OptionalFloat: proto.Float32(float32(math.NaN()))}, + &testpb.TestAllTypes{OptionalFloat: proto.Float32(0)}, + }, + { + &testpb.TestAllTypes{OptionalDouble: proto.Float64(float64(math.NaN()))}, + &testpb.TestAllTypes{OptionalDouble: proto.Float64(0)}, + }, { &testpb.TestAllTypes{OptionalBool: proto.Bool(true)}, &testpb.TestAllTypes{OptionalBool: proto.Bool(false)},