diff --git a/testing/protocmp/util_test.go b/testing/protocmp/util_test.go index f9da7e52..6f41ba4f 100644 --- a/testing/protocmp/util_test.go +++ b/testing/protocmp/util_test.go @@ -145,6 +145,65 @@ func TestEqual(t *testing.T) { want: true, }}...) + // Test message values. + tests = append(tests, []test{{ + x: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}, + y: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}, + opts: cmp.Options{Transform()}, + want: true, + }, { + x: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}, + y: testpb.TestAllTypes{OptionalSint64: proto.Int64(2)}, + opts: cmp.Options{Transform()}, + want: false, + }, { + x: struct{ M testpb.TestAllTypes }{M: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}}, + y: struct{ M testpb.TestAllTypes }{M: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}}, + opts: cmp.Options{Transform()}, + want: true, + }, { + x: struct{ M testpb.TestAllTypes }{M: testpb.TestAllTypes{OptionalSint64: proto.Int64(1)}}, + y: struct{ M testpb.TestAllTypes }{M: testpb.TestAllTypes{OptionalSint64: proto.Int64(2)}}, + opts: cmp.Options{Transform()}, + want: false, + }, { + x: struct{ M []testpb.TestAllTypes }{M: []testpb.TestAllTypes{{OptionalSint64: proto.Int64(1)}}}, + y: struct{ M []testpb.TestAllTypes }{M: []testpb.TestAllTypes{{OptionalSint64: proto.Int64(1)}}}, + opts: cmp.Options{Transform()}, + want: true, + }, { + x: struct{ M []testpb.TestAllTypes }{M: []testpb.TestAllTypes{{OptionalSint64: proto.Int64(1)}}}, + y: struct{ M []testpb.TestAllTypes }{M: []testpb.TestAllTypes{{OptionalSint64: proto.Int64(2)}}}, + opts: cmp.Options{Transform()}, + want: false, + }, { + x: struct { + M map[string]testpb.TestAllTypes + }{ + M: map[string]testpb.TestAllTypes{"k": {OptionalSint64: proto.Int64(1)}}, + }, + y: struct { + M map[string]testpb.TestAllTypes + }{ + M: map[string]testpb.TestAllTypes{"k": {OptionalSint64: proto.Int64(1)}}, + }, + opts: cmp.Options{Transform()}, + want: true, + }, { + x: struct { + M map[string]testpb.TestAllTypes + }{ + M: map[string]testpb.TestAllTypes{"k": {OptionalSint64: proto.Int64(1)}}, + }, + y: struct { + M map[string]testpb.TestAllTypes + }{ + M: map[string]testpb.TestAllTypes{"k": {OptionalSint64: proto.Int64(2)}}, + }, + opts: cmp.Options{Transform()}, + want: false, + }}...) + // Test IgnoreUnknown. raw := pack.Message{ pack.Tag{1, pack.BytesType}, pack.String("Hello, goodbye!"), diff --git a/testing/protocmp/xform.go b/testing/protocmp/xform.go index e27c8908..a42bf9bb 100644 --- a/testing/protocmp/xform.go +++ b/testing/protocmp/xform.go @@ -161,10 +161,18 @@ func Transform(...option) cmp.Option { // NOTE: There are currently no custom options for Transform, // but the use of an unexported type keeps the future open. + // addrType returns a pointer to t if t isn't a pointer or interface. + addrType := func(t reflect.Type) reflect.Type { + if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr { + return t + } + return reflect.PtrTo(t) + } + // TODO: Should this transform protoreflect.Enum types to Enum as well? return cmp.FilterPath(func(p cmp.Path) bool { ps := p.Last() - if isMessageType(ps.Type()) { + if isMessageType(addrType(ps.Type())) { return true } @@ -175,11 +183,19 @@ func Transform(...option) cmp.Option { if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { return false } - return isMessageType(vx.Elem().Type()) && isMessageType(vy.Elem().Type()) + return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) } return false }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message { + // For user convenience, shallow copy the message value if necessary + // in order for it to implement the message interface. + if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) { + pv := reflect.New(rv.Type()) + pv.Elem().Set(rv) + v = pv.Interface() + } + m := protoimpl.X.MessageOf(v) switch { case m == nil: