internal/impl: support typed nil source for Merge of aberrant messages

When merging aberrant messages with legacy Marshal and Unmarshal
methods, check for a typed nil source before calling Marshal.

Add an aberrant message with Marshal/Unmarshal methods to
internal/testprotos/nullable and use it to test the internal/impl
support for these methods.

Fixes golang/protobuf#1324

Change-Id: Ib6ce85b30b46e3392a226ca6abe411932a371f02
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/321529
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2021-05-20 10:35:18 -07:00
parent 0e358a402f
commit 24d799b3c1
4 changed files with 95 additions and 34 deletions

View File

@ -440,6 +440,13 @@ func legacyMerge(in piface.MergeInput) piface.MergeOutput {
if !ok {
return piface.MergeOutput{}
}
if !in.Source.IsValid() {
// Legacy Marshal methods may not function on nil messages.
// Check for a typed nil source only after we confirm that
// legacy Marshal/Unmarshal methods are present, for
// consistency.
return piface.MergeOutput{Flags: piface.MergeComplete}
}
b, err := marshaler.Marshal()
if err != nil {
return piface.MergeOutput{}

View File

@ -2,45 +2,19 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Only test compatibility with the Marshal/Unmarshal functionality with
// For messages which do not provide legacy Marshal and Unmarshal methods,
// only test compatibility with the Marshal/Unmarshal functionality with
// pure protobuf reflection since there is no support for nullable fields
// in the table-driven implementation.
// +build protoreflect
package nullable
import (
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/testing/protocmp"
)
import "google.golang.org/protobuf/runtime/protoimpl"
func init() {
testMethods = func(t *testing.T, mt protoreflect.MessageType) {
m1 := mt.New()
populated := testPopulateMessage(t, m1, 2)
b, err := proto.Marshal(m1.Interface())
if err != nil {
t.Errorf("proto.Marshal error: %v", err)
}
if populated && len(b) == 0 {
t.Errorf("len(proto.Marshal) = 0, want >0")
}
m2 := mt.New()
if err := proto.Unmarshal(b, m2.Interface()); err != nil {
t.Errorf("proto.Unmarshal error: %v", err)
}
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
t.Errorf("message mismatch:\n%v", diff)
}
proto.Reset(m2.Interface())
testEmptyMessage(t, m2, true)
proto.Merge(m2.Interface(), m1.Interface())
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
t.Errorf("message mismatch:\n%v", diff)
}
}
methodTestProtos = append(methodTestProtos,
protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(),
protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(),
)
}

View File

@ -6,6 +6,7 @@ package nullable
import (
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/descriptorpb"
)
@ -223,3 +224,43 @@ func (*Proto3_OneofString) isProto3_OneofUnion() {}
func (*Proto3_OneofBytes) isProto3_OneofUnion() {}
func (*Proto3_OneofEnum) isProto3_OneofUnion() {}
func (*Proto3_OneofMessage) isProto3_OneofUnion() {}
type Methods struct {
OptionalInt32 int32 `protobuf:"varint,101,opt,name=optional_int32"`
}
func (x *Methods) ProtoMessage() {}
func (x *Methods) Reset() { *x = Methods{} }
func (x *Methods) String() string { return prototext.Format(protoimpl.X.ProtoMessageV2Of(x)) }
func (x *Methods) Marshal() ([]byte, error) {
var b []byte
b = protowire.AppendTag(b, 101, protowire.VarintType)
b = protowire.AppendVarint(b, uint64(x.OptionalInt32))
return b, nil
}
func (x *Methods) Unmarshal(b []byte) error {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return protowire.ParseError(n)
}
b = b[n:]
if num != 101 || typ != protowire.VarintType {
n = protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return protowire.ParseError(n)
}
b = b[n:]
continue
}
v, n := protowire.ConsumeVarint(b)
if n < 0 {
return protowire.ParseError(n)
}
b = b[n:]
x.OptionalInt32 = int32(v)
}
return nil
}

View File

@ -8,8 +8,11 @@ import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/testing/protocmp"
)
func Test(t *testing.T) {
@ -20,12 +23,48 @@ func Test(t *testing.T) {
t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
testEmptyMessage(t, mt.Zero(), false)
testEmptyMessage(t, mt.New(), true)
//testMethods(t, mt)
})
}
}
var methodTestProtos = []protoreflect.MessageType{
protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(),
}
func TestMethods(t *testing.T) {
for _, mt := range methodTestProtos {
t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
testMethods(t, mt)
})
}
}
var testMethods = func(*testing.T, protoreflect.MessageType) {}
func testMethods(t *testing.T, mt protoreflect.MessageType) {
m1 := mt.New()
populated := testPopulateMessage(t, m1, 2)
b, err := proto.Marshal(m1.Interface())
if err != nil {
t.Errorf("proto.Marshal error: %v", err)
}
if populated && len(b) == 0 {
t.Errorf("len(proto.Marshal) = 0, want >0")
}
m2 := mt.New()
if err := proto.Unmarshal(b, m2.Interface()); err != nil {
t.Errorf("proto.Unmarshal error: %v", err)
}
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
t.Errorf("message mismatch:\n%v", diff)
}
proto.Reset(m2.Interface())
testEmptyMessage(t, m2, true)
proto.Merge(m2.Interface(), m1.Interface())
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
t.Errorf("message mismatch:\n%v", diff)
}
proto.Merge(mt.New().Interface(), mt.Zero().Interface())
}
func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) {
numFields := func(m protoreflect.Message) (n int) {