diff --git a/include/mbedtls/config_psa.h b/include/mbedtls/config_psa.h index 2032a365dc..f5db94ea0a 100644 --- a/include/mbedtls/config_psa.h +++ b/include/mbedtls/config_psa.h @@ -38,6 +38,30 @@ extern "C" { #endif + + +/****************************************************************/ +/* De facto synonyms */ +/****************************************************************/ + +#if defined(PSA_WANT_ALG_ECDSA_ANY) && !defined(PSA_WANT_ALG_ECDSA) +#define PSA_WANT_ALG_ECDSA PSA_WANT_ALG_ECDSA_ANY +#elif !defined(PSA_WANT_ALG_ECDSA_ANY) && defined(PSA_WANT_ALG_ECDSA) +#define PSA_WANT_ALG_ECDSA_ANY PSA_WANT_ALG_ECDSA +#endif + +#if defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW) && !defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN) +#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW +#elif !defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW) && defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN) +#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW PSA_WANT_ALG_RSA_PKCS1V15_SIGN +#endif + + + +/****************************************************************/ +/* Require built-in implementations based on PSA requirements */ +/****************************************************************/ + #if defined(MBEDTLS_PSA_CRYPTO_CONFIG) #if defined(PSA_WANT_ALG_DETERMINISTIC_ECDSA) @@ -497,6 +521,12 @@ extern "C" { #endif /* !MBEDTLS_PSA_ACCEL_ECC_SECP_K1_256 */ #endif /* PSA_WANT_ECC_SECP_K1_256 */ + + +/****************************************************************/ +/* Infer PSA requirements from Mbed TLS capabilities */ +/****************************************************************/ + #else /* MBEDTLS_PSA_CRYPTO_CONFIG */ /* @@ -522,6 +552,7 @@ extern "C" { #if defined(MBEDTLS_ECDSA_C) #define MBEDTLS_PSA_BUILTIN_ALG_ECDSA 1 #define PSA_WANT_ALG_ECDSA 1 +#define PSA_WANT_ALG_ECDSA_ANY 1 // Only add in DETERMINISTIC support if ECDSA is also enabled #if defined(MBEDTLS_ECDSA_DETERMINISTIC) @@ -586,6 +617,7 @@ extern "C" { #define PSA_WANT_ALG_RSA_PKCS1V15_CRYPT 1 #define MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN 1 #define PSA_WANT_ALG_RSA_PKCS1V15_SIGN 1 +#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW 1 #endif /* MBEDTLSS_PKCS1_V15 */ #if defined(MBEDTLS_PKCS1_V21) #define MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP 1 diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py index aa5279027c..94a97e7e96 100644 --- a/scripts/mbedtls_dev/crypto_knowledge.py +++ b/scripts/mbedtls_dev/crypto_knowledge.py @@ -33,7 +33,7 @@ class KeyType: `name` is a string 'PSA_KEY_TYPE_xxx' which is the name of a PSA key type macro. For key types that take arguments, the arguments can be passed either through the optional argument `params` or by - passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, param2)' + passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, ...)' in `name` as a string. """ @@ -48,7 +48,7 @@ class KeyType: m = re.match(r'(\w+)\s*\((.*)\)\Z', self.name) assert m is not None self.name = m.group(1) - params = ','.split(m.group(2)) + params = m.group(2).split(',') self.params = (None if params is None else [param.strip() for param in params]) """The parameters of the key type, if there are any. diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py index a2192baf48..0e76435f38 100644 --- a/scripts/mbedtls_dev/macro_collector.py +++ b/scripts/mbedtls_dev/macro_collector.py @@ -18,7 +18,55 @@ import itertools import re -from typing import Dict, Iterable, Iterator, List, Set +from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union + + +class ReadFileLineException(Exception): + def __init__(self, filename: str, line_number: Union[int, str]) -> None: + message = 'in {} at {}'.format(filename, line_number) + super(ReadFileLineException, self).__init__(message) + self.filename = filename + self.line_number = line_number + + +class read_file_lines: + # Dear Pylint, conventionally, a context manager class name is lowercase. + # pylint: disable=invalid-name,too-few-public-methods + """Context manager to read a text file line by line. + + ``` + with read_file_lines(filename) as lines: + for line in lines: + process(line) + ``` + is equivalent to + ``` + with open(filename, 'r') as input_file: + for line in input_file: + process(line) + ``` + except that if process(line) raises an exception, then the read_file_lines + snippet annotates the exception with the file name and line number. + """ + def __init__(self, filename: str, binary: bool = False) -> None: + self.filename = filename + self.line_number = 'entry' #type: Union[int, str] + self.generator = None #type: Optional[Iterable[Tuple[int, str]]] + self.binary = binary + def __enter__(self) -> 'read_file_lines': + self.generator = enumerate(open(self.filename, + 'rb' if self.binary else 'r')) + return self + def __iter__(self) -> Iterator[str]: + assert self.generator is not None + for line_number, content in self.generator: + self.line_number = line_number + yield content + self.line_number = 'exit' + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + if exc_type is not None: + raise ReadFileLineException(self.filename, self.line_number) \ + from exc_value class PSAMacroEnumerator: @@ -57,6 +105,20 @@ class PSAMacroEnumerator: 'tag_length': [], 'min_tag_length': [], } #type: Dict[str, List[str]] + # Whether to include intermediate macros in enumerations. Intermediate + # macros serve as category headers and are not valid values of their + # type. See `is_internal_name`. + # Always false in this class, may be set to true in derived classes. + self.include_intermediate = False + + def is_internal_name(self, name: str) -> bool: + """Whether this is an internal macro. Internal macros will be skipped.""" + if not self.include_intermediate: + if name.endswith('_BASE') or name.endswith('_NONE'): + return True + if '_CATEGORY_' in name: + return True + return name.endswith('_FLAG') or name.endswith('_MASK') def gather_arguments(self) -> None: """Populate the list of values for macro arguments. @@ -73,7 +135,11 @@ class PSAMacroEnumerator: @staticmethod def _format_arguments(name: str, arguments: Iterable[str]) -> str: - """Format a macro call with arguments..""" + """Format a macro call with arguments. + + The resulting format is consistent with + `InputsForTest.normalize_argument`. + """ return name + '(' + ', '.join(arguments) + ')' _argument_split_re = re.compile(r' *, *') @@ -111,6 +177,15 @@ class PSAMacroEnumerator: except BaseException as e: raise Exception('distribute_arguments({})'.format(name)) from e + def distribute_arguments_without_duplicates( + self, seen: Set[str], name: str + ) -> Iterator[str]: + """Same as `distribute_arguments`, but don't repeat seen results.""" + for result in self.distribute_arguments(name): + if result not in seen: + seen.add(result) + yield result + def generate_expressions(self, names: Iterable[str]) -> Iterator[str]: """Generate expressions covering values constructed from the given names. @@ -123,7 +198,11 @@ class PSAMacroEnumerator: * ``macros.generate_expressions(macros.key_types)`` generates all key types. """ - return itertools.chain(*map(self.distribute_arguments, names)) + seen = set() #type: Set[str] + return itertools.chain(*( + self.distribute_arguments_without_duplicates(seen, name) + for name in names + )) class PSAMacroCollector(PSAMacroEnumerator): @@ -144,15 +223,6 @@ class PSAMacroCollector(PSAMacroEnumerator): self.key_types_from_group = {} #type: Dict[str, str] self.algorithms_from_hash = {} #type: Dict[str, str] - def is_internal_name(self, name: str) -> bool: - """Whether this is an internal macro. Internal macros will be skipped.""" - if not self.include_intermediate: - if name.endswith('_BASE') or name.endswith('_NONE'): - return True - if '_CATEGORY_' in name: - return True - return name.endswith('_FLAG') or name.endswith('_MASK') - def record_algorithm_subtype(self, name: str, expansion: str) -> None: """Record the subtype of an algorithm constructor. @@ -251,3 +321,179 @@ class PSAMacroCollector(PSAMacroEnumerator): m = re.search(self._continued_line_re, line) line = re.sub(self._nonascii_re, rb'', line).decode('ascii') self.read_line(line) + + +class InputsForTest(PSAMacroEnumerator): + # pylint: disable=too-many-instance-attributes + """Accumulate information about macros to test. +enumerate + This includes macro names as well as information about their arguments + when applicable. + """ + + def __init__(self) -> None: + super().__init__() + self.all_declared = set() #type: Set[str] + # Identifier prefixes + self.table_by_prefix = { + 'ERROR': self.statuses, + 'ALG': self.algorithms, + 'ECC_CURVE': self.ecc_curves, + 'DH_GROUP': self.dh_groups, + 'KEY_TYPE': self.key_types, + 'KEY_USAGE': self.key_usage_flags, + } #type: Dict[str, Set[str]] + # Test functions + self.table_by_test_function = { + # Any function ending in _algorithm also gets added to + # self.algorithms. + 'key_type': [self.key_types], + 'block_cipher_key_type': [self.key_types], + 'stream_cipher_key_type': [self.key_types], + 'ecc_key_family': [self.ecc_curves], + 'ecc_key_types': [self.ecc_curves], + 'dh_key_family': [self.dh_groups], + 'dh_key_types': [self.dh_groups], + 'hash_algorithm': [self.hash_algorithms], + 'mac_algorithm': [self.mac_algorithms], + 'cipher_algorithm': [], + 'hmac_algorithm': [self.mac_algorithms], + 'aead_algorithm': [self.aead_algorithms], + 'key_derivation_algorithm': [self.kdf_algorithms], + 'key_agreement_algorithm': [self.ka_algorithms], + 'asymmetric_signature_algorithm': [], + 'asymmetric_signature_wildcard': [self.algorithms], + 'asymmetric_encryption_algorithm': [], + 'other_algorithm': [], + } #type: Dict[str, List[Set[str]]] + self.arguments_for['mac_length'] += ['1', '63'] + self.arguments_for['min_mac_length'] += ['1', '63'] + self.arguments_for['tag_length'] += ['1', '63'] + self.arguments_for['min_tag_length'] += ['1', '63'] + + def add_numerical_values(self) -> None: + """Add numerical values that are not supported to the known identifiers.""" + # Sets of names per type + self.algorithms.add('0xffffffff') + self.ecc_curves.add('0xff') + self.dh_groups.add('0xff') + self.key_types.add('0xffff') + self.key_usage_flags.add('0x80000000') + + # Hard-coded values for unknown algorithms + # + # These have to have values that are correct for their respective + # PSA_ALG_IS_xxx macros, but are also not currently assigned and are + # not likely to be assigned in the near future. + self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH + self.mac_algorithms.add('0x03007fff') + self.ka_algorithms.add('0x09fc0000') + self.kdf_algorithms.add('0x080000ff') + # For AEAD algorithms, the only variability is over the tag length, + # and this only applies to known algorithms, so don't test an + # unknown algorithm. + + def get_names(self, type_word: str) -> Set[str]: + """Return the set of known names of values of the given type.""" + return { + 'status': self.statuses, + 'algorithm': self.algorithms, + 'ecc_curve': self.ecc_curves, + 'dh_group': self.dh_groups, + 'key_type': self.key_types, + 'key_usage': self.key_usage_flags, + }[type_word] + + # Regex for interesting header lines. + # Groups: 1=macro name, 2=type, 3=argument list (optional). + _header_line_re = \ + re.compile(r'#define +' + + r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + + r'(?:\(([^\n()]*)\))?') + # Regex of macro names to exclude. + _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') + # Additional excluded macros. + _excluded_names = set([ + # Macros that provide an alternative way to build the same + # algorithm as another macro. + 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG', + 'PSA_ALG_FULL_LENGTH_MAC', + # Auxiliary macro whose name doesn't fit the usual patterns for + # auxiliary macros. + 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', + ]) + def parse_header_line(self, line: str) -> None: + """Parse a C header line, looking for "#define PSA_xxx".""" + m = re.match(self._header_line_re, line) + if not m: + return + name = m.group(1) + self.all_declared.add(name) + if re.search(self._excluded_name_re, name) or \ + name in self._excluded_names or \ + self.is_internal_name(name): + return + dest = self.table_by_prefix.get(m.group(2)) + if dest is None: + return + dest.add(name) + if m.group(3): + self.argspecs[name] = self._argument_split(m.group(3)) + + _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern + def parse_header(self, filename: str) -> None: + """Parse a C header file, looking for "#define PSA_xxx".""" + with read_file_lines(filename, binary=True) as lines: + for line in lines: + line = re.sub(self._nonascii_re, rb'', line).decode('ascii') + self.parse_header_line(line) + + _macro_identifier_re = re.compile(r'[A-Z]\w+') + def generate_undeclared_names(self, expr: str) -> Iterable[str]: + for name in re.findall(self._macro_identifier_re, expr): + if name not in self.all_declared: + yield name + + def accept_test_case_line(self, function: str, argument: str) -> bool: + #pylint: disable=unused-argument + undeclared = list(self.generate_undeclared_names(argument)) + if undeclared: + raise Exception('Undeclared names in test case', undeclared) + return True + + @staticmethod + def normalize_argument(argument: str) -> str: + """Normalize whitespace in the given C expression. + + The result uses the same whitespace as + ` PSAMacroEnumerator.distribute_arguments`. + """ + return re.sub(r',', r', ', re.sub(r' +', r'', argument)) + + def add_test_case_line(self, function: str, argument: str) -> None: + """Parse a test case data line, looking for algorithm metadata tests.""" + sets = [] + if function.endswith('_algorithm'): + sets.append(self.algorithms) + if function == 'key_agreement_algorithm' and \ + argument.startswith('PSA_ALG_KEY_AGREEMENT('): + # We only want *raw* key agreement algorithms as such, so + # exclude ones that are already chained with a KDF. + # Keep the expression as one to test as an algorithm. + function = 'other_algorithm' + sets += self.table_by_test_function[function] + if self.accept_test_case_line(function, argument): + for s in sets: + s.add(self.normalize_argument(argument)) + + # Regex matching a *.data line containing a test function call and + # its arguments. The actual definition is partly positional, but this + # regex is good enough in practice. + _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') + def parse_test_cases(self, filename: str) -> None: + """Parse a test case file (*.data), looking for algorithm metadata tests.""" + with read_file_lines(filename) as lines: + for line in lines: + m = re.match(self._test_case_line_re, line) + if m: + self.add_test_case_line(m.group(1), m.group(2)) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a9c9cf3a99..7898004c1c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -151,6 +151,8 @@ add_test_suite(psa_crypto_se_driver_hal) add_test_suite(psa_crypto_se_driver_hal_mocks) add_test_suite(psa_crypto_slot_management) add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.misc) +add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.current) +add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.v0) add_test_suite(psa_its) add_test_suite(random) add_test_suite(rsa) diff --git a/tests/scripts/check-generated-files.sh b/tests/scripts/check-generated-files.sh index b480837866..a2c285fad5 100755 --- a/tests/scripts/check-generated-files.sh +++ b/tests/scripts/check-generated-files.sh @@ -44,23 +44,28 @@ if [ $# -ne 0 ] && [ "$1" = "-u" ]; then UPDATE='y' fi +# check SCRIPT FILENAME[...] +# check SCRIPT DIRECTORY +# Run SCRIPT and check that it does not modify any of the specified files. +# In the first form, there can be any number of FILENAMEs, which must be +# regular files. +# In the second form, there must be a single DIRECTORY, standing for the +# list of files in the directory. Running SCRIPT must not modify any file +# in the directory and must not add or remove files either. +# If $UPDATE is empty, abort with an error status if a file is modified. check() { SCRIPT=$1 - TO_CHECK=$2 - PATTERN="" - FILES="" + shift - if [ -d $TO_CHECK ]; then - rm -f "$TO_CHECK"/*.bak - for FILE in $TO_CHECK/*; do - FILES="$FILE $FILES" - done - else - FILES=$TO_CHECK + directory= + if [ -d "$1" ]; then + directory="$1" + rm -f "$directory"/*.bak + set -- "$1"/* fi - for FILE in $FILES; do + for FILE in "$@"; do if [ -e "$FILE" ]; then cp "$FILE" "$FILE.bak" else @@ -68,37 +73,32 @@ check() fi done - $SCRIPT + "$SCRIPT" # Compare the script output to the old files and remove backups - for FILE in $FILES; do - if ! diff $FILE $FILE.bak >/dev/null 2>&1; then + for FILE in "$@"; do + if ! diff "$FILE" "$FILE.bak" >/dev/null 2>&1; then echo "'$FILE' was either modified or deleted by '$SCRIPT'" if [ -z "$UPDATE" ]; then exit 1 fi fi if [ -z "$UPDATE" ]; then - mv $FILE.bak $FILE + mv "$FILE.bak" "$FILE" else rm -f "$FILE.bak" fi - - if [ -d $TO_CHECK ]; then - # Create a grep regular expression that we can check against the - # directory contents to test whether new files have been created - if [ -z $PATTERN ]; then - PATTERN="$(basename $FILE)" - else - PATTERN="$PATTERN\|$(basename $FILE)" - fi - fi done - if [ -d $TO_CHECK ]; then + if [ -n "$directory" ]; then + old_list="$*" + set -- "$directory"/* + new_list="$*" # Check if there are any new files - if ls -1 $TO_CHECK | grep -v "$PATTERN" >/dev/null 2>&1; then - echo "Files were created by '$SCRIPT'" + if [ "$old_list" != "$new_list" ]; then + echo "Files were deleted or created by '$SCRIPT'" + echo "Before: $old_list" + echo "After: $new_list" if [ -z "$UPDATE" ]; then exit 1 fi diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py index 30f82db257..8c53414f25 100755 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -60,6 +60,14 @@ def finish_family_dependencies(dependencies: List[str], bits: int) -> List[str]: """ return [finish_family_dependency(dep, bits) for dep in dependencies] +SYMBOLS_WITHOUT_DEPENDENCY = frozenset([ + 'PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG', # modifier, only in policies + 'PSA_ALG_AEAD_WITH_SHORTENED_TAG', # modifier + 'PSA_ALG_ANY_HASH', # only in policies + 'PSA_ALG_AT_LEAST_THIS_LENGTH_MAC', # modifier, only in policies + 'PSA_ALG_KEY_AGREEMENT', # chaining + 'PSA_ALG_TRUNCATED_MAC', # modifier +]) def automatic_dependencies(*expressions: str) -> List[str]: """Infer dependencies of a test case by looking for PSA_xxx symbols. @@ -70,6 +78,7 @@ def automatic_dependencies(*expressions: str) -> List[str]: used = set() for expr in expressions: used.update(re.findall(r'PSA_(?:ALG|ECC_FAMILY|KEY_TYPE)_\w+', expr)) + used.difference_update(SYMBOLS_WITHOUT_DEPENDENCY) return sorted(psa_want_symbol(name) for name in used) # A temporary hack: at the time of writing, not all dependency symbols @@ -100,24 +109,27 @@ class Information: @staticmethod def remove_unwanted_macros( - constructors: macro_collector.PSAMacroCollector + constructors: macro_collector.PSAMacroEnumerator ) -> None: - # Mbed TLS doesn't support DSA. Don't attempt to generate any related - # test case. + # Mbed TLS doesn't support finite-field DH yet and will not support + # finite-field DSA. Don't attempt to generate any related test case. + constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR') + constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY') constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR') constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY') - constructors.algorithms_from_hash.pop('PSA_ALG_DSA', None) - constructors.algorithms_from_hash.pop('PSA_ALG_DETERMINISTIC_DSA', None) - def read_psa_interface(self) -> macro_collector.PSAMacroCollector: + def read_psa_interface(self) -> macro_collector.PSAMacroEnumerator: """Return the list of known key types, algorithms, etc.""" - constructors = macro_collector.PSAMacroCollector() + constructors = macro_collector.InputsForTest() header_file_names = ['include/psa/crypto_values.h', 'include/psa/crypto_extra.h'] + test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data'] for header_file_name in header_file_names: - with open(header_file_name, 'rb') as header_file: - constructors.read_file(header_file) + constructors.parse_header(header_file_name) + for test_cases in test_suites: + constructors.parse_test_cases(test_cases) self.remove_unwanted_macros(constructors) + constructors.gather_arguments() return constructors @@ -199,14 +211,18 @@ class NotSupported: ) # To be added: derive + ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR', + 'PSA_KEY_TYPE_ECC_PUBLIC_KEY') + def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]: """Generate test cases that exercise the creation of keys of unsupported types.""" for key_type in sorted(self.constructors.key_types): + if key_type in self.ECC_KEY_TYPES: + continue kt = crypto_knowledge.KeyType(key_type) yield from self.test_cases_for_key_type_not_supported(kt) for curve_family in sorted(self.constructors.ecc_curves): - for constr in ('PSA_KEY_TYPE_ECC_KEY_PAIR', - 'PSA_KEY_TYPE_ECC_PUBLIC_KEY'): + for constr in self.ECC_KEY_TYPES: kt = crypto_knowledge.KeyType(constr, [curve_family]) yield from self.test_cases_for_key_type_not_supported( kt, param_descr='type') @@ -260,13 +276,17 @@ class StorageFormat: if self.forward: extra_arguments = [] else: + flags = [] # Some test keys have the RAW_DATA type and attributes that don't # necessarily make sense. We do this to validate numerical # encodings of the attributes. # Raw data keys have no useful exercise anyway so there is no # loss of test coverage. - exercise = key.type.string != 'PSA_KEY_TYPE_RAW_DATA' - extra_arguments = ['1' if exercise else '0'] + if key.type.string != 'PSA_KEY_TYPE_RAW_DATA': + flags.append('TEST_FLAG_EXERCISE') + if 'READ_ONLY' in key.lifetime.string: + flags.append('TEST_FLAG_READ_ONLY') + extra_arguments = [' | '.join(flags) if flags else '0'] tc.set_arguments([key.lifetime.string, key.type.string, str(key.bits), key.usage.string, key.alg.string, key.alg2.string, @@ -335,23 +355,17 @@ class StorageFormat: def all_keys_for_types(self) -> Iterator[StorageKey]: """Generate test keys covering key types and their representations.""" - for key_type in sorted(self.constructors.key_types): + key_types = sorted(self.constructors.key_types) + for key_type in self.constructors.generate_expressions(key_types): yield from self.keys_for_type(key_type) - for key_type in sorted(self.constructors.key_types_from_curve): - for curve in sorted(self.constructors.ecc_curves): - yield from self.keys_for_type(key_type, [curve]) - ## Diffie-Hellman (FFDH) is not supported yet, either in - ## crypto_knowledge.py or in Mbed TLS. - # for key_type in sorted(self.constructors.key_types_from_group): - # for group in sorted(self.constructors.dh_groups): - # yield from self.keys_for_type(key_type, [group]) def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]: """Generate test keys for the specified algorithm.""" # For now, we don't have information on the compatibility of key # types and algorithms. So we just test the encoding of algorithms, # and not that operations can be performed with them. - descr = alg + descr = re.sub(r'PSA_ALG_', r'', alg) + descr = re.sub(r',', r', ', re.sub(r' +', r'', descr)) usage = 'PSA_KEY_USAGE_EXPORT' key1 = StorageKey(version=self.version, id=1, lifetime=0x00000001, @@ -370,17 +384,21 @@ class StorageFormat: def all_keys_for_algorithms(self) -> Iterator[StorageKey]: """Generate test keys covering algorithm encodings.""" - for alg in sorted(self.constructors.algorithms): + algorithms = sorted(self.constructors.algorithms) + for alg in self.constructors.generate_expressions(algorithms): yield from self.keys_for_algorithm(alg) - # To do: algorithm constructors with parameters def all_test_cases(self) -> Iterator[test_case.TestCase]: """Generate all storage format test cases.""" - for key in self.all_keys_for_usage_flags(): - yield self.make_test_case(key) - for key in self.all_keys_for_types(): - yield self.make_test_case(key) - for key in self.all_keys_for_algorithms(): + # First build a list of all keys, then construct all the corresponding + # test cases. This allows all required information to be obtained in + # one go, which is a significant performance gain as the information + # includes numerical values obtained by compiling a C program. + keys = [] #type: List[StorageKey] + keys += self.all_keys_for_usage_flags() + keys += self.all_keys_for_types() + keys += self.all_keys_for_algorithms() + for key in keys: yield self.make_test_case(key) # To do: vary id, lifetime diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index b3fdb8d994..07c8ab2e97 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -28,231 +28,30 @@ import os import re import subprocess import sys +from typing import Iterable, List, Optional, Tuple import scripts_path # pylint: disable=unused-import from mbedtls_dev import c_build_helper -from mbedtls_dev import macro_collector +from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator +from mbedtls_dev import typing_util -class ReadFileLineException(Exception): - def __init__(self, filename, line_number): - message = 'in {} at {}'.format(filename, line_number) - super(ReadFileLineException, self).__init__(message) - self.filename = filename - self.line_number = line_number - -class read_file_lines: - # Dear Pylint, conventionally, a context manager class name is lowercase. - # pylint: disable=invalid-name,too-few-public-methods - """Context manager to read a text file line by line. - - ``` - with read_file_lines(filename) as lines: - for line in lines: - process(line) - ``` - is equivalent to - ``` - with open(filename, 'r') as input_file: - for line in input_file: - process(line) - ``` - except that if process(line) raises an exception, then the read_file_lines - snippet annotates the exception with the file name and line number. - """ - def __init__(self, filename, binary=False): - self.filename = filename - self.line_number = 'entry' - self.generator = None - self.binary = binary - def __enter__(self): - self.generator = enumerate(open(self.filename, - 'rb' if self.binary else 'r')) - return self - def __iter__(self): - for line_number, content in self.generator: - self.line_number = line_number - yield content - self.line_number = 'exit' - def __exit__(self, exc_type, exc_value, exc_traceback): - if exc_type is not None: - raise ReadFileLineException(self.filename, self.line_number) \ - from exc_value - -class InputsForTest(macro_collector.PSAMacroEnumerator): - # pylint: disable=too-many-instance-attributes - """Accumulate information about macros to test. - - This includes macro names as well as information about their arguments - when applicable. - """ - - def __init__(self): - super().__init__() - self.all_declared = set() - # Sets of names per type - self.statuses.add('PSA_SUCCESS') - self.algorithms.add('0xffffffff') - self.ecc_curves.add('0xff') - self.dh_groups.add('0xff') - self.key_types.add('0xffff') - self.key_usage_flags.add('0x80000000') - - # Hard-coded values for unknown algorithms - # - # These have to have values that are correct for their respective - # PSA_ALG_IS_xxx macros, but are also not currently assigned and are - # not likely to be assigned in the near future. - self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH - self.mac_algorithms.add('0x03007fff') - self.ka_algorithms.add('0x09fc0000') - self.kdf_algorithms.add('0x080000ff') - # For AEAD algorithms, the only variability is over the tag length, - # and this only applies to known algorithms, so don't test an - # unknown algorithm. - - # Identifier prefixes - self.table_by_prefix = { - 'ERROR': self.statuses, - 'ALG': self.algorithms, - 'ECC_CURVE': self.ecc_curves, - 'DH_GROUP': self.dh_groups, - 'KEY_TYPE': self.key_types, - 'KEY_USAGE': self.key_usage_flags, - } - # Test functions - self.table_by_test_function = { - # Any function ending in _algorithm also gets added to - # self.algorithms. - 'key_type': [self.key_types], - 'block_cipher_key_type': [self.key_types], - 'stream_cipher_key_type': [self.key_types], - 'ecc_key_family': [self.ecc_curves], - 'ecc_key_types': [self.ecc_curves], - 'dh_key_family': [self.dh_groups], - 'dh_key_types': [self.dh_groups], - 'hash_algorithm': [self.hash_algorithms], - 'mac_algorithm': [self.mac_algorithms], - 'cipher_algorithm': [], - 'hmac_algorithm': [self.mac_algorithms], - 'aead_algorithm': [self.aead_algorithms], - 'key_derivation_algorithm': [self.kdf_algorithms], - 'key_agreement_algorithm': [self.ka_algorithms], - 'asymmetric_signature_algorithm': [], - 'asymmetric_signature_wildcard': [self.algorithms], - 'asymmetric_encryption_algorithm': [], - 'other_algorithm': [], - } - self.arguments_for['mac_length'] += ['1', '63'] - self.arguments_for['min_mac_length'] += ['1', '63'] - self.arguments_for['tag_length'] += ['1', '63'] - self.arguments_for['min_tag_length'] += ['1', '63'] - - def get_names(self, type_word): - """Return the set of known names of values of the given type.""" - return { - 'status': self.statuses, - 'algorithm': self.algorithms, - 'ecc_curve': self.ecc_curves, - 'dh_group': self.dh_groups, - 'key_type': self.key_types, - 'key_usage': self.key_usage_flags, - }[type_word] - - # Regex for interesting header lines. - # Groups: 1=macro name, 2=type, 3=argument list (optional). - _header_line_re = \ - re.compile(r'#define +' + - r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + - r'(?:\(([^\n()]*)\))?') - # Regex of macro names to exclude. - _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') - # Additional excluded macros. - _excluded_names = set([ - # Macros that provide an alternative way to build the same - # algorithm as another macro. - 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG', - 'PSA_ALG_FULL_LENGTH_MAC', - # Auxiliary macro whose name doesn't fit the usual patterns for - # auxiliary macros. - 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', - ]) - def parse_header_line(self, line): - """Parse a C header line, looking for "#define PSA_xxx".""" - m = re.match(self._header_line_re, line) - if not m: - return - name = m.group(1) - self.all_declared.add(name) - if re.search(self._excluded_name_re, name) or \ - name in self._excluded_names: - return - dest = self.table_by_prefix.get(m.group(2)) - if dest is None: - return - dest.add(name) - if m.group(3): - self.argspecs[name] = self._argument_split(m.group(3)) - - _nonascii_re = re.compile(rb'[^\x00-\x7f]+') - def parse_header(self, filename): - """Parse a C header file, looking for "#define PSA_xxx".""" - with read_file_lines(filename, binary=True) as lines: - for line in lines: - line = re.sub(self._nonascii_re, rb'', line).decode('ascii') - self.parse_header_line(line) - - _macro_identifier_re = re.compile(r'[A-Z]\w+') - def generate_undeclared_names(self, expr): - for name in re.findall(self._macro_identifier_re, expr): - if name not in self.all_declared: - yield name - - def accept_test_case_line(self, function, argument): - #pylint: disable=unused-argument - undeclared = list(self.generate_undeclared_names(argument)) - if undeclared: - raise Exception('Undeclared names in test case', undeclared) - return True - - def add_test_case_line(self, function, argument): - """Parse a test case data line, looking for algorithm metadata tests.""" - sets = [] - if function.endswith('_algorithm'): - sets.append(self.algorithms) - if function == 'key_agreement_algorithm' and \ - argument.startswith('PSA_ALG_KEY_AGREEMENT('): - # We only want *raw* key agreement algorithms as such, so - # exclude ones that are already chained with a KDF. - # Keep the expression as one to test as an algorithm. - function = 'other_algorithm' - sets += self.table_by_test_function[function] - if self.accept_test_case_line(function, argument): - for s in sets: - s.add(argument) - - # Regex matching a *.data line containing a test function call and - # its arguments. The actual definition is partly positional, but this - # regex is good enough in practice. - _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') - def parse_test_cases(self, filename): - """Parse a test case file (*.data), looking for algorithm metadata tests.""" - with read_file_lines(filename) as lines: - for line in lines: - m = re.match(self._test_case_line_re, line) - if m: - self.add_test_case_line(m.group(1), m.group(2)) - -def gather_inputs(headers, test_suites, inputs_class=InputsForTest): +def gather_inputs(headers: Iterable[str], + test_suites: Iterable[str], + inputs_class=InputsForTest) -> PSAMacroEnumerator: """Read the list of inputs to test psa_constant_names with.""" inputs = inputs_class() for header in headers: inputs.parse_header(header) for test_cases in test_suites: inputs.parse_test_cases(test_cases) + inputs.add_numerical_values() inputs.gather_arguments() return inputs -def run_c(type_word, expressions, include_path=None, keep_c=False): +def run_c(type_word: str, + expressions: Iterable[str], + include_path: Optional[str] = None, + keep_c: bool = False) -> List[str]: """Generate and run a program to print out numerical values of C expressions.""" if type_word == 'status': cast_to = 'long' @@ -271,14 +70,17 @@ def run_c(type_word, expressions, include_path=None, keep_c=False): ) NORMALIZE_STRIP_RE = re.compile(r'\s+') -def normalize(expr): +def normalize(expr: str) -> str: """Normalize the C expression so as not to care about trivial differences. Currently "trivial differences" means whitespace. """ return re.sub(NORMALIZE_STRIP_RE, '', expr) -def collect_values(inputs, type_word, include_path=None, keep_c=False): +def collect_values(inputs: InputsForTest, + type_word: str, + include_path: Optional[str] = None, + keep_c: bool = False) -> Tuple[List[str], List[str]]: """Generate expressions using known macro names and calculate their values. Return a list of pairs of (expr, value) where expr is an expression and @@ -296,12 +98,12 @@ class Tests: Error = namedtuple('Error', ['type', 'expression', 'value', 'output']) - def __init__(self, options): + def __init__(self, options) -> None: self.options = options self.count = 0 - self.errors = [] + self.errors = [] #type: List[Tests.Error] - def run_one(self, inputs, type_word): + def run_one(self, inputs: InputsForTest, type_word: str) -> None: """Test psa_constant_names for the specified type. Run the program on the names for this type. @@ -311,9 +113,10 @@ class Tests: expressions, values = collect_values(inputs, type_word, include_path=self.options.include, keep_c=self.options.keep_c) - output = subprocess.check_output([self.options.program, type_word] + - values) - outputs = output.decode('ascii').strip().split('\n') + output_bytes = subprocess.check_output([self.options.program, + type_word] + values) + output = output_bytes.decode('ascii') + outputs = output.strip().split('\n') self.count += len(expressions) for expr, value, output in zip(expressions, values, outputs): if self.options.show: @@ -324,13 +127,13 @@ class Tests: value=value, output=output)) - def run_all(self, inputs): + def run_all(self, inputs: InputsForTest) -> None: """Run psa_constant_names on all the gathered inputs.""" for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 'key_type', 'key_usage']: self.run_one(inputs, type_word) - def report(self, out): + def report(self, out: typing_util.Writable) -> None: """Describe each case where the output is not as expected. Write the errors to ``out``. @@ -365,7 +168,7 @@ def main(): help='Program to test') parser.add_argument('--show', action='store_true', - help='Keep the intermediate C file') + help='Show tested values on stdout') parser.add_argument('--no-show', action='store_false', dest='show', help='Don\'t show tested values (default)') diff --git a/tests/suites/test_suite_psa_crypto_storage_format.function b/tests/suites/test_suite_psa_crypto_storage_format.function index 76cfe57758..34d63a745b 100644 --- a/tests/suites/test_suite_psa_crypto_storage_format.function +++ b/tests/suites/test_suite_psa_crypto_storage_format.function @@ -7,6 +7,8 @@ #include +#define TEST_FLAG_EXERCISE 0x00000001 + /** Write a key with the given attributes and key material to storage. * Test that it has the expected representation. * @@ -67,7 +69,7 @@ static int test_read_key( const psa_key_attributes_t *expected_attributes, const data_t *expected_material, psa_storage_uid_t uid, const data_t *representation, - int exercise ) + int flags ) { psa_key_attributes_t actual_attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = psa_get_key_id( expected_attributes ); @@ -105,7 +107,7 @@ static int test_read_key( const psa_key_attributes_t *expected_attributes, exported_material, length ); } - if( exercise ) + if( flags & TEST_FLAG_EXERCISE ) { TEST_ASSERT( mbedtls_test_psa_exercise_key( key_id, @@ -183,7 +185,7 @@ exit: void key_storage_read( int lifetime_arg, int type_arg, int bits_arg, int usage_arg, int alg_arg, int alg2_arg, data_t *material, - data_t *representation, int exercise ) + data_t *representation, int flags ) { /* Backward compatibility: read a key in the format of a past version * and check that this version can use it. */ @@ -213,7 +215,7 @@ void key_storage_read( int lifetime_arg, int type_arg, int bits_arg, * guarantees backward compatibility with keys that were stored by * past versions of Mbed TLS. */ TEST_ASSERT( test_read_key( &attributes, material, - uid, representation, exercise ) ); + uid, representation, flags ) ); exit: psa_reset_key_attributes( &attributes );