generate_test_keys: minor improvements

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
This commit is contained in:
Valerio Setti 2024-04-17 15:12:49 +02:00
parent 37bc93cbeb
commit ee74339180

View File

@ -7,16 +7,12 @@
generating the required key at run time. This helps speeding up testing.""" generating the required key at run time. This helps speeding up testing."""
import os import os
import sys
from typing import Iterator from typing import Iterator
import re import re
# pylint: disable=wrong-import-position
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/"
sys.path.append(SCRIPT_DIR + "../../scripts/")
from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
import scripts_path # pylint: disable=unused-import import scripts_path # pylint: disable=unused-import
from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
OUTPUT_HEADER_FILE = SCRIPT_DIR + "../src/test_keys.h" OUTPUT_HEADER_FILE = os.path.dirname(os.path.abspath(__file__)) + "/../src/test_keys.h"
BYTES_PER_LINE = 16 BYTES_PER_LINE = 16
def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]: def c_byte_array_literal_content(array_name: str, key_data: bytes) -> Iterator[str]:
@ -53,25 +49,25 @@ def get_ec_key_family(key: str) -> str:
# - understand if the curve is supported in legacy symbols (MBEDTLS_ECP_DP_...) # - understand if the curve is supported in legacy symbols (MBEDTLS_ECP_DP_...)
EC_NAME_CONVERSION = { EC_NAME_CONVERSION = {
'PSA_ECC_FAMILY_SECP_K1': { 'PSA_ECC_FAMILY_SECP_K1': {
192: ['secp', 'k1'], 192: ('secp', 'k1'),
224: ['secp', 'k1'], 224: ('secp', 'k1'),
256: ['secp', 'k1'] 256: ('secp', 'k1')
}, },
'PSA_ECC_FAMILY_SECP_R1': { 'PSA_ECC_FAMILY_SECP_R1': {
192: ['secp', 'r1'], 192: ('secp', 'r1'),
224: ['secp', 'r1'], 224: ('secp', 'r1'),
256: ['secp', 'r1'], 256: ('secp', 'r1'),
384: ['secp', 'r1'], 384: ('secp', 'r1'),
521: ['secp', 'r1'] 521: ('secp', 'r1')
}, },
'PSA_ECC_FAMILY_BRAINPOOL_P_R1': { 'PSA_ECC_FAMILY_BRAINPOOL_P_R1': {
256: ['bp', 'r1'], 256: ('bp', 'r1'),
384: ['bp', 'r1'], 384: ('bp', 'r1'),
512: ['bp', 'r1'] 512: ('bp', 'r1')
}, },
'PSA_ECC_FAMILY_MONTGOMERY': { 'PSA_ECC_FAMILY_MONTGOMERY': {
255: ['curve', '19'], 255: ('curve', '19'),
448: ['curve', ''] 448: ('curve', '')
} }
} }
@ -80,13 +76,13 @@ def get_ec_curve_name(priv_key: str, bits: int) -> str:
try: try:
prefix = EC_NAME_CONVERSION[ec_family][bits][0] prefix = EC_NAME_CONVERSION[ec_family][bits][0]
suffix = EC_NAME_CONVERSION[ec_family][bits][1] suffix = EC_NAME_CONVERSION[ec_family][bits][1]
except: # pylint: disable=bare-except except KeyError:
return "" return ""
return prefix + str(bits) + suffix return prefix + str(bits) + suffix
def get_look_up_table_entry(key_type: str, curve_or_keybits: str, def get_look_up_table_entry(key_type: str, curve_or_keybits: str,
priv_array_name: str, pub_array_name: str) -> Iterator[str]: priv_array_name: str, pub_array_name: str) -> Iterator[str]:
yield "\n {{ {}, ".format("1" if key_type == "ec" else "0") yield " {{ {}, ".format("1" if key_type == "ec" else "0")
yield "{},\n".format(curve_or_keybits) yield "{},\n".format(curve_or_keybits)
yield " {0}, sizeof({0}),\n".format(priv_array_name) yield " {0}, sizeof({0}),\n".format(priv_array_name)
yield " {0}, sizeof({0}) }},".format(pub_array_name) yield " {0}, sizeof({0}) }},".format(pub_array_name)
@ -104,12 +100,12 @@ def main() -> None:
" *********************************************************************************/\n" " *********************************************************************************/\n"
) )
look_up_table = "" look_up_table = []
# Get a list of private keys only in order to get a single item for every # Get a list of private keys only in order to get a single item for every
# (key type, key bits) pair. We know that ASYMMETRIC_KEY_DATA # (key type, key bits) pair. We know that ASYMMETRIC_KEY_DATA
# contains also the public counterpart. # contains also the public counterpart.
priv_keys = [key for key in ASYMMETRIC_KEY_DATA if re.match(r'.*_KEY_PAIR', key)] priv_keys = [key for key in ASYMMETRIC_KEY_DATA if '_KEY_PAIR' in key]
for priv_key in priv_keys: for priv_key in priv_keys:
key_type = get_key_type(priv_key) key_type = get_key_type(priv_key)
@ -142,9 +138,8 @@ def main() -> None:
curve_or_keybits = "MBEDTLS_ECP_DP_" + curve.upper() curve_or_keybits = "MBEDTLS_ECP_DP_" + curve.upper()
else: else:
curve_or_keybits = str(bits) curve_or_keybits = str(bits)
look_up_table = look_up_table + \ look_up_table.append(''.join(get_look_up_table_entry(key_type, curve_or_keybits,
''.join(get_look_up_table_entry(key_type, curve_or_keybits, array_name_priv, array_name_pub)))
array_name_priv, array_name_pub))
# Write the lookup table: the struct containing pointers to all the arrays we created above. # Write the lookup table: the struct containing pointers to all the arrays we created above.
output_file.write(""" output_file.write("""
struct predefined_key_element { struct predefined_key_element {
@ -156,8 +151,10 @@ struct predefined_key_element {
size_t pub_key_len; size_t pub_key_len;
}; };
struct predefined_key_element predefined_keys[] = {""") struct predefined_key_element predefined_keys[] = {
output_file.write("{}\n}};\n".format(look_up_table)) """)
output_file.write("\n".join(look_up_table))
output_file.write("\n};\n")
if __name__ == '__main__': if __name__ == '__main__':
main() main()