diff --git a/proto/merge.go b/proto/merge.go index f7af9980..fbd9c73f 100644 --- a/proto/merge.go +++ b/proto/merge.go @@ -4,7 +4,41 @@ package proto -import "google.golang.org/protobuf/reflect/protoreflect" +import ( + "google.golang.org/protobuf/internal/pragma" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// MergeOptions configures the merger. +// +// Example usage: +// MergeOptions{Shallow: true}.Merge(dst, src) +type MergeOptions struct { + pragma.NoUnkeyedLiterals + + // Shallow configures Merge to shallow copy messages, lists, and maps + // instead of allocating new ones in the destination if it does not already + // have one populated. Scalar bytes are copied by reference. + // If true, Merge must be given messages of the same concrete type. + // + // If false, Merge is guaranteed to produce deep copies of all mutable + // objects from the source into the destination. Since scalar bytes are + // mutable they are deep copied as a result. + // + // Invariant: + // var dst1, dst2 Message = ... + // Equal(dst1, dst2) // assume equal initially + // MergeOptions{Shallow: true}.Merge(dst1, src) + // MergeOptions{Shallow: false}.Merge(dst2, src) + // Equal(dst1, dst2) // equal regardless of whether Shallow is specified + Shallow bool +} + +// Merge merges src into dst, which must be messages with the same descriptor. +// See MergeOptions.Merge for details. +func Merge(dst, src Message) { + MergeOptions{}.Merge(dst, src) +} // Merge merges src into dst, which must be messages with the same descriptor. // @@ -14,25 +48,46 @@ import "google.golang.org/protobuf/reflect/protoreflect" // list fields in dst. The entries of every map field in src is copied into // the corresponding map field in dst, possibly replacing existing entries. // The unknown fields of src are appended to the unknown fields of dst. -func Merge(dst, src Message) { - mergeMessage(dst.ProtoReflect(), src.ProtoReflect()) +// +// It is semantically equivalent to unmarshaling the encoded form of src +// into dst with the UnmarshalOptions.Merge option specified. +func (o MergeOptions) Merge(dst, src Message) { + dstMsg, srcMsg := dst.ProtoReflect(), src.ProtoReflect() + if o.Shallow { + if dstMsg.Type() != srcMsg.Type() { + panic("type mismatch") + } + } else { + if dstMsg.Descriptor() != srcMsg.Descriptor() { + panic("descriptor mismatch") + } + } + o.mergeMessage(dstMsg, srcMsg) } -func mergeMessage(dst, src protoreflect.Message) { - if dst.Descriptor() != src.Descriptor() { - panic("descriptor mismatch") - } - +func (o MergeOptions) mergeMessage(dst, src protoreflect.Message) { src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { switch { case fd.IsList(): - mergeList(dst.Mutable(fd).List(), v.List(), fd) + if o.Shallow && !dst.Has(fd) { + dst.Set(fd, v) + } else { + o.mergeList(dst.Mutable(fd).List(), v.List(), fd) + } case fd.IsMap(): - mergeMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue()) + if o.Shallow && !dst.Has(fd) { + dst.Set(fd, v) + } else { + o.mergeMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue()) + } case fd.Message() != nil: - mergeMessage(dst.Mutable(fd).Message(), v.Message()) + if o.Shallow && !dst.Has(fd) { + dst.Set(fd, v) + } else { + o.mergeMessage(dst.Mutable(fd).Message(), v.Message()) + } case fd.Kind() == protoreflect.BytesKind: - dst.Set(fd, cloneBytes(v)) + dst.Set(fd, o.cloneBytes(v)) default: dst.Set(fd, v) } @@ -40,34 +95,48 @@ func mergeMessage(dst, src protoreflect.Message) { }) if len(src.GetUnknown()) > 0 { - dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...)) + if o.Shallow && dst.GetUnknown() == nil { + dst.SetUnknown(src.GetUnknown()) + } else { + dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...)) + } } } -func mergeList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) { +func (o MergeOptions) mergeList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) { + // Merge semantics appends to the end of the existing list. for i, n := 0, src.Len(); i < n; i++ { switch v := src.Get(i); { case fd.Message() != nil: - dstv := dst.NewElement() - mergeMessage(dstv.Message(), v.Message()) - dst.Append(dstv) + if o.Shallow { + dst.Append(v) + } else { + dstv := dst.NewElement() + o.mergeMessage(dstv.Message(), v.Message()) + dst.Append(dstv) + } case fd.Kind() == protoreflect.BytesKind: - dst.Append(cloneBytes(v)) + dst.Append(o.cloneBytes(v)) default: dst.Append(v) } } } -func mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) { +func (o MergeOptions) mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) { + // Merge semantics replaces, rather than merges into existing entries. src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { switch { case fd.Message() != nil: - dstv := dst.NewValue() - mergeMessage(dstv.Message(), v.Message()) - dst.Set(k, dstv) // may replace existing entry + if o.Shallow { + dst.Set(k, v) + } else { + dstv := dst.NewValue() + o.mergeMessage(dstv.Message(), v.Message()) + dst.Set(k, dstv) + } case fd.Kind() == protoreflect.BytesKind: - dst.Set(k, cloneBytes(v)) + dst.Set(k, o.cloneBytes(v)) default: dst.Set(k, v) } @@ -75,6 +144,9 @@ func mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) { }) } -func cloneBytes(v protoreflect.Value) protoreflect.Value { +func (o MergeOptions) cloneBytes(v protoreflect.Value) protoreflect.Value { + if o.Shallow { + return v + } return protoreflect.ValueOfBytes(append([]byte{}, v.Bytes()...)) } diff --git a/proto/merge_test.go b/proto/merge_test.go index 9b6ee0f5..18243cc4 100644 --- a/proto/merge_test.go +++ b/proto/merge_test.go @@ -15,18 +15,22 @@ import ( ) func TestMerge(t *testing.T) { - dst := new(testpb.TestAllTypes) - src := (*testpb.TestAllTypes)(nil) - proto.Merge(dst, src) - // Mutating the source should not affect dst. + t.Run("Deep", func(t *testing.T) { testMerge(t, false) }) + t.Run("Shallow", func(t *testing.T) { testMerge(t, true) }) +} +func testMerge(t *testing.T, shallow bool) { tests := []struct { - desc string - dst proto.Message - src proto.Message - want proto.Message - mutator func(proto.Message) // if provided, is run on src after merging + desc string + dst proto.Message + src proto.Message + want proto.Message + + // If provided, mutator is run on src after merging. + // It reports whether a mutation is expected to be observable in dst + // if Shallow is enabled. + mutator func(proto.Message) bool }{{ desc: "merge from nil message", dst: new(testpb.TestAllTypes), @@ -85,7 +89,7 @@ func TestMerge(t *testing.T) { }, }, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) *m.OptionalInt64++ *m.OptionalNestedEnum++ @@ -95,6 +99,7 @@ func TestMerge(t *testing.T) { delete(m.MapStringNestedEnum, "fizz") *m.MapStringNestedMessage["foo"].A++ *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++ + return true }, }, { desc: "merge bytes", @@ -113,11 +118,12 @@ func TestMerge(t *testing.T) { RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}}, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) m.OptionalBytes[0]++ m.RepeatedBytes[0][0]++ m.MapStringBytes["alpha"][0]++ + return true }, }, { desc: "merge singular fields", @@ -150,11 +156,12 @@ func TestMerge(t *testing.T) { }, }, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) *m.OptionalInt64++ *m.OptionalNestedEnum++ *m.OptionalNestedMessage.A++ + return false // scalar mutations are not observable in shallow copy }, }, { desc: "merge list fields", @@ -181,10 +188,11 @@ func TestMerge(t *testing.T) { {A: proto.Int32(400)}, }, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) m.RepeatedSfixed32[0]++ *m.RepeatedNestedMessage[0].A++ + return true }, }, { desc: "merge map fields", @@ -219,10 +227,11 @@ func TestMerge(t *testing.T) { "bar": {}, }, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) delete(m.MapStringNestedEnum, "fizz") m.MapStringNestedMessage["bar"].A = proto.Int32(1) + return true }, }, { desc: "merge oneof message fields", @@ -252,9 +261,10 @@ func TestMerge(t *testing.T) { }, }, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++ + return true }, }, { desc: "merge oneof scalar fields", @@ -267,9 +277,10 @@ func TestMerge(t *testing.T) { want: &testpb.TestAllTypes{ OneofField: &testpb.TestAllTypes_OneofFloat{3.14152}, }, - mutator: func(mi proto.Message) { + mutator: func(mi proto.Message) bool { m := mi.(*testpb.TestAllTypes) m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++ + return false // scalar mutations are not observable in shallow copy }, }, { desc: "merge extension fields", @@ -359,13 +370,17 @@ func TestMerge(t *testing.T) { t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want) } - proto.Merge(tt.dst, tt.src) - if tt.mutator != nil { - tt.mutator(tt.src) // should not be observable by dst - } + proto.MergeOptions{Shallow: shallow}.Merge(tt.dst, tt.src) if !proto.Equal(tt.dst, tt.want) { t.Fatalf("Merge() mismatch:\n got %v\nwant %v", tt.dst, tt.want) } + if tt.mutator != nil { + wantObservable := tt.mutator(tt.src) && shallow + gotObservable := !proto.Equal(tt.dst, tt.want) + if gotObservable != wantObservable { + t.Fatalf("mutation observed:\n got %v\nwant %v", gotObservable, wantObservable) + } + } }) } }