Merge pull request #6632 from yanesca/refactor_bignum_test_framework

Refactor bignum test framework
This commit is contained in:
Janos Follath 2022-11-22 14:53:58 +00:00 committed by GitHub
commit 0fc88779ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 434 additions and 313 deletions

View File

@ -15,7 +15,12 @@
# limitations under the License.
from abc import abstractmethod
from typing import Iterator, List, Tuple, TypeVar
from typing import Iterator, List, Tuple, TypeVar, Any
from itertools import chain
from . import test_case
from . import test_data_generation
from .bignum_data import INPUTS_DEFAULT, MODULI_DEFAULT
T = TypeVar('T') #pylint: disable=invalid-name
@ -63,8 +68,7 @@ def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
"""Return all pair combinations from input values."""
return [(x, y) for x in values for y in values]
class OperationCommon:
class OperationCommon(test_data_generation.BaseTest):
"""Common features for bignum binary operations.
This adds functionality common in binary operation tests.
@ -78,22 +82,106 @@ class OperationCommon:
unique_combinations_only: Boolean to select if test case combinations
must be unique. If True, only A,B or B,A would be included as a test
case. If False, both A,B and B,A would be included.
input_style: Controls the way how test data is passed to the functions
in the generated test cases. "variable" passes them as they are
defined in the python source. "arch_split" pads the values with
zeroes depending on the architecture/limb size. If this is set,
test cases are generated for all architectures.
arity: the number of operands for the operation. Currently supported
values are 1 and 2.
"""
symbol = ""
input_values = [] # type: List[str]
input_cases = [] # type: List[Tuple[str, str]]
unique_combinations_only = True
input_values = INPUTS_DEFAULT # type: List[str]
input_cases = [] # type: List[Any]
unique_combinations_only = False
input_styles = ["variable", "fixed", "arch_split"] # type: List[str]
input_style = "variable" # type: str
limb_sizes = [32, 64] # type: List[int]
arities = [1, 2]
arity = 2
def __init__(self, val_a: str, val_b: str) -> None:
self.arg_a = val_a
self.arg_b = val_b
def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None:
self.val_a = val_a
self.val_b = val_b
# Setting the int versions here as opposed to making them @properties
# provides earlier/more robust input validation.
self.int_a = hex_to_int(val_a)
self.int_b = hex_to_int(val_b)
if bits_in_limb not in self.limb_sizes:
raise ValueError("Invalid number of bits in limb!")
if self.input_style == "arch_split":
self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)]
self.bits_in_limb = bits_in_limb
@property
def boundary(self) -> int:
if self.arity == 1:
return self.int_a
elif self.arity == 2:
return max(self.int_a, self.int_b)
raise ValueError("Unsupported number of operands!")
@property
def limb_boundary(self) -> int:
return bound_mpi(self.boundary, self.bits_in_limb)
@property
def limbs(self) -> int:
return limbs_mpi(self.boundary, self.bits_in_limb)
@property
def hex_digits(self) -> int:
return 2 * (self.limbs * self.bits_in_limb // 8)
def format_arg(self, val) -> str:
if self.input_style not in self.input_styles:
raise ValueError("Unknown input style!")
if self.input_style == "variable":
return val
else:
return val.zfill(self.hex_digits)
def format_result(self, res) -> str:
res_str = '{:x}'.format(res)
return quote_str(self.format_arg(res_str))
@property
def arg_a(self) -> str:
return self.format_arg(self.val_a)
@property
def arg_b(self) -> str:
if self.arity == 1:
raise AttributeError("Operation is unary and doesn't have arg_b!")
return self.format_arg(self.val_b)
def arguments(self) -> List[str]:
return [
quote_str(self.arg_a), quote_str(self.arg_b)
] + self.result()
args = [quote_str(self.arg_a)]
if self.arity == 2:
args.append(quote_str(self.arg_b))
return args + self.result()
def description(self) -> str:
"""Generate a description for the test case.
If not set, case_description uses the form A `symbol` B, where symbol
is used to represent the operation. Descriptions of each value are
generated to provide some context to the test case.
"""
if not self.case_description:
if self.arity == 1:
self.case_description = "{} {:x}".format(
self.symbol, self.int_a
)
elif self.arity == 2:
self.case_description = "{:x} {} {:x}".format(
self.int_a, self.symbol, self.int_b
)
return super().description()
@property
def is_valid(self) -> bool:
return True
@abstractmethod
def result(self) -> List[str]:
@ -111,15 +199,134 @@ class OperationCommon:
Combinations are first generated from all input values, and then
specific cases provided.
"""
if cls.unique_combinations_only:
yield from combination_pairs(cls.input_values)
if cls.arity == 1:
yield from ((a, "0") for a in cls.input_values)
elif cls.arity == 2:
if cls.unique_combinations_only:
yield from combination_pairs(cls.input_values)
else:
yield from (
(a, b)
for a in cls.input_values
for b in cls.input_values
)
else:
yield from (
(a, b)
for a in cls.input_values
for b in cls.input_values
)
yield from cls.input_cases
raise ValueError("Unsupported number of operands!")
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
if cls.input_style not in cls.input_styles:
raise ValueError("Unknown input style!")
if cls.arity not in cls.arities:
raise ValueError("Unsupported number of operands!")
if cls.input_style == "arch_split":
test_objects = (cls(a, b, bits_in_limb=bil)
for a, b in cls.get_value_pairs()
for bil in cls.limb_sizes)
special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
for args in cls.input_cases
for bil in cls.limb_sizes)
else:
test_objects = (cls(a, b)
for a, b in cls.get_value_pairs())
special_cases = (cls(*args) for args in cls.input_cases)
yield from (valid_test_object.create_test_case()
for valid_test_object in filter(
lambda test_object: test_object.is_valid,
chain(test_objects, special_cases)
)
)
class ModOperationCommon(OperationCommon):
#pylint: disable=abstract-method
"""Target for bignum mod_raw test case generation."""
moduli = MODULI_DEFAULT # type: List[str]
def __init__(self, val_n: str, val_a: str, val_b: str = "0",
bits_in_limb: int = 64) -> None:
super().__init__(val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb)
self.val_n = val_n
# Setting the int versions here as opposed to making them @properties
# provides earlier/more robust input validation.
self.int_n = hex_to_int(val_n)
@property
def boundary(self) -> int:
return self.int_n
@property
def arg_n(self) -> str:
return self.format_arg(self.val_n)
def arguments(self) -> List[str]:
return [quote_str(self.arg_n)] + super().arguments()
@property
def r(self) -> int: # pylint: disable=invalid-name
l = limbs_mpi(self.int_n, self.bits_in_limb)
return bound_mpi_limbs(l, self.bits_in_limb)
@property
def r_inv(self) -> int:
return invmod(self.r, self.int_n)
@property
def r2(self) -> int: # pylint: disable=invalid-name
return pow(self.r, 2)
@property
def is_valid(self) -> bool:
if self.int_a >= self.int_n:
return False
if self.arity == 2 and self.int_b >= self.int_n:
return False
return True
def description(self) -> str:
"""Generate a description for the test case.
It uses the form A `symbol` B mod N, where symbol is used to represent
the operation.
"""
if not self.case_description:
return super().description() + " mod {:x}".format(self.int_n)
return super().description()
@classmethod
def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]:
if cls.arity == 1:
yield from ((n, a, "0") for a, n in cls.input_cases)
elif cls.arity == 2:
yield from ((n, a, b) for a, b, n in cls.input_cases)
else:
raise ValueError("Unsupported number of operands!")
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
if cls.input_style not in cls.input_styles:
raise ValueError("Unknown input style!")
if cls.arity not in cls.arities:
raise ValueError("Unsupported number of operands!")
if cls.input_style == "arch_split":
test_objects = (cls(n, a, b, bits_in_limb=bil)
for n in cls.moduli
for a, b in cls.get_value_pairs()
for bil in cls.limb_sizes)
special_cases = (cls(*args, bits_in_limb=bil)
for args in cls.input_cases_args()
for bil in cls.limb_sizes)
else:
test_objects = (cls(n, a, b)
for n in cls.moduli
for a, b in cls.get_value_pairs())
special_cases = (cls(*args) for args in cls.input_cases_args())
yield from (valid_test_object.create_test_case()
for valid_test_object in filter(
lambda test_object: test_object.is_valid,
chain(test_objects, special_cases)
))
# BEGIN MERGE SLOT 1

View File

@ -16,20 +16,19 @@
import random
from abc import ABCMeta
from typing import Dict, Iterator, List, Tuple
from . import test_case
from . import test_data_generation
from . import bignum_common
class BignumCoreTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
class BignumCoreTarget(test_data_generation.BaseTarget):
#pylint: disable=abstract-method, too-few-public-methods
"""Target for bignum core test case generation."""
target_basename = 'test_suite_bignum_core.generated'
class BignumCoreShiftR(BignumCoreTarget, metaclass=ABCMeta):
class BignumCoreShiftR(BignumCoreTarget, test_data_generation.BaseTest):
"""Test cases for mbedtls_bignum_core_shift_r()."""
count = 0
test_function = "mpi_core_shift_r"
@ -69,7 +68,7 @@ class BignumCoreShiftR(BignumCoreTarget, metaclass=ABCMeta):
for count in counts:
yield cls(input_hex, descr, count).create_test_case()
class BignumCoreCTLookup(BignumCoreTarget, metaclass=ABCMeta):
class BignumCoreCTLookup(BignumCoreTarget, test_data_generation.BaseTest):
"""Test cases for mbedtls_mpi_core_ct_uint_table_lookup()."""
test_function = "mpi_core_ct_uint_table_lookup"
test_name = "Constant time MPI table lookup"
@ -107,104 +106,33 @@ class BignumCoreCTLookup(BignumCoreTarget, metaclass=ABCMeta):
yield (cls(bitsize, bitsize_description, window_size)
.create_test_case())
class BignumCoreOperation(bignum_common.OperationCommon, BignumCoreTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
"""Common features for bignum core operations."""
input_values = [
"0", "1", "3", "f", "fe", "ff", "100", "ff00", "fffe", "ffff", "10000",
"fffffffe", "ffffffff", "100000000", "1f7f7f7f7f7f7f",
"8000000000000000", "fefefefefefefefe", "fffffffffffffffe",
"ffffffffffffffff", "10000000000000000", "1234567890abcdef0",
"fffffffffffffffffefefefefefefefe", "fffffffffffffffffffffffffffffffe",
"ffffffffffffffffffffffffffffffff", "100000000000000000000000000000000",
"1234567890abcdef01234567890abcdef0",
"fffffffffffffffffffffffffffffffffffffffffffffffffefefefefefefefe",
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe",
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
"10000000000000000000000000000000000000000000000000000000000000000",
"1234567890abcdef01234567890abcdef01234567890abcdef01234567890abcdef0",
(
"4df72d07b4b71c8dacb6cffa954f8d88254b6277099308baf003fab73227f34029"
"643b5a263f66e0d3c3fa297ef71755efd53b8fb6cb812c6bbf7bcf179298bd9947"
"c4c8b14324140a2c0f5fad7958a69050a987a6096e9f055fb38edf0c5889eca4a0"
"cfa99b45fbdeee4c696b328ddceae4723945901ec025076b12b"
)
]
def description(self) -> str:
"""Generate a description for the test case.
If not set, case_description uses the form A `symbol` B, where symbol
is used to represent the operation. Descriptions of each value are
generated to provide some context to the test case.
"""
if not self.case_description:
self.case_description = "{:x} {} {:x}".format(
self.int_a, self.symbol, self.int_b
)
return super().description()
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
for a_value, b_value in cls.get_value_pairs():
yield cls(a_value, b_value).create_test_case()
class BignumCoreOperationArchSplit(BignumCoreOperation):
#pylint: disable=abstract-method
"""Common features for bignum core operations where the result depends on
the limb size."""
def __init__(self, val_a: str, val_b: str, bits_in_limb: int) -> None:
super().__init__(val_a, val_b)
bound_val = max(self.int_a, self.int_b)
self.bits_in_limb = bits_in_limb
self.bound = bignum_common.bound_mpi(bound_val, self.bits_in_limb)
limbs = bignum_common.limbs_mpi(bound_val, self.bits_in_limb)
byte_len = limbs * self.bits_in_limb // 8
self.hex_digits = 2 * byte_len
if self.bits_in_limb == 32:
self.dependencies = ["MBEDTLS_HAVE_INT32"]
elif self.bits_in_limb == 64:
self.dependencies = ["MBEDTLS_HAVE_INT64"]
else:
raise ValueError("Invalid number of bits in limb!")
self.arg_a = self.arg_a.zfill(self.hex_digits)
self.arg_b = self.arg_b.zfill(self.hex_digits)
def pad_to_limbs(self, val) -> str:
return "{:x}".format(val).zfill(self.hex_digits)
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
for a_value, b_value in cls.get_value_pairs():
yield cls(a_value, b_value, 32).create_test_case()
yield cls(a_value, b_value, 64).create_test_case()
class BignumCoreAddAndAddIf(BignumCoreOperationArchSplit):
class BignumCoreAddAndAddIf(BignumCoreTarget, bignum_common.OperationCommon):
"""Test cases for bignum core add and add-if."""
count = 0
symbol = "+"
test_function = "mpi_core_add_and_add_if"
test_name = "mpi_core_add_and_add_if"
input_style = "arch_split"
unique_combinations_only = True
def result(self) -> List[str]:
result = self.int_a + self.int_b
carry, result = divmod(result, self.bound)
carry, result = divmod(result, self.limb_boundary)
return [
bignum_common.quote_str(self.pad_to_limbs(result)),
self.format_result(result),
str(carry)
]
class BignumCoreSub(BignumCoreOperation):
class BignumCoreSub(BignumCoreTarget, bignum_common.OperationCommon):
"""Test cases for bignum core sub."""
count = 0
symbol = "-"
test_function = "mpi_core_sub"
test_name = "mbedtls_mpi_core_sub"
unique_combinations_only = False
def result(self) -> List[str]:
if self.int_a >= self.int_b:
@ -224,12 +152,11 @@ class BignumCoreSub(BignumCoreOperation):
]
class BignumCoreMLA(BignumCoreOperation):
class BignumCoreMLA(BignumCoreTarget, bignum_common.OperationCommon):
"""Test cases for fixed-size multiply accumulate."""
count = 0
test_function = "mpi_core_mla"
test_name = "mbedtls_mpi_core_mla"
unique_combinations_only = False
input_values = [
"0", "1", "fffe", "ffffffff", "100000000", "20000000000000",
@ -288,6 +215,16 @@ class BignumCoreMLA(BignumCoreOperation):
"\"{:x}\"".format(carry_8)
]
@classmethod
def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
"""Generator to yield pairs of inputs.
Combinations are first generated from all input values, and then
specific cases provided.
"""
yield from super().get_value_pairs()
yield from cls.input_cases
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
"""Override for additional scalar input."""
@ -297,7 +234,7 @@ class BignumCoreMLA(BignumCoreOperation):
yield cur_op.create_test_case()
class BignumCoreMontmul(BignumCoreTarget):
class BignumCoreMontmul(BignumCoreTarget, test_data_generation.BaseTest):
"""Test cases for Montgomery multiplication."""
count = 0
test_function = "mpi_core_montmul"

View File

@ -0,0 +1,136 @@
"""Base values and datasets for bignum generated tests and helper functions that
produced them."""
# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
# Functions calling these were used to produce test data and are here only for
# reproducability, they are not used by the test generation framework/classes
try:
from Cryptodome.Util.number import isPrime, getPrime #type: ignore #pylint: disable=import-error
except ImportError:
pass
# Generated by bignum_common.gen_safe_prime(192,1)
SAFE_PRIME_192_BIT_SEED_1 = "d1c127a667786703830500038ebaef20e5a3e2dc378fb75b"
# First number generated by random.getrandbits(192) - seed(2,2), not a prime
RANDOM_192_BIT_SEED_2_NO1 = "177219d30e7a269fd95bafc8f2a4d27bdcf4bb99f4bea973"
# Second number generated by random.getrandbits(192) - seed(2,2), not a prime
RANDOM_192_BIT_SEED_2_NO2 = "cf1822ffbc6887782b491044d5e341245c6e433715ba2bdd"
# Third number generated by random.getrandbits(192) - seed(2,2), not a prime
RANDOM_192_BIT_SEED_2_NO3 = "3653f8dd9b1f282e4067c3584ee207f8da94e3e8ab73738f"
# Fourth number generated by random.getrandbits(192) - seed(2,2), not a prime
RANDOM_192_BIT_SEED_2_NO4 = "ffed9235288bc781ae66267594c9c9500925e4749b575bd1"
# Ninth number generated by random.getrandbits(192) - seed(2,2), not a prime
RANDOM_192_BIT_SEED_2_NO9 = "2a1be9cd8697bbd0e2520e33e44c50556c71c4a66148a86f"
# Generated by bignum_common.gen_safe_prime(1024,3)
SAFE_PRIME_1024_BIT_SEED_3 = ("c93ba7ec74d96f411ba008bdb78e63ff11bb5df46a51e16b"
"2c9d156f8e4e18abf5e052cb01f47d0d1925a77f60991577"
"e128fb6f52f34a27950a594baadd3d8057abeb222cf3cca9"
"62db16abf79f2ada5bd29ab2f51244bf295eff9f6aaba130"
"2efc449b128be75eeaca04bc3c1a155d11d14e8be32a2c82"
"87b3996cf6ad5223")
# First number generated by random.getrandbits(1024) - seed(4,2), not a prime
RANDOM_1024_BIT_SEED_4_NO1 = ("6905269ed6f0b09f165c8ce36e2f24b43000de01b2ed40ed"
"3addccb2c33be0ac79d679346d4ac7a5c3902b38963dc6e8"
"534f45738d048ec0f1099c6c3e1b258fd724452ccea71ff4"
"a14876aeaff1a098ca5996666ceab360512bd13110722311"
"710cf5327ac435a7a97c643656412a9b8a1abcd1a6916c74"
"da4f9fc3c6da5d7")
# Second number generated by random.getrandbits(1024) - seed(4,2), not a prime
RANDOM_1024_BIT_SEED_4_NO2 = ("f1cfd99216df648647adec26793d0e453f5082492d83a823"
"3fb62d2c81862fc9634f806fabf4a07c566002249b191bf4"
"d8441b5616332aca5f552773e14b0190d93936e1daca3c06"
"f5ff0c03bb5d7385de08caa1a08179104a25e4664f5253a0"
"2a3187853184ff27459142deccea264542a00403ce80c4b0"
"a4042bb3d4341aad")
# Third number generated by random.getrandbits(1024) - seed(4,2), not a prime
RANDOM_1024_BIT_SEED_4_NO3 = ("14c15c910b11ad28cc21ce88d0060cc54278c2614e1bcb38"
"3bb4a570294c4ea3738d243a6e58d5ca49c7b59b995253fd"
"6c79a3de69f85e3131f3b9238224b122c3e4a892d9196ada"
"4fcfa583e1df8af9b474c7e89286a1754abcb06ae8abb93f"
"01d89a024cdce7a6d7288ff68c320f89f1347e0cdd905ecf"
"d160c5d0ef412ed6")
# Fourth number generated by random.getrandbits(1024) - seed(4,2), not a prime
RANDOM_1024_BIT_SEED_4_NO4 = ("32decd6b8efbc170a26a25c852175b7a96b98b5fbf37a2be"
"6f98bca35b17b9662f0733c846bbe9e870ef55b1a1f65507"
"a2909cb633e238b4e9dd38b869ace91311021c9e32111ac1"
"ac7cc4a4ff4dab102522d53857c49391b36cc9aa78a330a1"
"a5e333cb88dcf94384d4cd1f47ca7883ff5a52f1a05885ac"
"7671863c0bdbc23a")
# Fifth number generated by random.getrandbits(1024) - seed(4,2), not a prime
RANDOM_1024_BIT_SEED_4_NO5 = ("53be4721f5b9e1f5acdac615bc20f6264922b9ccf469aef8"
"f6e7d078e55b85dd1525f363b281b8885b69dc230af5ac87"
"0692b534758240df4a7a03052d733dcdef40af2e54c0ce68"
"1f44ebd13cc75f3edcb285f89d8cf4d4950b16ffc3e1ac3b"
"4708d9893a973000b54a23020fc5b043d6e4a51519d9c9cc"
"52d32377e78131c1")
# Adding 192 bit and 1024 bit numbers because these are the shortest required
# for ECC and RSA respectively.
INPUTS_DEFAULT = [
"0", "1", # corner cases
"2", "3", # small primes
"4", # non-prime even
"38", # small random
SAFE_PRIME_192_BIT_SEED_1, # prime
RANDOM_192_BIT_SEED_2_NO1, # not a prime
RANDOM_192_BIT_SEED_2_NO2, # not a prime
SAFE_PRIME_1024_BIT_SEED_3, # prime
RANDOM_1024_BIT_SEED_4_NO1, # not a prime
RANDOM_1024_BIT_SEED_4_NO3, # not a prime
RANDOM_1024_BIT_SEED_4_NO2, # largest (not a prime)
]
# Only odd moduli are present as in the new bignum code only odd moduli are
# supported for now.
MODULI_DEFAULT = [
"53", # safe prime
"45", # non-prime
SAFE_PRIME_192_BIT_SEED_1, # safe prime
RANDOM_192_BIT_SEED_2_NO4, # not a prime
SAFE_PRIME_1024_BIT_SEED_3, # safe prime
RANDOM_1024_BIT_SEED_4_NO5, # not a prime
]
def __gen_safe_prime(bits, seed):
'''
Generate a safe prime.
This function is intended for generating constants offline and shouldn't be
used in test generation classes.
Requires pycryptodomex for getPrime and isPrime and python 3.9 or later for
randbytes.
'''
rng = random.Random()
# We want reproducability across python versions
rng.seed(seed, version=2)
while True:
prime = 2*getPrime(bits-1, rng.randbytes)+1 #pylint: disable=no-member
if isPrime(prime, 1e-30):
return prime

View File

@ -14,12 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
from . import test_data_generation
class BignumModTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
class BignumModTarget(test_data_generation.BaseTarget):
#pylint: disable=abstract-method, too-few-public-methods
"""Target for bignum mod test case generation."""
target_basename = 'test_suite_bignum_mod.generated'

View File

@ -14,89 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
from typing import Dict, Iterator, List
from typing import Dict, List
from . import test_case
from . import test_data_generation
from . import bignum_common
class BignumModRawTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
class BignumModRawTarget(test_data_generation.BaseTarget):
#pylint: disable=abstract-method, too-few-public-methods
"""Target for bignum mod_raw test case generation."""
target_basename = 'test_suite_bignum_mod_raw.generated'
class BignumModRawOperation(bignum_common.OperationCommon, BignumModRawTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
"""Target for bignum mod_raw test case generation."""
def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None:
super().__init__(val_a=val_a, val_b=val_b)
self.val_n = val_n
self.bits_in_limb = bits_in_limb
@property
def int_n(self) -> int:
return bignum_common.hex_to_int(self.val_n)
@property
def boundary(self) -> int:
data_in = [self.int_a, self.int_b, self.int_n]
return max([n for n in data_in if n is not None])
@property
def limbs(self) -> int:
return bignum_common.limbs_mpi(self.boundary, self.bits_in_limb)
@property
def hex_digits(self) -> int:
return 2 * (self.limbs * self.bits_in_limb // 8)
@property
def hex_n(self) -> str:
return "{:x}".format(self.int_n).zfill(self.hex_digits)
@property
def hex_a(self) -> str:
return "{:x}".format(self.int_a).zfill(self.hex_digits)
@property
def hex_b(self) -> str:
return "{:x}".format(self.int_b).zfill(self.hex_digits)
@property
def r(self) -> int: # pylint: disable=invalid-name
l = bignum_common.limbs_mpi(self.int_n, self.bits_in_limb)
return bignum_common.bound_mpi_limbs(l, self.bits_in_limb)
@property
def r_inv(self) -> int:
return bignum_common.invmod(self.r, self.int_n)
@property
def r2(self) -> int: # pylint: disable=invalid-name
return pow(self.r, 2)
class BignumModRawOperationArchSplit(BignumModRawOperation):
#pylint: disable=abstract-method
"""Common features for bignum mod raw operations where the result depends on
the limb size."""
limb_sizes = [32, 64] # type: List[int]
def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None:
super().__init__(val_n=val_n, val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb)
if bits_in_limb not in self.limb_sizes:
raise ValueError("Invalid number of bits in limb!")
self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)]
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
for a_value, b_value in cls.get_value_pairs():
for bil in cls.limb_sizes:
yield cls(a_value, b_value, bits_in_limb=bil).create_test_case()
# BEGIN MERGE SLOT 1
# END MERGE SLOT 1
@ -122,126 +49,35 @@ class BignumModRawOperationArchSplit(BignumModRawOperation):
# END MERGE SLOT 6
# BEGIN MERGE SLOT 7
class BignumModRawConvertToMont(BignumModRawOperationArchSplit):
""" Test cases for mpi_mod_raw_to_mont_rep(). """
class BignumModRawConvertToMont(bignum_common.ModOperationCommon,
BignumModRawTarget):
""" Test cases for mpi_mod_raw_to_mont_rep(). """
test_function = "mpi_mod_raw_to_mont_rep"
test_name = "Convert into Mont: "
test_data_moduli = ["b",
"fd",
"eeff99aa37",
"eeff99aa11",
"800000000005",
"7fffffffffffffff",
"80fe000a10000001",
"25a55a46e5da99c71c7",
"1058ad82120c3a10196bb36229c1",
"7e35b84cb19ea5bc57ec37f5e431462fa962d98c1e63738d4657f"
"18ad6532e6adc3eafe67f1e5fa262af94cee8d3e7268593942a2a"
"98df75154f8c914a282f8b",
"8335616aed761f1f7f44e6bd49e807b82e3bf2bf11bfa63",
"ffcece570f2f991013f26dd5b03c4c5b65f97be5905f36cb4664f"
"2c78ff80aa8135a4aaf57ccb8a0aca2f394909a74cef1ef6758a6"
"4d11e2c149c393659d124bfc94196f0ce88f7d7d567efa5a649e2"
"deefaa6e10fdc3deac60d606bf63fc540ac95294347031aefd73d"
"6a9ee10188aaeb7a90d920894553cb196881691cadc51808715a0"
"7e8b24fcb1a63df047c7cdf084dd177ba368c806f3d51ddb5d389"
"8c863e687ecaf7d649a57a46264a582f94d3c8f2edaf59f77a7f6"
"bdaf83c991e8f06abe220ec8507386fce8c3da84c6c3903ab8f3a"
"d4630a204196a7dbcbd9bcca4e40ec5cc5c09938d49f5e1e6181d"
"b8896f33bb12e6ef73f12ec5c5ea7a8a337"
]
test_input_numbers = ["0",
"1",
"97",
"f5",
"6f5c3",
"745bfe50f7",
"ffa1f9924123",
"334a8b983c79bd",
"5b84f632b58f3461",
"19acd15bc38008e1",
"ffffffffffffffff",
"54ce6a6bb8247fa0427cfc75a6b0599",
"fecafe8eca052f154ce6a6bb8247fa019558bfeecce9bb9",
"a87d7a56fa4bfdc7da42ef798b9cf6843d4c54794698cb14d72"
"851dec9586a319f4bb6d5695acbd7c92e7a42a5ede6972adcbc"
"f68425265887f2d721f462b7f1b91531bac29fa648facb8e3c6"
"1bd5ae42d5a59ba1c89a95897bfe541a8ce1d633b98f379c481"
"6f25e21f6ac49286b261adb4b78274fe5f61c187581f213e84b"
"2a821e341ef956ecd5de89e6c1a35418cd74a549379d2d4594a"
"577543147f8e35b3514e62cf3e89d1156cdc91ab5f4c928fbd6"
"9148c35df5962fed381f4d8a62852a36823d5425f7487c13a12"
"523473fb823aa9d6ea5f42e794e15f2c1a8785cf6b7d51a4617"
"947fb3baf674f74a673cf1d38126983a19ed52c7439fab42c2185"
]
descr_tpl = '{} #{} N: \"{}\" A: \"{}\".'
symbol = "R *"
input_style = "arch_split"
arity = 1
def result(self) -> List[str]:
return [self.hex_x]
result = (self.int_a * self.r) % self.int_n
return [self.format_result(result)]
def arguments(self) -> List[str]:
return [bignum_common.quote_str(n) for n in [self.hex_n,
self.hex_a,
self.hex_x]]
def description(self) -> str:
return self.descr_tpl.format(self.test_name,
self.count,
self.int_n,
self.int_a)
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
for bil in [32, 64]:
for n in cls.test_data_moduli:
for i in cls.test_input_numbers:
# Skip invalid combinations where A.limbs > N.limbs
if bignum_common.hex_to_int(i) > bignum_common.hex_to_int(n):
continue
yield cls(n, i, bits_in_limb=bil).create_test_case()
@property
def x(self) -> int: # pylint: disable=invalid-name
return (self.int_a * self.r) % self.int_n
@property
def hex_x(self) -> str:
return "{:x}".format(self.x).zfill(self.hex_digits)
class BignumModRawConvertFromMont(BignumModRawConvertToMont):
class BignumModRawConvertFromMont(bignum_common.ModOperationCommon,
BignumModRawTarget):
""" Test cases for mpi_mod_raw_from_mont_rep(). """
test_function = "mpi_mod_raw_from_mont_rep"
test_name = "Convert from Mont: "
symbol = "1/R *"
input_style = "arch_split"
arity = 1
def result(self) -> List[str]:
result = (self.int_a * self.r_inv) % self.int_n
return [self.format_result(result)]
test_input_numbers = ["0",
"1",
"3ca",
"539ed428",
"7dfe5c6beb35a2d6",
"dca8de1c2adfc6d7aafb9b48e",
"a7d17b6c4be72f3d5c16bf9c1af6fc933",
"2fec97beec546f9553142ed52f147845463f579",
"378dc83b8bc5a7b62cba495af4919578dce6d4f175cadc4f",
"b6415f2a1a8e48a518345db11f56db3829c8f2c6415ab4a395a"
"b3ac2ea4cbef4af86eb18a84eb6ded4c6ecbfc4b59c2879a675"
"487f687adea9d197a84a5242a5cf6125ce19a6ad2e7341f1c57"
"d43ea4f4c852a51cb63dabcd1c9de2b827a3146a3d175b35bea"
"41ae75d2a286a3e9d43623152ac513dcdea1d72a7da846a8ab3"
"58d9be4926c79cfb287cf1cf25b689de3b912176be5dcaf4d4c"
"6e7cb839a4a3243a6c47c1e2c99d65c59d6fa3672575c2f1ca8"
"de6a32e854ec9d8ec635c96af7679fce26d7d159e4a9da3bd74"
"e1272c376cd926d74fe3fb164a5935cff3d5cdb92b35fe2cea32"
"138a7e6bfbc319ebd1725dacb9a359cbf693f2ecb785efb9d627"
]
@property
def x(self): # pylint: disable=invalid-name
return (self.int_a * self.r_inv) % self.int_n
# END MERGE SLOT 7
# BEGIN MERGE SLOT 8

View File

@ -25,6 +25,7 @@ import argparse
import os
import posixpath
import re
import inspect
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar
@ -35,12 +36,8 @@ from . import test_case
T = TypeVar('T') #pylint: disable=invalid-name
class BaseTarget(metaclass=ABCMeta):
"""Base target for test case generation.
Child classes of this class represent an output file, and can be referred
to as file targets. These indicate where test cases will be written to for
all subclasses of the file target, which is set by `target_basename`.
class BaseTest(metaclass=ABCMeta):
"""Base class for test case generation.
Attributes:
count: Counter for test cases from this class.
@ -48,8 +45,6 @@ class BaseTarget(metaclass=ABCMeta):
automatically generated using the class, or manually set.
dependencies: A list of dependencies required for the test case.
show_test_count: Toggle for inclusion of `count` in the test description.
target_basename: Basename of file to write generated tests to. This
should be specified in a child class of BaseTarget.
test_function: Test function which the class generates cases for.
test_name: A common name or description of the test function. This can
be `test_function`, a clearer equivalent, or a short summary of the
@ -59,7 +54,6 @@ class BaseTarget(metaclass=ABCMeta):
case_description = ""
dependencies = [] # type: List[str]
show_test_count = True
target_basename = ""
test_function = ""
test_name = ""
@ -121,6 +115,21 @@ class BaseTarget(metaclass=ABCMeta):
"""
raise NotImplementedError
class BaseTarget:
#pylint: disable=too-few-public-methods
"""Base target for test case generation.
Child classes of this class represent an output file, and can be referred
to as file targets. These indicate where test cases will be written to for
all subclasses of the file target, which is set by `target_basename`.
Attributes:
target_basename: Basename of file to write generated tests to. This
should be specified in a child class of BaseTarget.
"""
target_basename = ""
@classmethod
def generate_tests(cls) -> Iterator[test_case.TestCase]:
"""Generate test cases for the class and its subclasses.
@ -132,7 +141,8 @@ class BaseTarget(metaclass=ABCMeta):
yield from `generate_tests()` in each. Calling this method on a class X
will yield test cases from all classes derived from X.
"""
if cls.test_function:
if issubclass(cls, BaseTest) and not inspect.isabstract(cls):
#pylint: disable=no-member
yield from cls.generate_function_tests()
for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__):
yield from subclass.generate_tests()

