mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-10 07:14:24 +00:00
proto: implement Merge
Change-Id: Ibb579bf5ad8548359dfd9805fd3022bcd53a6379 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/183679 Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
424139789a
commit
2fc306a8e3
78
proto/merge.go
Normal file
78
proto/merge.go
Normal file
@ -0,0 +1,78 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto
|
||||
|
||||
import "google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
// Merge merges src into dst, which must be messages with the same descriptor.
|
||||
//
|
||||
// Populated scalar fields in src are copied to dst, while populated
|
||||
// singular messages in src are merged into dst by recursively calling Merge.
|
||||
// The elements of every list field in src is appended to the corresponded
|
||||
// list fields in dst. The entries of every map field in src is copied into
|
||||
// the corresponding map field in dst, possibly replacing existing entries.
|
||||
// The unknown fields of src are appended to the unknown fields of dst.
|
||||
func Merge(dst, src Message) {
|
||||
mergeMessage(dst.ProtoReflect(), src.ProtoReflect())
|
||||
}
|
||||
|
||||
func mergeMessage(dst, src protoreflect.Message) {
|
||||
if dst.Descriptor() != src.Descriptor() {
|
||||
panic("descriptor mismatch")
|
||||
}
|
||||
|
||||
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
||||
switch {
|
||||
case fd.IsList():
|
||||
mergeList(dst.Mutable(fd).List(), v.List(), fd)
|
||||
case fd.IsMap():
|
||||
mergeMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue())
|
||||
case fd.Message() != nil:
|
||||
mergeMessage(dst.Mutable(fd).Message(), v.Message())
|
||||
case fd.Kind() == protoreflect.BytesKind:
|
||||
dst.Set(fd, cloneBytes(v))
|
||||
default:
|
||||
dst.Set(fd, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
|
||||
}
|
||||
|
||||
func mergeList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) {
|
||||
for i := 0; i < src.Len(); i++ {
|
||||
switch v := src.Get(i); {
|
||||
case fd.Message() != nil:
|
||||
m := dst.NewMessage()
|
||||
mergeMessage(m, v.Message())
|
||||
dst.Append(protoreflect.ValueOf(m))
|
||||
case fd.Kind() == protoreflect.BytesKind:
|
||||
dst.Append(cloneBytes(v))
|
||||
default:
|
||||
dst.Append(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) {
|
||||
src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||
switch {
|
||||
case fd.Message() != nil:
|
||||
m := dst.NewMessage()
|
||||
mergeMessage(m, v.Message())
|
||||
dst.Set(k, protoreflect.ValueOf(m)) // may replace existing entry
|
||||
case fd.Kind() == protoreflect.BytesKind:
|
||||
dst.Set(k, cloneBytes(v))
|
||||
default:
|
||||
dst.Set(k, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func cloneBytes(v protoreflect.Value) protoreflect.Value {
|
||||
return protoreflect.ValueOf(append([]byte{}, v.Bytes()...))
|
||||
}
|
398
proto/merge_test.go
Normal file
398
proto/merge_test.go
Normal file
@ -0,0 +1,398 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/internal/encoding/pack"
|
||||
"google.golang.org/protobuf/internal/scalar"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
testpb "google.golang.org/protobuf/internal/testprotos/test"
|
||||
)
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
dst := new(testpb.TestAllTypes)
|
||||
src := (*testpb.TestAllTypes)(nil)
|
||||
proto.Merge(dst, src)
|
||||
|
||||
// Mutating the source should not affect dst.
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
dst proto.Message
|
||||
src proto.Message
|
||||
want proto.Message
|
||||
mutator func(proto.Message) // if provided, is run on src after merging
|
||||
|
||||
skipMarshalUnmarshal bool // TODO: Remove this when proto.Unmarshal is fixed for messages in oneofs
|
||||
}{{
|
||||
desc: "merge from nil message",
|
||||
dst: new(testpb.TestAllTypes),
|
||||
src: (*testpb.TestAllTypes)(nil),
|
||||
want: new(testpb.TestAllTypes),
|
||||
}, {
|
||||
desc: "clone a large message",
|
||||
dst: new(testpb.TestAllTypes),
|
||||
src: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(0),
|
||||
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
|
||||
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(100),
|
||||
},
|
||||
RepeatedSfixed32: []int32{1, 2, 3},
|
||||
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
|
||||
{A: scalar.Int32(200)},
|
||||
{A: scalar.Int32(300)},
|
||||
},
|
||||
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
|
||||
"fizz": 400,
|
||||
"buzz": 500,
|
||||
},
|
||||
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
|
||||
"foo": {A: scalar.Int32(600)},
|
||||
"bar": {A: scalar.Int32(700)},
|
||||
},
|
||||
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
|
||||
&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(800),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(0),
|
||||
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
|
||||
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(100),
|
||||
},
|
||||
RepeatedSfixed32: []int32{1, 2, 3},
|
||||
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
|
||||
{A: scalar.Int32(200)},
|
||||
{A: scalar.Int32(300)},
|
||||
},
|
||||
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
|
||||
"fizz": 400,
|
||||
"buzz": 500,
|
||||
},
|
||||
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
|
||||
"foo": {A: scalar.Int32(600)},
|
||||
"bar": {A: scalar.Int32(700)},
|
||||
},
|
||||
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
|
||||
&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(800),
|
||||
},
|
||||
},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
*m.OptionalInt64++
|
||||
*m.OptionalNestedEnum++
|
||||
*m.OptionalNestedMessage.A++
|
||||
m.RepeatedSfixed32[0]++
|
||||
*m.RepeatedNestedMessage[0].A++
|
||||
delete(m.MapStringNestedEnum, "fizz")
|
||||
*m.MapStringNestedMessage["foo"].A++
|
||||
*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
|
||||
},
|
||||
}, {
|
||||
desc: "merge bytes",
|
||||
dst: &testpb.TestAllTypes{
|
||||
OptionalBytes: []byte{1, 2, 3},
|
||||
RepeatedBytes: [][]byte{{1, 2}, {3, 4}},
|
||||
MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
OptionalBytes: []byte{4, 5, 6},
|
||||
RepeatedBytes: [][]byte{{5, 6}, {7, 8}},
|
||||
MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
OptionalBytes: []byte{4, 5, 6},
|
||||
RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
|
||||
MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
m.OptionalBytes[0]++
|
||||
m.RepeatedBytes[0][0]++
|
||||
m.MapStringBytes["alpha"][0]++
|
||||
},
|
||||
}, {
|
||||
desc: "merge singular fields",
|
||||
dst: &testpb.TestAllTypes{
|
||||
OptionalInt32: scalar.Int32(1),
|
||||
OptionalInt64: scalar.Int64(1),
|
||||
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
|
||||
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(100),
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(2),
|
||||
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
|
||||
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(200),
|
||||
},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
OptionalInt32: scalar.Int32(1),
|
||||
OptionalInt64: scalar.Int64(2),
|
||||
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
|
||||
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(200),
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
*m.OptionalInt64++
|
||||
*m.OptionalNestedEnum++
|
||||
*m.OptionalNestedMessage.A++
|
||||
},
|
||||
}, {
|
||||
desc: "merge list fields",
|
||||
dst: &testpb.TestAllTypes{
|
||||
RepeatedSfixed32: []int32{1, 2, 3},
|
||||
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
|
||||
{A: scalar.Int32(100)},
|
||||
{A: scalar.Int32(200)},
|
||||
},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
RepeatedSfixed32: []int32{4, 5, 6},
|
||||
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
|
||||
{A: scalar.Int32(300)},
|
||||
{A: scalar.Int32(400)},
|
||||
},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
|
||||
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
|
||||
{A: scalar.Int32(100)},
|
||||
{A: scalar.Int32(200)},
|
||||
{A: scalar.Int32(300)},
|
||||
{A: scalar.Int32(400)},
|
||||
},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
m.RepeatedSfixed32[0]++
|
||||
*m.RepeatedNestedMessage[0].A++
|
||||
},
|
||||
}, {
|
||||
desc: "merge map fields",
|
||||
dst: &testpb.TestAllTypes{
|
||||
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
|
||||
"fizz": 100,
|
||||
"buzz": 200,
|
||||
"guzz": 300,
|
||||
},
|
||||
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
|
||||
"foo": {A: scalar.Int32(400)},
|
||||
},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
|
||||
"fizz": 1000,
|
||||
"buzz": 2000,
|
||||
},
|
||||
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
|
||||
"foo": {A: scalar.Int32(3000)},
|
||||
"bar": {},
|
||||
},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
|
||||
"fizz": 1000,
|
||||
"buzz": 2000,
|
||||
"guzz": 300,
|
||||
},
|
||||
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
|
||||
"foo": {A: scalar.Int32(3000)},
|
||||
"bar": {},
|
||||
},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
delete(m.MapStringNestedEnum, "fizz")
|
||||
m.MapStringNestedMessage["bar"].A = scalar.Int32(1)
|
||||
},
|
||||
}, {
|
||||
desc: "merge oneof message fields",
|
||||
dst: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
|
||||
&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
|
||||
&testpb.TestAllTypes_NestedMessage{
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
|
||||
&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(100),
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
|
||||
},
|
||||
skipMarshalUnmarshal: true,
|
||||
}, {
|
||||
desc: "merge oneof scalar fields",
|
||||
dst: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofUint32{100},
|
||||
},
|
||||
src: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
|
||||
},
|
||||
want: &testpb.TestAllTypes{
|
||||
OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
|
||||
},
|
||||
mutator: func(mi proto.Message) {
|
||||
m := mi.(*testpb.TestAllTypes)
|
||||
m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
|
||||
},
|
||||
}, {
|
||||
desc: "merge extension fields",
|
||||
dst: func() proto.Message {
|
||||
m := new(testpb.TestAllExtensions)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalInt32Extension.Type,
|
||||
testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalNestedMessageExtension.Type,
|
||||
testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(50),
|
||||
}),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_RepeatedFixed32Extension.Type,
|
||||
testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}),
|
||||
)
|
||||
return m
|
||||
}(),
|
||||
src: func() proto.Message {
|
||||
m := new(testpb.TestAllExtensions)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalInt64Extension.Type,
|
||||
testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalNestedMessageExtension.Type,
|
||||
testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
}),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_RepeatedFixed32Extension.Type,
|
||||
testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}),
|
||||
)
|
||||
return m
|
||||
}(),
|
||||
want: func() proto.Message {
|
||||
m := new(testpb.TestAllExtensions)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalInt32Extension.Type,
|
||||
testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalInt64Extension.Type,
|
||||
testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_OptionalNestedMessageExtension.Type,
|
||||
testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
|
||||
A: scalar.Int32(50),
|
||||
Corecursive: &testpb.TestAllTypes{
|
||||
OptionalInt64: scalar.Int64(1000),
|
||||
},
|
||||
}),
|
||||
)
|
||||
m.ProtoReflect().Set(
|
||||
testpb.E_RepeatedFixed32Extension.Type,
|
||||
testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}),
|
||||
)
|
||||
return m
|
||||
}(),
|
||||
}, {
|
||||
desc: "merge unknown fields",
|
||||
dst: func() proto.Message {
|
||||
m := new(testpb.TestAllTypes)
|
||||
m.ProtoReflect().SetUnknown(pack.Message{
|
||||
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
|
||||
}.Marshal())
|
||||
return m
|
||||
}(),
|
||||
src: func() proto.Message {
|
||||
m := new(testpb.TestAllTypes)
|
||||
m.ProtoReflect().SetUnknown(pack.Message{
|
||||
pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
|
||||
}.Marshal())
|
||||
return m
|
||||
}(),
|
||||
want: func() proto.Message {
|
||||
m := new(testpb.TestAllTypes)
|
||||
m.ProtoReflect().SetUnknown(pack.Message{
|
||||
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
|
||||
pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
|
||||
}.Marshal())
|
||||
return m
|
||||
}(),
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
// Merge should be semantically equivalent to unmarshaling the
|
||||
// encoded form of src into the current dst.
|
||||
b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal(dst) error: %v", err)
|
||||
}
|
||||
b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal(src) error: %v", err)
|
||||
}
|
||||
dst := tt.dst.ProtoReflect().New().Interface()
|
||||
err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal() error: %v", err)
|
||||
}
|
||||
if !proto.Equal(dst, tt.want) && !tt.skipMarshalUnmarshal {
|
||||
t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
|
||||
}
|
||||
|
||||
proto.Merge(tt.dst, tt.src)
|
||||
if tt.mutator != nil {
|
||||
tt.mutator(tt.src) // should not be observable by dst
|
||||
}
|
||||
if !proto.Equal(tt.dst, tt.want) {
|
||||
t.Fatalf("Merge() mismatch: got %v, want %v", tt.dst, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user