mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-03-28 19:21:22 +00:00
internal/impl: support typed nil source for Merge of aberrant messages
When merging aberrant messages with legacy Marshal and Unmarshal methods, check for a typed nil source before calling Marshal. Add an aberrant message with Marshal/Unmarshal methods to internal/testprotos/nullable and use it to test the internal/impl support for these methods. Fixes golang/protobuf#1324 Change-Id: Ib6ce85b30b46e3392a226ca6abe411932a371f02 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/321529 Trust: Damien Neil <dneil@google.com> Run-TryBot: Damien Neil <dneil@google.com> Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
0e358a402f
commit
24d799b3c1
@ -440,6 +440,13 @@ func legacyMerge(in piface.MergeInput) piface.MergeOutput {
|
||||
if !ok {
|
||||
return piface.MergeOutput{}
|
||||
}
|
||||
if !in.Source.IsValid() {
|
||||
// Legacy Marshal methods may not function on nil messages.
|
||||
// Check for a typed nil source only after we confirm that
|
||||
// legacy Marshal/Unmarshal methods are present, for
|
||||
// consistency.
|
||||
return piface.MergeOutput{Flags: piface.MergeComplete}
|
||||
}
|
||||
b, err := marshaler.Marshal()
|
||||
if err != nil {
|
||||
return piface.MergeOutput{}
|
||||
|
@ -2,45 +2,19 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Only test compatibility with the Marshal/Unmarshal functionality with
|
||||
// For messages which do not provide legacy Marshal and Unmarshal methods,
|
||||
// only test compatibility with the Marshal/Unmarshal functionality with
|
||||
// pure protobuf reflection since there is no support for nullable fields
|
||||
// in the table-driven implementation.
|
||||
// +build protoreflect
|
||||
|
||||
package nullable
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
)
|
||||
import "google.golang.org/protobuf/runtime/protoimpl"
|
||||
|
||||
func init() {
|
||||
testMethods = func(t *testing.T, mt protoreflect.MessageType) {
|
||||
m1 := mt.New()
|
||||
populated := testPopulateMessage(t, m1, 2)
|
||||
b, err := proto.Marshal(m1.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("proto.Marshal error: %v", err)
|
||||
}
|
||||
if populated && len(b) == 0 {
|
||||
t.Errorf("len(proto.Marshal) = 0, want >0")
|
||||
}
|
||||
m2 := mt.New()
|
||||
if err := proto.Unmarshal(b, m2.Interface()); err != nil {
|
||||
t.Errorf("proto.Unmarshal error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
|
||||
t.Errorf("message mismatch:\n%v", diff)
|
||||
}
|
||||
proto.Reset(m2.Interface())
|
||||
testEmptyMessage(t, m2, true)
|
||||
proto.Merge(m2.Interface(), m1.Interface())
|
||||
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
|
||||
t.Errorf("message mismatch:\n%v", diff)
|
||||
}
|
||||
}
|
||||
methodTestProtos = append(methodTestProtos,
|
||||
protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(),
|
||||
protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(),
|
||||
)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ package nullable
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
"google.golang.org/protobuf/runtime/protoimpl"
|
||||
"google.golang.org/protobuf/types/descriptorpb"
|
||||
)
|
||||
@ -223,3 +224,43 @@ func (*Proto3_OneofString) isProto3_OneofUnion() {}
|
||||
func (*Proto3_OneofBytes) isProto3_OneofUnion() {}
|
||||
func (*Proto3_OneofEnum) isProto3_OneofUnion() {}
|
||||
func (*Proto3_OneofMessage) isProto3_OneofUnion() {}
|
||||
|
||||
type Methods struct {
|
||||
OptionalInt32 int32 `protobuf:"varint,101,opt,name=optional_int32"`
|
||||
}
|
||||
|
||||
func (x *Methods) ProtoMessage() {}
|
||||
func (x *Methods) Reset() { *x = Methods{} }
|
||||
func (x *Methods) String() string { return prototext.Format(protoimpl.X.ProtoMessageV2Of(x)) }
|
||||
|
||||
func (x *Methods) Marshal() ([]byte, error) {
|
||||
var b []byte
|
||||
b = protowire.AppendTag(b, 101, protowire.VarintType)
|
||||
b = protowire.AppendVarint(b, uint64(x.OptionalInt32))
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (x *Methods) Unmarshal(b []byte) error {
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return protowire.ParseError(n)
|
||||
}
|
||||
b = b[n:]
|
||||
if num != 101 || typ != protowire.VarintType {
|
||||
n = protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return protowire.ParseError(n)
|
||||
}
|
||||
b = b[n:]
|
||||
continue
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 {
|
||||
return protowire.ParseError(n)
|
||||
}
|
||||
b = b[n:]
|
||||
x.OptionalInt32 = int32(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -8,8 +8,11 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/runtime/protoimpl"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
@ -20,12 +23,48 @@ func Test(t *testing.T) {
|
||||
t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
|
||||
testEmptyMessage(t, mt.Zero(), false)
|
||||
testEmptyMessage(t, mt.New(), true)
|
||||
//testMethods(t, mt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var methodTestProtos = []protoreflect.MessageType{
|
||||
protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(),
|
||||
}
|
||||
|
||||
func TestMethods(t *testing.T) {
|
||||
for _, mt := range methodTestProtos {
|
||||
t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
|
||||
testMethods(t, mt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var testMethods = func(*testing.T, protoreflect.MessageType) {}
|
||||
func testMethods(t *testing.T, mt protoreflect.MessageType) {
|
||||
m1 := mt.New()
|
||||
populated := testPopulateMessage(t, m1, 2)
|
||||
b, err := proto.Marshal(m1.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("proto.Marshal error: %v", err)
|
||||
}
|
||||
if populated && len(b) == 0 {
|
||||
t.Errorf("len(proto.Marshal) = 0, want >0")
|
||||
}
|
||||
m2 := mt.New()
|
||||
if err := proto.Unmarshal(b, m2.Interface()); err != nil {
|
||||
t.Errorf("proto.Unmarshal error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
|
||||
t.Errorf("message mismatch:\n%v", diff)
|
||||
}
|
||||
proto.Reset(m2.Interface())
|
||||
testEmptyMessage(t, m2, true)
|
||||
proto.Merge(m2.Interface(), m1.Interface())
|
||||
if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
|
||||
t.Errorf("message mismatch:\n%v", diff)
|
||||
}
|
||||
proto.Merge(mt.New().Interface(), mt.Zero().Interface())
|
||||
}
|
||||
|
||||
func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) {
|
||||
numFields := func(m protoreflect.Message) (n int) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user