proto: fix DiscardUnknown

UnmarshalOptions.DiscardUnknown was simply not working. Oops. Fix it.
Add a test.

Change-Id: I76888eae1221d99a007f0e9cdb711d292e6856b1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216762
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-01-28 14:53:44 -08:00
parent cb0bfd0f40
commit a60e709ac8
4 changed files with 66 additions and 8 deletions

View File

@ -176,7 +176,7 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if n < 0 {
return out, wire.ParseError(n)
}
if mi.unknownOffset.IsValid() {
if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
u := p.Apply(mi.unknownOffset).Bytes()
*u = wire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...)

View File

@ -154,7 +154,9 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message)
if valLen < 0 {
return wire.ParseError(valLen)
}
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
if !o.DiscardUnknown {
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
}
}
b = b[tagLen+valLen:]
}

View File

@ -25,9 +25,8 @@ func TestDecode(t *testing.T) {
}
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
opts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
opts := test.unmarshalOptions
opts.AllowPartial = test.partial
wire := append(([]byte)(nil), test.wire...)
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := opts.Unmarshal(wire, got); err != nil {
@ -55,6 +54,8 @@ func TestDecodeRequiredFieldChecks(t *testing.T) {
}
for _, m := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
opts := test.unmarshalOptions
opts.AllowPartial = false
got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(test.wire, got); err == nil {
t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got))
@ -71,9 +72,8 @@ func TestDecodeInvalidMessages(t *testing.T) {
}
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
opts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
opts := test.unmarshalOptions
opts.AllowPartial = test.partial
got := want.ProtoReflect().New().Interface()
if err := opts.Unmarshal(test.wire, got); err == nil {
t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, marshalText(got))

View File

@ -5,10 +5,12 @@
package proto_test
import (
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
@ -24,6 +26,7 @@ type testProto struct {
partial bool
noEncode bool
checkFastInit bool
unmarshalOptions proto.UnmarshalOptions
validationStatus impl.ValidationStatus
}
@ -1117,6 +1120,19 @@ var testValidMessages = []testProto{
pack.Tag{100000, pack.VarintType}, pack.Varint(1),
}.Marshal(),
},
{
desc: "discarded unknown fields",
unmarshalOptions: proto.UnmarshalOptions{
DiscardUnknown: true,
},
decodeTo: []proto.Message{
&testpb.TestAllTypes{},
&test3pb.TestAllTypes{},
},
wire: pack.Message{
pack.Tag{100000, pack.VarintType}, pack.Varint(1),
}.Marshal(),
},
{
desc: "field type mismatch",
decodeTo: []proto.Message{build(
@ -1615,6 +1631,46 @@ var testValidMessages = []testProto{
pack.Tag{pack.LastReservedNumber, pack.VarintType}, pack.Varint(1005),
}.Marshal(),
},
{
desc: "nested unknown extension",
unmarshalOptions: proto.UnmarshalOptions{
DiscardUnknown: true,
Resolver: func() protoregistry.ExtensionTypeResolver {
types := &protoregistry.Types{}
types.RegisterExtension(testpb.E_OptionalNestedMessageExtension)
types.RegisterExtension(testpb.E_OptionalInt32Extension)
return types
}(),
},
decodeTo: []proto.Message{func() proto.Message {
m := &testpb.TestAllExtensions{}
if err := prototext.Unmarshal([]byte(`
[goproto.proto.test.optional_nested_message_extension]: {
corecursive: {
[goproto.proto.test.optional_nested_message_extension]: {
corecursive: {
[goproto.proto.test.optional_int32_extension]: 42
}
}
}
}`), m); err != nil {
panic(err)
}
return m
}()},
wire: pack.Message{
pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
pack.Tag{1, pack.VarintType}, pack.Varint(42),
pack.Tag{2, pack.VarintType}, pack.Varint(43),
}),
}),
}),
}),
}.Marshal(),
},
}
var testInvalidMessages = []testProto{