diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py index 5fe85b7bd6..c129def4ed 100755 --- a/tests/scripts/check_names.py +++ b/tests/scripts/check_names.py @@ -33,6 +33,7 @@ subprocess error. Must be run from Mbed TLS root. """ import argparse +import glob import textwrap import os import sys @@ -177,13 +178,19 @@ class Typo(Problem): # pylint: disable=too-few-public-methods class NameCheck(): """ Representation of the core name checking operation performed by this script. - Shares a common logger, common excluded filenames, and a shared return_code. + Shares a common logger, and a shared return code. """ - def __init__(self): + def __init__(self, verbose=False): self.log = None self.check_repo_path() self.return_code = 0 + + self.setup_logger(verbose) + + # Globally excluded filenames self.excluded_files = ["bn_mul", "compat-2.x.h"] + + # Will contain the parse result after a comprehensive parse self.parse_result = {} def set_return_code(self, return_code): @@ -213,30 +220,30 @@ class NameCheck(): self.log.setLevel(logging.INFO) self.log.addHandler(logging.StreamHandler()) - def get_files(self, extension, directory): + def get_files(self, wildcard): """ - Get all files that end with .extension in the specified directory - recursively. + Get all files that match a UNIX-style wildcard recursively. While the + script is designed only for use on UNIX/macOS (due to nm), this function + would work fine on Windows even with forward slashes in the wildcard. Args: - * extension: the file extension to search for, without the dot - * directory: the directory to recursively search for + * wildcard: shell-style wildcards to match filepaths against. Returns a List of relative filepaths. """ - filenames = [] - for root, _, files in sorted(os.walk(directory)): - for filename in sorted(files): - if (filename not in self.excluded_files and - filename.endswith("." + extension)): - filenames.append(os.path.join(root, filename)) - return filenames + accumulator = [] + + for filepath in glob.iglob(wildcard, recursive=True): + if os.path.basename(filepath) not in self.excluded_files: + accumulator.append(filepath) + return accumulator def parse_names_in_source(self): """ - Calls each parsing function to retrieve various elements of the code, - together with their source location. Puts the parsed values in the - internal variable self.parse_result. + Comprehensive function to call each parsing function and retrieve + various elements of the code, together with their source location. + Puts the parsed values in the internal variable self.parse_result, so + they can be used from perform_checks(). """ self.log.info("Parsing source code...") self.log.debug( @@ -244,13 +251,13 @@ class NameCheck(): .format(str(self.excluded_files)) ) - m_headers = self.get_files("h", os.path.join("include", "mbedtls")) - p_headers = self.get_files("h", os.path.join("include", "psa")) + m_headers = self.get_files("include/mbedtls/*.h") + p_headers = self.get_files("include/psa/*.h") t_headers = ["3rdparty/everest/include/everest/everest.h", "3rdparty/everest/include/everest/x25519.h"] - d_headers = self.get_files("h", os.path.join("tests", "include", "test", "drivers")) - l_headers = self.get_files("h", "library") - libraries = self.get_files("c", "library") + [ + d_headers = self.get_files("tests/include/test/drivers/*.h") + l_headers = self.get_files("library/*.h") + libraries = self.get_files("library/*.c") + [ "3rdparty/everest/library/everest.c", "3rdparty/everest/library/x25519.c"] @@ -589,6 +596,7 @@ class NameCheck(): """ Perform each check in order, output its PASS/FAIL status. Maintain an overall test status, and output that at the end. + Assumes parse_names_in_source() was called before this. Args: * quiet: whether to hide detailed problem explanation. @@ -620,6 +628,7 @@ class NameCheck(): """ Perform a check that all detected symbols in the library object files are properly declared in headers. + Assumes parse_names_in_source() was called before this. Args: * quiet: whether to hide detailed problem explanation. @@ -645,6 +654,7 @@ class NameCheck(): def check_match_pattern(self, quiet, group_to_check, check_pattern): """ Perform a check that all items of a group conform to a regex pattern. + Assumes parse_names_in_source() was called before this. Args: * quiet: whether to hide detailed problem explanation. @@ -674,6 +684,7 @@ class NameCheck(): """ Perform a check that all words in the soure code beginning with MBED are either defined as macros, or as enum constants. + Assumes parse_names_in_source() was called before this. Args: * quiet: whether to hide detailed problem explanation. @@ -725,7 +736,7 @@ def main(): Perform argument parsing, and create an instance of NameCheck to begin the core operation. """ - parser = argparse.ArgumentParser( + argparser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=( "This script confirms that the naming of all symbols and identifiers " @@ -733,19 +744,18 @@ def main(): "self-consistent.\n\n" "Expected to be run from the MbedTLS root directory.")) - parser.add_argument("-v", "--verbose", + argparser.add_argument("-v", "--verbose", action="store_true", help="show parse results") - parser.add_argument("-q", "--quiet", + argparser.add_argument("-q", "--quiet", action="store_true", help="hide unnecessary text, explanations, and highlighs") - args = parser.parse_args() + args = argparser.parse_args() try: - name_check = NameCheck() - name_check.setup_logger(verbose=args.verbose) + name_check = NameCheck(verbose=args.verbose) name_check.parse_names_in_source() name_check.perform_checks(quiet=args.quiet) sys.exit(name_check.return_code)