View File

@ -57,7 +57,7 @@ of BaseTarget in test_data_generation.py.
import sys
from abc import ABCMeta
from typing import Iterator, List
from typing import List
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import test_case
@ -68,15 +68,17 @@ from mbedtls_dev import bignum_common
# the framework
from mbedtls_dev import bignum_core, bignum_mod_raw # pylint: disable=unused-import
class BignumTarget(test_data_generation.BaseTarget, metaclass=ABCMeta):
#pylint: disable=abstract-method
class BignumTarget(test_data_generation.BaseTarget):
#pylint: disable=too-few-public-methods
"""Target for bignum (legacy) test case generation."""
target_basename = 'test_suite_bignum.generated'
class BignumOperation(bignum_common.OperationCommon, BignumTarget, metaclass=ABCMeta):
class BignumOperation(bignum_common.OperationCommon, BignumTarget,
metaclass=ABCMeta):
#pylint: disable=abstract-method
"""Common features for bignum operations in legacy tests."""
unique_combinations_only = True
input_values = [
"", "0", "-", "-0",
"7b", "-7b",
@ -132,11 +134,6 @@ class BignumOperation(bignum_common.OperationCommon, BignumTarget, metaclass=ABC
tmp = "large " + tmp
return tmp
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
for a_value, b_value in cls.get_value_pairs():
yield cls(a_value, b_value).create_test_case()
class BignumCmp(BignumOperation):
"""Test cases for bignum value comparison."""