diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py index 88fa4dfd0e..c81770ed7d 100644 --- a/scripts/mbedtls_dev/bignum_common.py +++ b/scripts/mbedtls_dev/bignum_common.py @@ -43,31 +43,18 @@ def hex_to_int(val: str) -> int: def quote_str(val) -> str: return "\"{}\"".format(val) -def bound_mpi8(val: int) -> int: - """First number exceeding 8-byte limbs needed for given input value.""" - return bound_mpi8_limbs(limbs_mpi8(val)) +def bound_mpi(val: int, bits_in_limb: int) -> int: + """First number exceeding number of limbs needed for given input value.""" + return bound_mpi_limbs(limbs_mpi(val, bits_in_limb), bits_in_limb) -def bound_mpi4(val: int) -> int: - """First number exceeding 4-byte limbs needed for given input value.""" - return bound_mpi4_limbs(limbs_mpi4(val)) - -def bound_mpi8_limbs(limbs: int) -> int: - """First number exceeding maximum of given 8-byte limbs.""" - bits = 64 * limbs +def bound_mpi_limbs(limbs: int, bits_in_limb: int) -> int: + """First number exceeding maximum of given number of limbs.""" + bits = bits_in_limb * limbs return 1 << bits -def bound_mpi4_limbs(limbs: int) -> int: - """First number exceeding maximum of given 4-byte limbs.""" - bits = 32 * limbs - return 1 << bits - -def limbs_mpi8(val: int) -> int: - """Return the number of 8-byte limbs required to store value.""" - return (val.bit_length() + 63) // 64 - -def limbs_mpi4(val: int) -> int: - """Return the number of 4-byte limbs required to store value.""" - return (val.bit_length() + 31) // 32 +def limbs_mpi(val: int, bits_in_limb: int) -> int: + """Return the number of limbs required to store value.""" + return (val.bit_length() + bits_in_limb - 1) // bits_in_limb def combination_pairs(values: List[T]) -> List[Tuple[T, T]]: """Return all pair combinations from input values. diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py index 1bd2482669..3652ac20ab 100644 --- a/scripts/mbedtls_dev/bignum_core.py +++ b/scripts/mbedtls_dev/bignum_core.py @@ -83,8 +83,8 @@ class BignumCoreAddIf(BignumCoreOperation): def result(self) -> List[str]: tmp = self.int_a + self.int_b bound_val = max(self.int_a, self.int_b) - bound_4 = bignum_common.bound_mpi4(bound_val) - bound_8 = bignum_common.bound_mpi8(bound_val) + bound_4 = bignum_common.bound_mpi(bound_val, 32) + bound_8 = bignum_common.bound_mpi(bound_val, 64) carry_4, remainder_4 = divmod(tmp, bound_4) carry_8, remainder_8 = divmod(tmp, bound_8) return [ @@ -109,9 +109,9 @@ class BignumCoreSub(BignumCoreOperation): carry = 0 else: bound_val = max(self.int_a, self.int_b) - bound_4 = bignum_common.bound_mpi4(bound_val) + bound_4 = bignum_common.bound_mpi(bound_val, 32) result_4 = bound_4 + self.int_a - self.int_b - bound_8 = bignum_common.bound_mpi8(bound_val) + bound_8 = bignum_common.bound_mpi(bound_val, 64) result_8 = bound_8 + self.int_a - self.int_b carry = 1 return [ @@ -153,7 +153,7 @@ class BignumCoreMLA(BignumCoreOperation): super().__init__(val_a, val_b) self.arg_scalar = val_s self.int_scalar = bignum_common.hex_to_int(val_s) - if bignum_common.limbs_mpi4(self.int_scalar) > 1: + if bignum_common.limbs_mpi(self.int_scalar, 32) > 1: self.dependencies = ["MBEDTLS_HAVE_INT64"] def arguments(self) -> List[str]: @@ -174,8 +174,8 @@ class BignumCoreMLA(BignumCoreOperation): def result(self) -> List[str]: result = self.int_a + (self.int_b * self.int_scalar) bound_val = max(self.int_a, self.int_b) - bound_4 = bignum_common.bound_mpi4(bound_val) - bound_8 = bignum_common.bound_mpi8(bound_val) + bound_4 = bignum_common.bound_mpi(bound_val, 32) + bound_8 = bignum_common.bound_mpi(bound_val, 64) carry_4, remainder_4 = divmod(result, bound_4) carry_8, remainder_8 = divmod(result, bound_8) return [ @@ -548,12 +548,12 @@ class BignumCoreMontmul(BignumCoreTarget): self.arg_n = val_n self.int_n = bignum_common.hex_to_int(val_n) - limbs_a4 = bignum_common.limbs_mpi4(self.int_a) - limbs_a8 = bignum_common.limbs_mpi8(self.int_a) - self.limbs_b4 = bignum_common.limbs_mpi4(self.int_b) - self.limbs_b8 = bignum_common.limbs_mpi8(self.int_b) - self.limbs_an4 = bignum_common.limbs_mpi4(self.int_n) - self.limbs_an8 = bignum_common.limbs_mpi8(self.int_n) + limbs_a4 = bignum_common.limbs_mpi(self.int_a, 32) + limbs_a8 = bignum_common.limbs_mpi(self.int_a, 64) + self.limbs_b4 = bignum_common.limbs_mpi(self.int_b, 32) + self.limbs_b8 = bignum_common.limbs_mpi(self.int_b, 64) + self.limbs_an4 = bignum_common.limbs_mpi(self.int_n, 32) + self.limbs_an8 = bignum_common.limbs_mpi(self.int_n, 64) if limbs_a4 > self.limbs_an4 or limbs_a8 > self.limbs_an8: raise Exception("Limbs of input A ({}) exceeds N ({})".format( @@ -584,12 +584,12 @@ class BignumCoreMontmul(BignumCoreTarget): def result(self) -> List[str]: """Get the result of the operation.""" - r4 = bignum_common.bound_mpi4_limbs(self.limbs_an4) + r4 = bignum_common.bound_mpi_limbs(self.limbs_an4, 32) i4 = bignum_common.invmod(r4, self.int_n) x4 = self.int_a * self.int_b * i4 x4 = x4 % self.int_n - r8 = bignum_common.bound_mpi8_limbs(self.limbs_an8) + r8 = bignum_common.bound_mpi_limbs(self.limbs_an8, 64) i8 = bignum_common.invmod(r8, self.int_n) x8 = self.int_a * self.int_b * i8 x8 = x8 % self.int_n