diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py index b22846b710..7d7170d170 100644 --- a/scripts/mbedtls_dev/bignum_common.py +++ b/scripts/mbedtls_dev/bignum_common.py @@ -209,12 +209,12 @@ class OperationCommon(test_data_generation.BaseTest): if cls.arity not in cls.arities: raise ValueError("Unsupported number of operands!") if cls.input_style == "arch_split": - test_objects = (cls(a_value, b_value, bits_in_limb=bil) - for a_value, b_value in cls.get_value_pairs() + test_objects = (cls(a, b, bits_in_limb=bil) + for a, b in cls.get_value_pairs() for bil in cls.limb_sizes) else: - test_objects = (cls(a_value, b_value) for - a_value, b_value in cls.get_value_pairs()) + test_objects = (cls(a, b) + for a, b in cls.get_value_pairs()) yield from (valid_test_object.create_test_case() for valid_test_object in filter( lambda test_object: test_object.is_valid, @@ -225,6 +225,7 @@ class OperationCommon(test_data_generation.BaseTest): class ModOperationCommon(OperationCommon): #pylint: disable=abstract-method """Target for bignum mod_raw test case generation.""" + moduli = [] # type: List[str] def __init__(self, val_n: str, val_a: str, val_b: str = "0", bits_in_limb: int = 64) -> None: @@ -258,6 +259,34 @@ class ModOperationCommon(OperationCommon): def r2(self) -> int: # pylint: disable=invalid-name return pow(self.r, 2) + @property + def is_valid(self) -> bool: + if self.int_a >= self.int_n: + return False + if self.arity == 2 and self.int_b >= self.int_n: + return False + return True + + @classmethod + def generate_function_tests(cls) -> Iterator[test_case.TestCase]: + if cls.input_style not in cls.input_styles: + raise ValueError("Unknown input style!") + if cls.arity not in cls.arities: + raise ValueError("Unsupported number of operands!") + if cls.input_style == "arch_split": + test_objects = (cls(n, a, b, bits_in_limb=bil) + for n in cls.moduli + for a, b in cls.get_value_pairs() + for bil in cls.limb_sizes) + else: + test_objects = (cls(n, a, b) + for n in cls.moduli + for a, b in cls.get_value_pairs()) + yield from (valid_test_object.create_test_case() + for valid_test_object in filter( + lambda test_object: test_object.is_valid, + test_objects + )) # BEGIN MERGE SLOT 1