diff --git a/scripts/abi_check.py b/scripts/abi_check.py index c2288432ce..ac1d60ffd0 100755 --- a/scripts/abi_check.py +++ b/scripts/abi_check.py @@ -113,6 +113,8 @@ from types import SimpleNamespace import xml.etree.ElementTree as ET +from mbedtls_dev import build_tree + class AbiChecker: """API and ABI checker.""" @@ -150,11 +152,6 @@ class AbiChecker: self.git_command = "git" self.make_command = "make" - @staticmethod - def check_repo_path(): - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("Must be run from Mbed TLS root") - def _setup_logger(self): self.log = logging.getLogger() if self.verbose: @@ -540,7 +537,7 @@ class AbiChecker: def check_for_abi_changes(self): """Generate a report of ABI differences between self.old_rev and self.new_rev.""" - self.check_repo_path() + build_tree.check_repo_path() if self.check_api or self.check_abi: self.check_abi_tools_are_installed() self._get_abi_dump_for_ref(self.old_version) diff --git a/scripts/code_size_compare.py b/scripts/code_size_compare.py index 0ef438db7c..af6ddd4fcb 100755 --- a/scripts/code_size_compare.py +++ b/scripts/code_size_compare.py @@ -30,6 +30,9 @@ import os import subprocess import sys +from mbedtls_dev import build_tree + + class CodeSizeComparison: """Compare code size between two Git revisions.""" @@ -51,11 +54,6 @@ class CodeSizeComparison: self.git_command = "git" self.make_command = "make" - @staticmethod - def check_repo_path(): - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("Must be run from Mbed TLS root") - @staticmethod def validate_revision(revision): result = subprocess.check_output(["git", "rev-parse", "--verify", @@ -172,7 +170,7 @@ class CodeSizeComparison: def get_comparision_results(self): """Compare size of library/*.o between self.old_rev and self.new_rev, and generate the result file.""" - self.check_repo_path() + build_tree.check_repo_path() self._get_code_size_for_rev(self.old_rev) self._get_code_size_for_rev(self.new_rev) return self.compare_code_size() diff --git a/scripts/mbedtls_dev/build_tree.py b/scripts/mbedtls_dev/build_tree.py index 3920d0ed6c..f52b785d95 100644 --- a/scripts/mbedtls_dev/build_tree.py +++ b/scripts/mbedtls_dev/build_tree.py @@ -25,6 +25,13 @@ def looks_like_mbedtls_root(path: str) -> bool: return all(os.path.isdir(os.path.join(path, subdir)) for subdir in ['include', 'library', 'programs', 'tests']) +def check_repo_path(): + """ + Check that the current working directory is the project root, and throw + an exception if not. + """ + if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): + raise Exception("This script must be run from Mbed TLS root") def chdir_to_root() -> None: """Detect the root of the Mbed TLS source tree and change to it. diff --git a/tests/scripts/check_files.py b/tests/scripts/check_files.py index a0f5e1f538..5c18702def 100755 --- a/tests/scripts/check_files.py +++ b/tests/scripts/check_files.py @@ -34,6 +34,9 @@ try: except ImportError: pass +import scripts_path # pylint: disable=unused-import +from mbedtls_dev import build_tree + class FileIssueTracker: """Base class for file-wide issue tracking. @@ -338,7 +341,7 @@ class IntegrityChecker: """Instantiate the sanity checker. Check files under the current directory. Write a report of issues to log_file.""" - self.check_repo_path() + build_tree.check_repo_path() self.logger = None self.setup_logger(log_file) self.issues_to_check = [ @@ -353,11 +356,6 @@ class IntegrityChecker: MergeArtifactIssueTracker(), ] - @staticmethod - def check_repo_path(): - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("Must be run from Mbed TLS root") - def setup_logger(self, log_file, level=logging.INFO): self.logger = logging.getLogger() self.logger.setLevel(level) diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py index e204487290..aece1ef060 100755 --- a/tests/scripts/check_names.py +++ b/tests/scripts/check_names.py @@ -56,6 +56,10 @@ import shutil import subprocess import logging +import scripts_path # pylint: disable=unused-import +from mbedtls_dev import build_tree + + # Naming patterns to check against. These are defined outside the NameCheck # class for ease of modification. PUBLIC_MACRO_PATTERN = r"^(MBEDTLS|PSA)_[0-9A-Z_]*[0-9A-Z]$" @@ -219,7 +223,7 @@ class CodeParser(): """ def __init__(self, log): self.log = log - self.check_repo_path() + build_tree.check_repo_path() # Memo for storing "glob expression": set(filepaths) self.files = {} @@ -228,15 +232,6 @@ class CodeParser(): # Note that "*" can match directory separators in exclude lists. self.excluded_files = ["*/bn_mul", "*/compat-2.x.h"] - @staticmethod - def check_repo_path(): - """ - Check that the current working directory is the project root, and throw - an exception if not. - """ - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("This script must be run from Mbed TLS root") - def comprehensive_parse(self): """ Comprehensive ("default") function to call each parsing function and