diff --git a/proto/decode_test.go b/proto/decode_test.go index e296fc22..b2a3369e 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -134,6 +134,20 @@ func TestDecodeOneofNilWrapper(t *testing.T) { } } +func TestDecodeInvalidFieldNumbers(t *testing.T) { + for _, test := range invalidFieldNumberTestProtos { + t.Run(test.desc, func(t *testing.T) { + decoded := new(testpb.TestAllTypes) // type doesn't matter since we expect errors + err := proto.Unmarshal(test.wire, decoded) + if err == nil && !test.allowed { + t.Error("unmarshal: got nil want error") + } else if err != nil && test.allowed { + t.Errorf("unmarshal: got %v want nil since %s is allowed by Unmarshal", err, test.desc) + } + }) + } +} + var testProtos = []testProto{ { desc: "basic scalar types", @@ -1663,6 +1677,55 @@ var messageInfo_TestNoEnforceUTF8 = protoimpl.MessageInfo{ }, } +var invalidFieldNumberTestProtos = []struct { + desc string + wire []byte + allowed bool +}{ + { + desc: "zero", + wire: pack.Message{ + pack.Tag{pack.MinValidNumber - 1, pack.VarintType}, pack.Varint(1001), + }.Marshal(), + }, + { + desc: "zero and one", + wire: pack.Message{ + pack.Tag{pack.MinValidNumber - 1, pack.VarintType}, pack.Varint(1002), + pack.Tag{pack.MinValidNumber, pack.VarintType}, pack.Varint(1003), + }.Marshal(), + }, + { + desc: "first reserved", + wire: pack.Message{ + pack.Tag{pack.FirstReservedNumber, pack.VarintType}, pack.Varint(1004), + }.Marshal(), + allowed: true, + }, + { + desc: "last reserved", + wire: pack.Message{ + pack.Tag{pack.LastReservedNumber, pack.VarintType}, pack.Varint(1005), + }.Marshal(), + allowed: true, + }, + { + desc: "max and max+1", + wire: pack.Message{ + pack.Tag{pack.MaxValidNumber, pack.VarintType}, pack.Varint(1006), + pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1007), + }.Marshal(), + allowed: flags.ProtoLegacy, + }, + { + desc: "max+1", + wire: pack.Message{ + pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1008), + }.Marshal(), + allowed: flags.ProtoLegacy, + }, +} + func build(m proto.Message, opts ...buildOpt) proto.Message { for _, opt := range opts { opt(m)