diff --git a/internal/impl/decode.go b/internal/impl/decode.go index 4b1bc6d6..74fd821d 100644 --- a/internal/impl/decode.go +++ b/internal/impl/decode.go @@ -176,7 +176,7 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe if n < 0 { return out, wire.ParseError(n) } - if mi.unknownOffset.IsValid() { + if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { u := p.Apply(mi.unknownOffset).Bytes() *u = wire.AppendTag(*u, num, wtyp) *u = append(*u, b[:n]...) diff --git a/proto/decode.go b/proto/decode.go index 83942eae..9a6b2f71 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -154,7 +154,9 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) if valLen < 0 { return wire.ParseError(valLen) } - m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) + if !o.DiscardUnknown { + m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) + } } b = b[tagLen+valLen:] } diff --git a/proto/decode_test.go b/proto/decode_test.go index 02f07d48..5ccb8164 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -25,9 +25,8 @@ func TestDecode(t *testing.T) { } for _, want := range test.decodeTo { t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { - opts := proto.UnmarshalOptions{ - AllowPartial: test.partial, - } + opts := test.unmarshalOptions + opts.AllowPartial = test.partial wire := append(([]byte)(nil), test.wire...) got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) if err := opts.Unmarshal(wire, got); err != nil { @@ -55,6 +54,8 @@ func TestDecodeRequiredFieldChecks(t *testing.T) { } for _, m := range test.decodeTo { t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) { + opts := test.unmarshalOptions + opts.AllowPartial = false got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message) if err := proto.Unmarshal(test.wire, got); err == nil { t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got)) @@ -71,9 +72,8 @@ func TestDecodeInvalidMessages(t *testing.T) { } for _, want := range test.decodeTo { t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { - opts := proto.UnmarshalOptions{ - AllowPartial: test.partial, - } + opts := test.unmarshalOptions + opts.AllowPartial = test.partial got := want.ProtoReflect().New().Interface() if err := opts.Unmarshal(test.wire, got); err == nil { t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, marshalText(got)) diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go index 8a7cc296..6f66380d 100644 --- a/proto/testmessages_test.go +++ b/proto/testmessages_test.go @@ -5,10 +5,12 @@ package proto_test import ( + "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/internal/encoding/pack" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoregistry" legacypb "google.golang.org/protobuf/internal/testprotos/legacy" legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5" @@ -24,6 +26,7 @@ type testProto struct { partial bool noEncode bool checkFastInit bool + unmarshalOptions proto.UnmarshalOptions validationStatus impl.ValidationStatus } @@ -1117,6 +1120,19 @@ var testValidMessages = []testProto{ pack.Tag{100000, pack.VarintType}, pack.Varint(1), }.Marshal(), }, + { + desc: "discarded unknown fields", + unmarshalOptions: proto.UnmarshalOptions{ + DiscardUnknown: true, + }, + decodeTo: []proto.Message{ + &testpb.TestAllTypes{}, + &test3pb.TestAllTypes{}, + }, + wire: pack.Message{ + pack.Tag{100000, pack.VarintType}, pack.Varint(1), + }.Marshal(), + }, { desc: "field type mismatch", decodeTo: []proto.Message{build( @@ -1615,6 +1631,46 @@ var testValidMessages = []testProto{ pack.Tag{pack.LastReservedNumber, pack.VarintType}, pack.Varint(1005), }.Marshal(), }, + { + desc: "nested unknown extension", + unmarshalOptions: proto.UnmarshalOptions{ + DiscardUnknown: true, + Resolver: func() protoregistry.ExtensionTypeResolver { + types := &protoregistry.Types{} + types.RegisterExtension(testpb.E_OptionalNestedMessageExtension) + types.RegisterExtension(testpb.E_OptionalInt32Extension) + return types + }(), + }, + decodeTo: []proto.Message{func() proto.Message { + m := &testpb.TestAllExtensions{} + if err := prototext.Unmarshal([]byte(` + [goproto.proto.test.optional_nested_message_extension]: { + corecursive: { + [goproto.proto.test.optional_nested_message_extension]: { + corecursive: { + [goproto.proto.test.optional_int32_extension]: 42 + } + } + } + }`), m); err != nil { + panic(err) + } + return m + }()}, + wire: pack.Message{ + pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{ + pack.Tag{1, pack.VarintType}, pack.Varint(42), + pack.Tag{2, pack.VarintType}, pack.Varint(43), + }), + }), + }), + }), + }.Marshal(), + }, } var testInvalidMessages = []testProto{