mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-02-19 03:39:48 +00:00
proto, internal/impl: implement support for weak fields
Change-Id: I0a3ff79542a3316295fd6c58e1447e597be97ab9 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189923 Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
fc5f8c340a
commit
6e095998ae
@ -201,6 +201,8 @@ func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) erro
|
||||
fd = nil // reset since field name is actually the message name
|
||||
}
|
||||
}
|
||||
}
|
||||
if flags.ProtoLegacy {
|
||||
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
|
||||
fd = nil // reset since the weak reference is not linked in
|
||||
}
|
||||
|
@ -138,8 +138,10 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message)
|
||||
} else if xtErr != nil && xtErr != protoregistry.NotFound {
|
||||
return errors.New("unable to resolve: %v", xtErr)
|
||||
}
|
||||
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
|
||||
fd = nil // reset since the weak reference is not linked in
|
||||
if flags.ProtoLegacy {
|
||||
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
|
||||
fd = nil // reset since the weak reference is not linked in
|
||||
}
|
||||
}
|
||||
|
||||
// Handle unknown fields.
|
||||
|
@ -6,10 +6,13 @@ package impl
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/internal/encoding/wire"
|
||||
"google.golang.org/protobuf/proto"
|
||||
pref "google.golang.org/protobuf/reflect/protoreflect"
|
||||
preg "google.golang.org/protobuf/reflect/protoregistry"
|
||||
piface "google.golang.org/protobuf/runtime/protoiface"
|
||||
)
|
||||
|
||||
type errInvalidUTF8 struct{}
|
||||
@ -17,7 +20,7 @@ type errInvalidUTF8 struct{}
|
||||
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
|
||||
func (errInvalidUTF8) InvalidUTF8() bool { return true }
|
||||
|
||||
func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
|
||||
func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
|
||||
ot := si.oneofWrappersByNumber[fd.Number()]
|
||||
funcs := fieldCoder(fd, ot.Field(0).Type)
|
||||
fs := si.oneofsByName[fd.ContainingOneof().Name()]
|
||||
@ -78,6 +81,61 @@ func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFun
|
||||
return pcf
|
||||
}
|
||||
|
||||
func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
|
||||
var once sync.Once
|
||||
var messageType pref.MessageType
|
||||
lazyInit := func() {
|
||||
once.Do(func() {
|
||||
messageName := fd.Message().FullName()
|
||||
messageType, _ = preg.GlobalTypes.FindMessageByName(messageName)
|
||||
})
|
||||
}
|
||||
|
||||
num := int32(fd.Number())
|
||||
return pointerCoderFuncs{
|
||||
size: func(p pointer, tagsize int, opts marshalOptions) int {
|
||||
fs := p.WeakFields()
|
||||
m, ok := (*fs)[num]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return sizeMessage(m.(proto.Message), tagsize, opts)
|
||||
},
|
||||
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
|
||||
fs := p.WeakFields()
|
||||
m, ok := (*fs)[num]
|
||||
if !ok {
|
||||
return b, nil
|
||||
}
|
||||
return appendMessage(b, m.(proto.Message), wiretag, opts)
|
||||
},
|
||||
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
|
||||
fs := p.WeakFields()
|
||||
m, ok := (*fs)[num]
|
||||
if !ok {
|
||||
lazyInit()
|
||||
if messageType == nil {
|
||||
return 0, errUnknown
|
||||
}
|
||||
m = messageType.New().Interface().(piface.MessageV1)
|
||||
if *fs == nil {
|
||||
*fs = make(WeakFields)
|
||||
}
|
||||
(*fs)[num] = m
|
||||
}
|
||||
return consumeMessage(b, m.(proto.Message), wtyp, opts)
|
||||
},
|
||||
isInit: func(p pointer) error {
|
||||
fs := p.WeakFields()
|
||||
m, ok := (*fs)[num]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return proto.IsInitialized(m.(proto.Message))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
|
||||
if mi := getMessageInfo(ft); mi != nil {
|
||||
return pointerCoderFuncs{
|
||||
|
@ -65,15 +65,22 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
|
||||
} else {
|
||||
wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
|
||||
}
|
||||
var fieldOffset offset
|
||||
var funcs pointerCoderFuncs
|
||||
if fd.ContainingOneof() != nil {
|
||||
funcs = makeOneofFieldCoder(si, fd)
|
||||
} else {
|
||||
switch {
|
||||
case fd.ContainingOneof() != nil:
|
||||
fieldOffset = offsetOf(fs, mi.Exporter)
|
||||
funcs = makeOneofFieldCoder(fd, si)
|
||||
case fd.IsWeak():
|
||||
fieldOffset = si.weakOffset
|
||||
funcs = makeWeakMessageFieldCoder(fd)
|
||||
default:
|
||||
fieldOffset = offsetOf(fs, mi.Exporter)
|
||||
funcs = fieldCoder(fd, ft)
|
||||
}
|
||||
cf := &coderFieldInfo{
|
||||
num: fd.Number(),
|
||||
offset: offsetOf(fs, mi.Exporter),
|
||||
offset: fieldOffset,
|
||||
wiretag: wiretag,
|
||||
tagsize: wire.SizeVarint(wiretag),
|
||||
funcs: funcs,
|
||||
|
@ -19,7 +19,7 @@ type WeakImportMessage1 struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
|
||||
A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
|
||||
}
|
||||
|
||||
func (x *WeakImportMessage1) Reset() {
|
||||
@ -64,7 +64,7 @@ var file_test_weak1_test_weak_proto_rawDesc = []byte{
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
|
||||
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
|
||||
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x31, 0x12, 0x0c, 0x0a, 0x01, 0x61,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
|
||||
0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
|
||||
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
|
||||
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
|
||||
|
@ -9,5 +9,5 @@ package goproto.proto.test.weak;
|
||||
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak1";
|
||||
|
||||
message WeakImportMessage1 {
|
||||
optional int32 a = 1;
|
||||
required int32 a = 1;
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ type WeakImportMessage2 struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
|
||||
A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
|
||||
}
|
||||
|
||||
func (x *WeakImportMessage2) Reset() {
|
||||
@ -64,7 +64,7 @@ var file_test_weak2_test_weak_proto_rawDesc = []byte{
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
|
||||
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
|
||||
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x61,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
|
||||
0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
|
||||
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
|
||||
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
|
||||
|
@ -9,5 +9,5 @@ package goproto.proto.test.weak;
|
||||
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak2";
|
||||
|
||||
message WeakImportMessage2 {
|
||||
optional int32 a = 1;
|
||||
required int32 a = 1;
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"google.golang.org/protobuf/internal/encoding/messageset"
|
||||
"google.golang.org/protobuf/internal/encoding/wire"
|
||||
"google.golang.org/protobuf/internal/errors"
|
||||
"google.golang.org/protobuf/internal/flags"
|
||||
"google.golang.org/protobuf/internal/pragma"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/reflect/protoregistry"
|
||||
@ -88,7 +89,7 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
|
||||
return wire.ParseError(tagLen)
|
||||
}
|
||||
|
||||
// Parse the field value.
|
||||
// Find the field descriptor for this field number.
|
||||
fd := fields.ByNumber(num)
|
||||
if fd == nil && md.ExtensionRanges().Has(num) {
|
||||
extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
|
||||
@ -100,10 +101,18 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
|
||||
}
|
||||
}
|
||||
var err error
|
||||
if fd == nil {
|
||||
err = errUnknown
|
||||
} else if flags.ProtoLegacy {
|
||||
if fd.IsWeak() && fd.Message().IsPlaceholder() {
|
||||
err = errUnknown // weak referent is not linked in
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the field value.
|
||||
var valLen int
|
||||
switch {
|
||||
case fd == nil:
|
||||
err = errUnknown
|
||||
case err != nil:
|
||||
case fd.IsList():
|
||||
valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
|
||||
case fd.IsMap():
|
||||
@ -111,14 +120,15 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
|
||||
default:
|
||||
valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
|
||||
}
|
||||
if err == errUnknown {
|
||||
if err != nil {
|
||||
if err != errUnknown {
|
||||
return err
|
||||
}
|
||||
valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
|
||||
if valLen < 0 {
|
||||
return wire.ParseError(valLen)
|
||||
}
|
||||
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[tagLen+valLen:]
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
|
||||
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
|
||||
testpb "google.golang.org/protobuf/internal/testprotos/test"
|
||||
weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
|
||||
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
|
||||
"google.golang.org/protobuf/types/descriptorpb"
|
||||
)
|
||||
@ -1726,6 +1727,46 @@ var invalidFieldNumberTestProtos = []struct {
|
||||
},
|
||||
}
|
||||
|
||||
func TestWeak(t *testing.T) {
|
||||
if !flags.ProtoLegacy {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
m := new(testpb.TestWeak)
|
||||
b := pack.Message{
|
||||
pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(1000),
|
||||
}),
|
||||
pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
||||
pack.Tag{1, pack.VarintType}, pack.Varint(2000),
|
||||
}),
|
||||
}.Marshal()
|
||||
if err := proto.Unmarshal(b, m); err != nil {
|
||||
t.Errorf("Unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
|
||||
if mw.GetA() != 1000 {
|
||||
t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
|
||||
}
|
||||
|
||||
if len(m.ProtoReflect().GetUnknown()) == 0 {
|
||||
t.Errorf("m has no unknown fields, expected at least something")
|
||||
}
|
||||
|
||||
if n := proto.Size(m); n != len(b) {
|
||||
t.Errorf("Size() = %d, want %d", n, len(b))
|
||||
}
|
||||
|
||||
b2, err := proto.Marshal(m)
|
||||
if err != nil {
|
||||
t.Errorf("Marshal error: %v", err)
|
||||
}
|
||||
if len(b2) != len(b) {
|
||||
t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
|
||||
}
|
||||
}
|
||||
|
||||
func build(m proto.Message, opts ...buildOpt) proto.Message {
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
|
@ -9,51 +9,80 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/internal/flags"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
testpb "google.golang.org/protobuf/internal/testprotos/test"
|
||||
weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
|
||||
)
|
||||
|
||||
func TestIsInitializedErrors(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
type test struct {
|
||||
m proto.Message
|
||||
want string
|
||||
}{
|
||||
{
|
||||
&testpb.TestRequired{},
|
||||
`goproto.proto.test.TestRequired.required_field`,
|
||||
skip bool
|
||||
}
|
||||
tests := []test{{
|
||||
m: &testpb.TestRequired{},
|
||||
want: `goproto.proto.test.TestRequired.required_field`,
|
||||
}, {
|
||||
m: &testpb.TestRequiredForeign{
|
||||
OptionalMessage: &testpb.TestRequired{},
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
OptionalMessage: &testpb.TestRequired{},
|
||||
want: `goproto.proto.test.TestRequired.required_field`,
|
||||
}, {
|
||||
m: &testpb.TestRequiredForeign{
|
||||
RepeatedMessage: []*testpb.TestRequired{
|
||||
{RequiredField: proto.Int32(1)},
|
||||
{},
|
||||
},
|
||||
`goproto.proto.test.TestRequired.required_field`,
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
RepeatedMessage: []*testpb.TestRequired{
|
||||
{RequiredField: proto.Int32(1)},
|
||||
{},
|
||||
},
|
||||
want: `goproto.proto.test.TestRequired.required_field`,
|
||||
}, {
|
||||
m: &testpb.TestRequiredForeign{
|
||||
MapMessage: map[int32]*testpb.TestRequired{
|
||||
1: {},
|
||||
},
|
||||
`goproto.proto.test.TestRequired.required_field`,
|
||||
},
|
||||
{
|
||||
&testpb.TestRequiredForeign{
|
||||
MapMessage: map[int32]*testpb.TestRequired{
|
||||
1: {},
|
||||
},
|
||||
},
|
||||
`goproto.proto.test.TestRequired.required_field`,
|
||||
},
|
||||
} {
|
||||
err := proto.IsInitialized(test.m)
|
||||
got := "<nil>"
|
||||
if err != nil {
|
||||
got = fmt.Sprintf("%q", err)
|
||||
}
|
||||
if !strings.Contains(got, test.want) {
|
||||
t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, test.want, marshalText(test.m))
|
||||
}
|
||||
want: `goproto.proto.test.TestRequired.required_field`,
|
||||
}, {
|
||||
m: &testpb.TestWeak{},
|
||||
want: `<nil>`,
|
||||
skip: !flags.ProtoLegacy,
|
||||
}, {
|
||||
m: func() proto.Message {
|
||||
m := &testpb.TestWeak{}
|
||||
m.SetWeakMessage1(&weakpb.WeakImportMessage1{})
|
||||
return m
|
||||
}(),
|
||||
want: `goproto.proto.test.weak.WeakImportMessage1.a`,
|
||||
skip: !flags.ProtoLegacy,
|
||||
}, {
|
||||
m: func() proto.Message {
|
||||
m := &testpb.TestWeak{}
|
||||
m.SetWeakMessage1(&weakpb.WeakImportMessage1{
|
||||
A: proto.Int32(1),
|
||||
})
|
||||
return m
|
||||
}(),
|
||||
want: `<nil>`,
|
||||
skip: !flags.ProtoLegacy,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
if tt.skip {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
err := proto.IsInitialized(tt.m)
|
||||
got := "<nil>"
|
||||
if err != nil {
|
||||
got = fmt.Sprintf("%q", err)
|
||||
}
|
||||
if !strings.Contains(got, tt.want) {
|
||||
t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, tt.want, marshalText(tt.m))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user