proto, internal/impl: implement support for weak fields

Change-Id: I0a3ff79542a3316295fd6c58e1447e597be97ab9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189923
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2019-08-10 13:56:36 -07:00
parent fc5f8c340a
commit 6e095998ae
11 changed files with 201 additions and 52 deletions

View File

@ -201,6 +201,8 @@ func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) erro
fd = nil // reset since field name is actually the message name
}
}
}
if flags.ProtoLegacy {
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
fd = nil // reset since the weak reference is not linked in
}

View File

@ -138,8 +138,10 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
} else if xtErr != nil && xtErr != protoregistry.NotFound {
return errors.New("unable to resolve: %v", xtErr)
}
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
fd = nil // reset since the weak reference is not linked in
if flags.ProtoLegacy {
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
fd = nil // reset since the weak reference is not linked in
}
}
// Handle unknown fields.

View File

@ -6,10 +6,13 @@ package impl
import (
"reflect"
"sync"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
piface "google.golang.org/protobuf/runtime/protoiface"
)
type errInvalidUTF8 struct{}
@ -17,7 +20,7 @@ type errInvalidUTF8 struct{}
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
ot := si.oneofWrappersByNumber[fd.Number()]
funcs := fieldCoder(fd, ot.Field(0).Type)
fs := si.oneofsByName[fd.ContainingOneof().Name()]
@ -78,6 +81,61 @@ func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFun
return pcf
}
func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
var once sync.Once
var messageType pref.MessageType
lazyInit := func() {
once.Do(func() {
messageName := fd.Message().FullName()
messageType, _ = preg.GlobalTypes.FindMessageByName(messageName)
})
}
num := int32(fd.Number())
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
fs := p.WeakFields()
m, ok := (*fs)[num]
if !ok {
return 0
}
return sizeMessage(m.(proto.Message), tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
fs := p.WeakFields()
m, ok := (*fs)[num]
if !ok {
return b, nil
}
return appendMessage(b, m.(proto.Message), wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
fs := p.WeakFields()
m, ok := (*fs)[num]
if !ok {
lazyInit()
if messageType == nil {
return 0, errUnknown
}
m = messageType.New().Interface().(piface.MessageV1)
if *fs == nil {
*fs = make(WeakFields)
}
(*fs)[num] = m
}
return consumeMessage(b, m.(proto.Message), wtyp, opts)
},
isInit: func(p pointer) error {
fs := p.WeakFields()
m, ok := (*fs)[num]
if !ok {
return nil
}
return proto.IsInitialized(m.(proto.Message))
},
}
}
func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if mi := getMessageInfo(ft); mi != nil {
return pointerCoderFuncs{

View File

@ -65,15 +65,22 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
} else {
wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
}
var fieldOffset offset
var funcs pointerCoderFuncs
if fd.ContainingOneof() != nil {
funcs = makeOneofFieldCoder(si, fd)
} else {
switch {
case fd.ContainingOneof() != nil:
fieldOffset = offsetOf(fs, mi.Exporter)
funcs = makeOneofFieldCoder(fd, si)
case fd.IsWeak():
fieldOffset = si.weakOffset
funcs = makeWeakMessageFieldCoder(fd)
default:
fieldOffset = offsetOf(fs, mi.Exporter)
funcs = fieldCoder(fd, ft)
}
cf := &coderFieldInfo{
num: fd.Number(),
offset: offsetOf(fs, mi.Exporter),
offset: fieldOffset,
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
funcs: funcs,

View File

@ -19,7 +19,7 @@ type WeakImportMessage1 struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
}
func (x *WeakImportMessage1) Reset() {
@ -64,7 +64,7 @@ var file_test_weak1_test_weak_proto_rawDesc = []byte{
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x31, 0x12, 0x0c, 0x0a, 0x01, 0x61,
0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,

View File

@ -9,5 +9,5 @@ package goproto.proto.test.weak;
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak1";
message WeakImportMessage1 {
optional int32 a = 1;
required int32 a = 1;
}

View File

@ -19,7 +19,7 @@ type WeakImportMessage2 struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
}
func (x *WeakImportMessage2) Reset() {
@ -64,7 +64,7 @@ var file_test_weak2_test_weak_proto_rawDesc = []byte{
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x61,
0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,

View File

@ -9,5 +9,5 @@ package goproto.proto.test.weak;
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak2";
message WeakImportMessage2 {
optional int32 a = 1;
required int32 a = 1;
}

View File

@ -8,6 +8,7 @@ import (
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
@ -88,7 +89,7 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
return wire.ParseError(tagLen)
}
// Parse the field value.
// Find the field descriptor for this field number.
fd := fields.ByNumber(num)
if fd == nil && md.ExtensionRanges().Has(num) {
extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
@ -100,10 +101,18 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
}
}
var err error
if fd == nil {
err = errUnknown
} else if flags.ProtoLegacy {
if fd.IsWeak() && fd.Message().IsPlaceholder() {
err = errUnknown // weak referent is not linked in
}
}
// Parse the field value.
var valLen int
switch {
case fd == nil:
err = errUnknown
case err != nil:
case fd.IsList():
valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
case fd.IsMap():
@ -111,14 +120,15 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
default:
valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
}
if err == errUnknown {
if err != nil {
if err != errUnknown {
return err
}
valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
if valLen < 0 {
return wire.ParseError(valLen)
}
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
} else if err != nil {
return err
}
b = b[tagLen+valLen:]
}

View File

@ -21,6 +21,7 @@ import (
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
"google.golang.org/protobuf/types/descriptorpb"
)
@ -1726,6 +1727,46 @@ var invalidFieldNumberTestProtos = []struct {
},
}
func TestWeak(t *testing.T) {
if !flags.ProtoLegacy {
t.SkipNow()
}
m := new(testpb.TestWeak)
b := pack.Message{
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(1000),
}),
pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(2000),
}),
}.Marshal()
if err := proto.Unmarshal(b, m); err != nil {
t.Errorf("Unmarshal error: %v", err)
}
mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
if mw.GetA() != 1000 {
t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
}
if len(m.ProtoReflect().GetUnknown()) == 0 {
t.Errorf("m has no unknown fields, expected at least something")
}
if n := proto.Size(m); n != len(b) {
t.Errorf("Size() = %d, want %d", n, len(b))
}
b2, err := proto.Marshal(m)
if err != nil {
t.Errorf("Marshal error: %v", err)
}
if len(b2) != len(b) {
t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
}
}
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)

