Use glob in get_files(), call setup_logger on init

glob is more flexible and simplifies the function arguments drastically.
It is also much more intuitive to extend in the future when the filepaths
need to be extended or changed.

setup_logger had to be called as the first thing after instantiation, so
this commit simplify makes it automatic.

Several clarification comments are added too.

Signed-off-by: Yuto Takano <yuto.takano@arm.com>
This commit is contained in:
Yuto Takano 2021-08-09 11:56:15 +01:00
parent 51efcb143d
commit 977e07f5c8

View File

@ -33,6 +33,7 @@ subprocess error. Must be run from Mbed TLS root.
"""
import argparse
import glob
import textwrap
import os
import sys
@ -177,13 +178,19 @@ class Typo(Problem): # pylint: disable=too-few-public-methods
class NameCheck():
"""
Representation of the core name checking operation performed by this script.
Shares a common logger, common excluded filenames, and a shared return_code.
Shares a common logger, and a shared return code.
"""
def __init__(self):
def __init__(self, verbose=False):
self.log = None
self.check_repo_path()
self.return_code = 0
self.setup_logger(verbose)
# Globally excluded filenames
self.excluded_files = ["bn_mul", "compat-2.x.h"]
# Will contain the parse result after a comprehensive parse
self.parse_result = {}
def set_return_code(self, return_code):
@ -213,30 +220,30 @@ class NameCheck():
self.log.setLevel(logging.INFO)
self.log.addHandler(logging.StreamHandler())
def get_files(self, extension, directory):
def get_files(self, wildcard):
"""
Get all files that end with .extension in the specified directory
recursively.
Get all files that match a UNIX-style wildcard recursively. While the
script is designed only for use on UNIX/macOS (due to nm), this function
would work fine on Windows even with forward slashes in the wildcard.
Args:
* extension: the file extension to search for, without the dot
* directory: the directory to recursively search for
* wildcard: shell-style wildcards to match filepaths against.
Returns a List of relative filepaths.
"""
filenames = []
for root, _, files in sorted(os.walk(directory)):
for filename in sorted(files):
if (filename not in self.excluded_files and
filename.endswith("." + extension)):
filenames.append(os.path.join(root, filename))
return filenames
accumulator = []
for filepath in glob.iglob(wildcard, recursive=True):
if os.path.basename(filepath) not in self.excluded_files:
accumulator.append(filepath)
return accumulator
def parse_names_in_source(self):
"""
Calls each parsing function to retrieve various elements of the code,
together with their source location. Puts the parsed values in the
internal variable self.parse_result.
Comprehensive function to call each parsing function and retrieve
various elements of the code, together with their source location.
Puts the parsed values in the internal variable self.parse_result, so
they can be used from perform_checks().
"""
self.log.info("Parsing source code...")
self.log.debug(
@ -244,13 +251,13 @@ class NameCheck():
.format(str(self.excluded_files))
)
m_headers = self.get_files("h", os.path.join("include", "mbedtls"))
p_headers = self.get_files("h", os.path.join("include", "psa"))
m_headers = self.get_files("include/mbedtls/*.h")
p_headers = self.get_files("include/psa/*.h")
t_headers = ["3rdparty/everest/include/everest/everest.h",
"3rdparty/everest/include/everest/x25519.h"]
d_headers = self.get_files("h", os.path.join("tests", "include", "test", "drivers"))
l_headers = self.get_files("h", "library")
libraries = self.get_files("c", "library") + [
d_headers = self.get_files("tests/include/test/drivers/*.h")
l_headers = self.get_files("library/*.h")
libraries = self.get_files("library/*.c") + [
"3rdparty/everest/library/everest.c",
"3rdparty/everest/library/x25519.c"]
@ -589,6 +596,7 @@ class NameCheck():
"""
Perform each check in order, output its PASS/FAIL status. Maintain an
overall test status, and output that at the end.
Assumes parse_names_in_source() was called before this.
Args:
* quiet: whether to hide detailed problem explanation.
@ -620,6 +628,7 @@ class NameCheck():
"""
Perform a check that all detected symbols in the library object files
are properly declared in headers.
Assumes parse_names_in_source() was called before this.
Args:
* quiet: whether to hide detailed problem explanation.
@ -645,6 +654,7 @@ class NameCheck():
def check_match_pattern(self, quiet, group_to_check, check_pattern):
"""
Perform a check that all items of a group conform to a regex pattern.
Assumes parse_names_in_source() was called before this.
Args:
* quiet: whether to hide detailed problem explanation.
@ -674,6 +684,7 @@ class NameCheck():
"""
Perform a check that all words in the soure code beginning with MBED are
either defined as macros, or as enum constants.
Assumes parse_names_in_source() was called before this.
Args:
* quiet: whether to hide detailed problem explanation.
@ -725,7 +736,7 @@ def main():
Perform argument parsing, and create an instance of NameCheck to begin the
core operation.
"""
parser = argparse.ArgumentParser(
argparser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=(
"This script confirms that the naming of all symbols and identifiers "
@ -733,19 +744,18 @@ def main():
"self-consistent.\n\n"
"Expected to be run from the MbedTLS root directory."))
parser.add_argument("-v", "--verbose",
argparser.add_argument("-v", "--verbose",
action="store_true",
help="show parse results")
parser.add_argument("-q", "--quiet",
argparser.add_argument("-q", "--quiet",
action="store_true",
help="hide unnecessary text, explanations, and highlighs")
args = parser.parse_args()
args = argparser.parse_args()
try:
name_check = NameCheck()
name_check.setup_logger(verbose=args.verbose)
name_check = NameCheck(verbose=args.verbose)
name_check.parse_names_in_source()
name_check.perform_checks(quiet=args.quiet)
sys.exit(name_check.return_code)