From 882e57ecba4d524bb11c4d72ce5d0fb1d7614763 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Fri, 12 Apr 2019 00:12:07 +0200 Subject: [PATCH] psa_constant_names: support key agreement algorithms --- programs/psa/psa_constant_names.c | 21 ++++++----- scripts/generate_psa_constants.py | 46 ++++++++++++++++++++---- tests/scripts/test_psa_constant_names.py | 4 ++- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/programs/psa/psa_constant_names.c b/programs/psa/psa_constant_names.c index 5514100219..5240b084a4 100644 --- a/programs/psa/psa_constant_names.c +++ b/programs/psa/psa_constant_names.c @@ -84,22 +84,21 @@ static void append_with_curve(char **buffer, size_t buffer_size, append(buffer, buffer_size, required_size, ")", 1); } -static void append_with_hash(char **buffer, size_t buffer_size, - size_t *required_size, - const char *string, size_t length, - psa_algorithm_t hash_alg) +typedef const char *(*psa_get_algorithm_name_func_ptr)(psa_algorithm_t alg); + +static void append_with_alg(char **buffer, size_t buffer_size, + size_t *required_size, + psa_get_algorithm_name_func_ptr get_name, + psa_algorithm_t alg) { - const char *hash_name = psa_hash_algorithm_name(hash_alg); - append(buffer, buffer_size, required_size, string, length); - append(buffer, buffer_size, required_size, "(", 1); - if (hash_name != NULL) { + const char *name = get_name(alg); + if (name != NULL) { append(buffer, buffer_size, required_size, - hash_name, strlen(hash_name)); + name, strlen(name)); } else { append_integer(buffer, buffer_size, required_size, - "0x%08lx", hash_alg); + "0x%08lx", alg); } - append(buffer, buffer_size, required_size, ")", 1); } #include "psa_constant_names_generated.c" diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py index 382fd23e74..dac60034d4 100755 --- a/scripts/generate_psa_constants.py +++ b/scripts/generate_psa_constants.py @@ -30,6 +30,14 @@ static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg) } } +static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg) +{ + switch (ka_alg) { + %(ka_algorithm_cases)s + default: return NULL; + } +} + static int psa_snprint_key_type(char *buffer, size_t buffer_size, psa_key_type_t type) { @@ -47,12 +55,13 @@ static int psa_snprint_key_type(char *buffer, size_t buffer_size, return (int) required_size; } +#define NO_LENGTH_MODIFIER 0xfffffffflu static int psa_snprint_algorithm(char *buffer, size_t buffer_size, psa_algorithm_t alg) { size_t required_size = 0; psa_algorithm_t core_alg = alg; - unsigned long length_modifier = 0; + unsigned long length_modifier = NO_LENGTH_MODIFIER; if (PSA_ALG_IS_MAC(alg)) { core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0); if (core_alg != alg) { @@ -70,6 +79,15 @@ static int psa_snprint_algorithm(char *buffer, size_t buffer_size, "PSA_ALG_AEAD_WITH_TAG_LENGTH(", 29); length_modifier = PSA_AEAD_TAG_LENGTH(alg); } + } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) && + !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) { + core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg); + append(&buffer, buffer_size, &required_size, + "PSA_ALG_KEY_AGREEMENT(", 22); + append_with_alg(&buffer, buffer_size, &required_size, + psa_ka_algorithm_name, + PSA_ALG_KEY_AGREEMENT_GET_BASE(alg)); + append(&buffer, buffer_size, &required_size, ", ", 2); } switch (core_alg) { %(algorithm_cases)s @@ -81,9 +99,11 @@ static int psa_snprint_algorithm(char *buffer, size_t buffer_size, break; } if (core_alg != alg) { - append(&buffer, buffer_size, &required_size, ", ", 2); - append_integer(&buffer, buffer_size, &required_size, - "%%lu", length_modifier); + if (length_modifier != NO_LENGTH_MODIFIER) { + append(&buffer, buffer_size, &required_size, ", ", 2); + append_integer(&buffer, buffer_size, &required_size, + "%%lu", length_modifier); + } append(&buffer, buffer_size, &required_size, ")", 1); } buffer[0] = 0; @@ -126,9 +146,12 @@ key_type_from_curve_template = '''if (%(tester)s(type)) { } else ''' algorithm_from_hash_template = '''if (%(tester)s(core_alg)) { - append_with_hash(&buffer, buffer_size, &required_size, - "%(builder)s", %(builder_length)s, - PSA_ALG_GET_HASH(core_alg)); + append(&buffer, buffer_size, &required_size, + "%(builder)s(", %(builder_length)s + 1); + append_with_alg(&buffer, buffer_size, &required_size, + psa_hash_algorithm_name, + PSA_ALG_GET_HASH(core_alg)); + append(&buffer, buffer_size, &required_size, ")", 1); } else ''' bit_test_template = '''\ @@ -149,6 +172,7 @@ class MacroCollector: self.ecc_curves = set() self.algorithms = set() self.hash_algorithms = set() + self.ka_algorithms = set() self.algorithms_from_hash = {} self.key_usages = set() @@ -193,6 +217,9 @@ class MacroCollector: # Ad hoc detection of hash algorithms if re.search(r'0x010000[0-9A-Fa-f]{2}', definition): self.hash_algorithms.add(name) + # Ad hoc detection of key agreement algorithms + if re.search(r'0x30[0-9A-Fa-f]{2}0000', definition): + self.ka_algorithms.add(name) elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']: # A naming irregularity @@ -256,6 +283,10 @@ class MacroCollector: return '\n '.join(map(self.make_return_case, sorted(self.hash_algorithms))) + def make_ka_algorithm_cases(self): + return '\n '.join(map(self.make_return_case, + sorted(self.ka_algorithms))) + def make_algorithm_cases(self): return '\n '.join(map(self.make_append_case, sorted(self.algorithms))) @@ -281,6 +312,7 @@ class MacroCollector: data['key_type_cases'] = self.make_key_type_cases() data['key_type_code'] = self.make_key_type_code() data['hash_algorithm_cases'] = self.make_hash_algorithm_cases() + data['ka_algorithm_cases'] = self.make_ka_algorithm_cases() data['algorithm_cases'] = self.make_algorithm_cases() data['algorithm_code'] = self.make_algorithm_code() data['key_usage_code'] = self.make_key_usage_code() diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index 5e128eb7d1..421cf4e48a 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -63,7 +63,8 @@ when applicable.''' # Hard-coded value for unknown algorithms self.hash_algorithms = set(['0x010000fe']) self.mac_algorithms = set(['0x02ff00ff']) - self.kdf_algorithms = set(['0x300000ff', '0x310000ff']) + self.ka_algorithms = set(['0x30fc0000']) + self.kdf_algorithms = set(['0x200000ff']) # For AEAD algorithms, the only variability is over the tag length, # and this only applies to known algorithms, so don't test an # unknown algorithm. @@ -89,6 +90,7 @@ when applicable.''' Call this after parsing all the inputs.''' self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) + self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) self.arguments_for['curve'] = sorted(self.ecc_curves)