View File

@ -9,51 +9,80 @@ import (
"strings"
"testing"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
testpb "google.golang.org/protobuf/internal/testprotos/test"
weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
)
func TestIsInitializedErrors(t *testing.T) {
for _, test := range []struct {
type test struct {
m proto.Message
want string
}{
{
&testpb.TestRequired{},
`goproto.proto.test.TestRequired.required_field`,
skip bool
}
tests := []test{{
m: &testpb.TestRequired{},
want: `goproto.proto.test.TestRequired.required_field`,
}, {
m: &testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{},
},
{
&testpb.TestRequiredForeign{
OptionalMessage: &testpb.TestRequired{},
want: `goproto.proto.test.TestRequired.required_field`,
}, {
m: &testpb.TestRequiredForeign{
RepeatedMessage: []*testpb.TestRequired{
{RequiredField: proto.Int32(1)},
{},
},
`goproto.proto.test.TestRequired.required_field`,
},
{
&testpb.TestRequiredForeign{
RepeatedMessage: []*testpb.TestRequired{
{RequiredField: proto.Int32(1)},
{},
},
want: `goproto.proto.test.TestRequired.required_field`,
}, {
m: &testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
1: {},
},
`goproto.proto.test.TestRequired.required_field`,
},
{
&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{
1: {},
},
},
`goproto.proto.test.TestRequired.required_field`,
},
} {
err := proto.IsInitialized(test.m)
got := "<nil>"
if err != nil {
got = fmt.Sprintf("%q", err)
}
if !strings.Contains(got, test.want) {
t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, test.want, marshalText(test.m))
}
want: `goproto.proto.test.TestRequired.required_field`,
}, {
m: &testpb.TestWeak{},
want: `<nil>`,
skip: !flags.ProtoLegacy,
}, {
m: func() proto.Message {
m := &testpb.TestWeak{}
m.SetWeakMessage1(&weakpb.WeakImportMessage1{})
return m
}(),
want: `goproto.proto.test.weak.WeakImportMessage1.a`,
skip: !flags.ProtoLegacy,
}, {
m: func() proto.Message {
m := &testpb.TestWeak{}
m.SetWeakMessage1(&weakpb.WeakImportMessage1{
A: proto.Int32(1),
})
return m
}(),
want: `<nil>`,
skip: !flags.ProtoLegacy,
}}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if tt.skip {
t.SkipNow()
}
err := proto.IsInitialized(tt.m)
got := "<nil>"
if err != nil {
got = fmt.Sprintf("%q", err)
}
if !strings.Contains(got, tt.want) {
t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, tt.want, marshalText(tt.m))
}
})
}
}