Redefine result() method to return List

Many bignum tests have multiple calculated result values, so return
these as a list, rather than formatting as a string.

Signed-off-by: Werner Lewis <werner.lewis@arm.com>
This commit is contained in:
Werner Lewis 2022-10-12 14:53:17 +01:00
parent 7a2731463b
commit 1b20e7e645
3 changed files with 38 additions and 25 deletions

View File

@ -108,10 +108,12 @@ class OperationCommon:
self.int_b = hex_to_int(val_b)
def arguments(self) -> List[str]:
return [quote_str(self.arg_a), quote_str(self.arg_b), self.result()]
return [
quote_str(self.arg_a), quote_str(self.arg_b)
] + self.result()
@abstractmethod
def result(self) -> str:
def result(self) -> List[str]:
"""Get the result of the operation.
This could be calculated during initialization and stored as `_result`

View File

@ -80,16 +80,19 @@ class BignumCoreAddIf(BignumCoreOperation):
test_function = "mpi_core_add_if"
test_name = "mbedtls_mpi_core_add_if"
def result(self) -> str:
def result(self) -> List[str]:
tmp = self.int_a + self.int_b
bound_val = max(self.int_a, self.int_b)
bound_4 = bignum_common.bound_mpi4(bound_val)
bound_8 = bignum_common.bound_mpi8(bound_val)
carry_4, remainder_4 = divmod(tmp, bound_4)
carry_8, remainder_8 = divmod(tmp, bound_8)
return "\"{:x}\":{}:\"{:x}\":{}".format(
remainder_4, carry_4, remainder_8, carry_8
)
return [
"\"{:x}\"".format(remainder_4),
str(carry_4),
"\"{:x}\"".format(remainder_8),
str(carry_8)
]
class BignumCoreSub(BignumCoreOperation):
@ -100,7 +103,7 @@ class BignumCoreSub(BignumCoreOperation):
test_name = "mbedtls_mpi_core_sub"
unique_combinations_only = False
def result(self) -> str:
def result(self) -> List[str]:
if self.int_a >= self.int_b:
result_4 = result_8 = self.int_a - self.int_b
carry = 0
@ -111,7 +114,11 @@ class BignumCoreSub(BignumCoreOperation):
bound_8 = bignum_common.bound_mpi8(bound_val)
result_8 = bound_8 + self.int_a - self.int_b
carry = 1
return "\"{:x}\":\"{:x}\":{}".format(result_4, result_8, carry)
return [
"\"{:x}\"".format(result_4),
"\"{:x}\"".format(result_8),
str(carry)
]
class BignumCoreMLA(BignumCoreOperation):
@ -153,9 +160,8 @@ class BignumCoreMLA(BignumCoreOperation):
return [
bignum_common.quote_str(self.arg_a),
bignum_common.quote_str(self.arg_b),
bignum_common.quote_str(self.arg_scalar),
self.result()
]
bignum_common.quote_str(self.arg_scalar)
] + self.result()
def description(self) -> str:
"""Override and add the additional scalar."""
@ -165,16 +171,19 @@ class BignumCoreMLA(BignumCoreOperation):
)
return super().description()
def result(self) -> str:
def result(self) -> List[str]:
result = self.int_a + (self.int_b * self.int_scalar)
bound_val = max(self.int_a, self.int_b)
bound_4 = bignum_common.bound_mpi4(bound_val)
bound_8 = bignum_common.bound_mpi8(bound_val)
carry_4, remainder_4 = divmod(result, bound_4)
carry_8, remainder_8 = divmod(result, bound_8)
return "\"{:x}\":\"{:x}\":\"{:x}\":\"{:x}\"".format(
remainder_4, carry_4, remainder_8, carry_8
)
return [
"\"{:x}\"".format(remainder_4),
"\"{:x}\"".format(carry_4),
"\"{:x}\"".format(remainder_8),
"\"{:x}\"".format(carry_8)
]
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
@ -557,9 +566,8 @@ class BignumCoreMontmul(BignumCoreTarget):
str(self.limbs_an8), str(self.limbs_b8),
bignum_common.quote_str(self.arg_a),
bignum_common.quote_str(self.arg_b),
bignum_common.quote_str(self.arg_n),
self.result()
]
bignum_common.quote_str(self.arg_n)
] + self.result()
def description(self) -> str:
if self.case_description != "replay":
@ -574,7 +582,7 @@ class BignumCoreMontmul(BignumCoreTarget):
self.case_description = tmp + self.case_description
return super().description()
def result(self) -> str:
def result(self) -> List[str]:
"""Get the result of the operation."""
r4 = bignum_common.bound_mpi4_limbs(self.limbs_an4)
i4 = bignum_common.invmod(r4, self.int_n)
@ -585,7 +593,10 @@ class BignumCoreMontmul(BignumCoreTarget):
i8 = bignum_common.invmod(r8, self.int_n)
x8 = self.int_a * self.int_b * i8
x8 = x8 % self.int_n
return "\"{:x}\":\"{:x}\"".format(x4, x8)
return [
"\"{:x}\"".format(x4),
"\"{:x}\"".format(x8)
]
def set_limbs(
self, limbs_an4: int, limbs_b4: int, limbs_an8: int, limbs_b8: int

View File

@ -57,7 +57,7 @@ of BaseTarget in test_data_generation.py.
import sys
from abc import ABCMeta
from typing import Iterator
from typing import Iterator, List
import scripts_path # pylint: disable=unused-import
from mbedtls_dev import test_case
@ -144,8 +144,8 @@ class BignumCmp(BignumOperation):
self._result = int(self.int_a > self.int_b) - int(self.int_a < self.int_b)
self.symbol = ["<", "==", ">"][self._result + 1]
def result(self) -> str:
return str(self._result)
def result(self) -> List[str]:
return [str(self._result)]
class BignumCmpAbs(BignumCmp):
@ -171,8 +171,8 @@ class BignumAdd(BignumOperation):
]
)
def result(self) -> str:
return bignum_common.quote_str("{:x}").format(self.int_a + self.int_b)
def result(self) -> List[str]:
return [bignum_common.quote_str("{:x}").format(self.int_a + self.int_b)]
if __name__ == '__main__':