protobuf-go/proto/merge_test.go
Joe Tsai 641611d984 proto: fix self-merging
While odd, it is possible to merge a message into itself.
In such a situation, the material impact is that repeated
and unknown fields are duplicated. The previous logic would
inifinite loop since the list iteration logic uses the current
length, but since the current length is ever growing, this loop
will never terminate. Instead, record the list length once
and iterate exactly that many times.

Change-Id: Ief98afa1b20bd950a9c2422d4462b170dbe6fa11
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/196857
Reviewed-by: Damien Neil <dneil@google.com>
2019-09-23 16:14:39 +00:00

442 lines
13 KiB
Go

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proto_test
import (
"sync"
"testing"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/proto"
testpb "google.golang.org/protobuf/internal/testprotos/test"
)
func TestMerge(t *testing.T) {
dst := new(testpb.TestAllTypes)
src := (*testpb.TestAllTypes)(nil)
proto.Merge(dst, src)
// Mutating the source should not affect dst.
tests := []struct {
desc string
dst proto.Message
src proto.Message
want proto.Message
mutator func(proto.Message) // if provided, is run on src after merging
}{{
desc: "merge from nil message",
dst: new(testpb.TestAllTypes),
src: (*testpb.TestAllTypes)(nil),
want: new(testpb.TestAllTypes),
}, {
desc: "clone a large message",
dst: new(testpb.TestAllTypes),
src: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(0),
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(100),
},
RepeatedSfixed32: []int32{1, 2, 3},
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: proto.Int32(200)},
{A: proto.Int32(300)},
},
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
"fizz": 400,
"buzz": 500,
},
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
"foo": {A: proto.Int32(600)},
"bar": {A: proto.Int32(700)},
},
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(800),
},
},
},
want: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(0),
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(100),
},
RepeatedSfixed32: []int32{1, 2, 3},
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: proto.Int32(200)},
{A: proto.Int32(300)},
},
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
"fizz": 400,
"buzz": 500,
},
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
"foo": {A: proto.Int32(600)},
"bar": {A: proto.Int32(700)},
},
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(800),
},
},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
*m.OptionalInt64++
*m.OptionalNestedEnum++
*m.OptionalNestedMessage.A++
m.RepeatedSfixed32[0]++
*m.RepeatedNestedMessage[0].A++
delete(m.MapStringNestedEnum, "fizz")
*m.MapStringNestedMessage["foo"].A++
*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
},
}, {
desc: "merge bytes",
dst: &testpb.TestAllTypes{
OptionalBytes: []byte{1, 2, 3},
RepeatedBytes: [][]byte{{1, 2}, {3, 4}},
MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
},
src: &testpb.TestAllTypes{
OptionalBytes: []byte{4, 5, 6},
RepeatedBytes: [][]byte{{5, 6}, {7, 8}},
MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
},
want: &testpb.TestAllTypes{
OptionalBytes: []byte{4, 5, 6},
RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
m.OptionalBytes[0]++
m.RepeatedBytes[0][0]++
m.MapStringBytes["alpha"][0]++
},
}, {
desc: "merge singular fields",
dst: &testpb.TestAllTypes{
OptionalInt32: proto.Int32(1),
OptionalInt64: proto.Int64(1),
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(100),
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
},
src: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(2),
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(200),
},
},
want: &testpb.TestAllTypes{
OptionalInt32: proto.Int32(1),
OptionalInt64: proto.Int64(2),
OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(200),
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
*m.OptionalInt64++
*m.OptionalNestedEnum++
*m.OptionalNestedMessage.A++
},
}, {
desc: "merge list fields",
dst: &testpb.TestAllTypes{
RepeatedSfixed32: []int32{1, 2, 3},
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: proto.Int32(100)},
{A: proto.Int32(200)},
},
},
src: &testpb.TestAllTypes{
RepeatedSfixed32: []int32{4, 5, 6},
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: proto.Int32(300)},
{A: proto.Int32(400)},
},
},
want: &testpb.TestAllTypes{
RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: proto.Int32(100)},
{A: proto.Int32(200)},
{A: proto.Int32(300)},
{A: proto.Int32(400)},
},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
m.RepeatedSfixed32[0]++
*m.RepeatedNestedMessage[0].A++
},
}, {
desc: "merge map fields",
dst: &testpb.TestAllTypes{
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
"fizz": 100,
"buzz": 200,
"guzz": 300,
},
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
"foo": {A: proto.Int32(400)},
},
},
src: &testpb.TestAllTypes{
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
"fizz": 1000,
"buzz": 2000,
},
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
"foo": {A: proto.Int32(3000)},
"bar": {},
},
},
want: &testpb.TestAllTypes{
MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
"fizz": 1000,
"buzz": 2000,
"guzz": 300,
},
MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
"foo": {A: proto.Int32(3000)},
"bar": {},
},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
delete(m.MapStringNestedEnum, "fizz")
m.MapStringNestedMessage["bar"].A = proto.Int32(1)
},
}, {
desc: "merge oneof message fields",
dst: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(100),
},
},
},
src: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
},
},
want: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(100),
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
},
}, {
desc: "merge oneof scalar fields",
dst: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofUint32{100},
},
src: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
},
want: &testpb.TestAllTypes{
OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
},
mutator: func(mi proto.Message) {
m := mi.(*testpb.TestAllTypes)
m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
},
}, {
desc: "merge extension fields",
dst: func() proto.Message {
m := new(testpb.TestAllExtensions)
proto.SetExtension(m, testpb.E_OptionalInt32Extension, int32(32))
proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(50),
},
)
proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3})
return m
}(),
src: func() proto.Message {
m := new(testpb.TestAllExtensions)
proto.SetExtension(m, testpb.E_OptionalInt64Extension, int64(64))
proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
&testpb.TestAllTypes_NestedMessage{
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
)
proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{4, 5, 6})
return m
}(),
want: func() proto.Message {
m := new(testpb.TestAllExtensions)
proto.SetExtension(m, testpb.E_OptionalInt32Extension, int32(32))
proto.SetExtension(m, testpb.E_OptionalInt64Extension, int64(64))
proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension,
&testpb.TestAllTypes_NestedMessage{
A: proto.Int32(50),
Corecursive: &testpb.TestAllTypes{
OptionalInt64: proto.Int64(1000),
},
},
)
proto.SetExtension(m, testpb.E_RepeatedFixed32Extension, []uint32{1, 2, 3, 4, 5, 6})
return m
}(),
}, {
desc: "merge unknown fields",
dst: func() proto.Message {
m := new(testpb.TestAllTypes)
m.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
}.Marshal())
return m
}(),
src: func() proto.Message {
m := new(testpb.TestAllTypes)
m.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
}.Marshal())
return m
}(),
want: func() proto.Message {
m := new(testpb.TestAllTypes)
m.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
}.Marshal())
return m
}(),
}}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
// Merge should be semantically equivalent to unmarshaling the
// encoded form of src into the current dst.
b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
if err != nil {
t.Fatalf("Marshal(dst) error: %v", err)
}
b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
if err != nil {
t.Fatalf("Marshal(src) error: %v", err)
}
dst := tt.dst.ProtoReflect().New().Interface()
err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
if err != nil {
t.Fatalf("Unmarshal() error: %v", err)
}
if !proto.Equal(dst, tt.want) {
t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
}
proto.Merge(tt.dst, tt.src)
if tt.mutator != nil {
tt.mutator(tt.src) // should not be observable by dst
}
if !proto.Equal(tt.dst, tt.want) {
t.Fatalf("Merge() mismatch:\n got %v\nwant %v", tt.dst, tt.want)
}
})
}
}
func TestMergeRace(t *testing.T) {
dst := new(testpb.TestAllTypes)
srcs := []*testpb.TestAllTypes{
{OptionalInt32: proto.Int32(1)},
{OptionalString: proto.String("hello")},
{RepeatedInt32: []int32{2, 3, 4}},
{RepeatedString: []string{"goodbye"}},
{MapStringString: map[string]string{"key": "value"}},
{OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(5),
}},
func() *testpb.TestAllTypes {
m := new(testpb.TestAllTypes)
m.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
}.Marshal())
return m
}(),
}
// It should be safe to concurrently merge non-overlapping fields.
var wg sync.WaitGroup
defer wg.Wait()
for _, src := range srcs {
wg.Add(1)
go func(src proto.Message) {
defer wg.Done()
proto.Merge(dst, src)
}(src)
}
}
func TestMergeSelf(t *testing.T) {
got := &testpb.TestAllTypes{
OptionalInt32: proto.Int32(1),
OptionalString: proto.String("hello"),
RepeatedInt32: []int32{2, 3, 4},
RepeatedString: []string{"goodbye"},
MapStringString: map[string]string{"key": "value"},
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(5),
},
}
got.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
}.Marshal())
proto.Merge(got, got)
// The main impact of merging to self is that repeated fields and
// unknown fields are doubled.
want := &testpb.TestAllTypes{
OptionalInt32: proto.Int32(1),
OptionalString: proto.String("hello"),
RepeatedInt32: []int32{2, 3, 4, 2, 3, 4},
RepeatedString: []string{"goodbye", "goodbye"},
MapStringString: map[string]string{"key": "value"},
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: proto.Int32(5),
},
}
want.ProtoReflect().SetUnknown(pack.Message{
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
}.Marshal())
if !proto.Equal(got, want) {
t.Errorf("Equal mismatch:\ngot %v\nwant %v", got, want)
}
}