diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py index 1ccfc2188f..5506e40e7f 100755 --- a/tests/scripts/audit-validity-dates.py +++ b/tests/scripts/audit-validity-dates.py @@ -31,6 +31,7 @@ import argparse import datetime import glob import logging +import hashlib from enum import Enum # The script requires cryptography >= 35.0.0 which is only available @@ -45,7 +46,7 @@ from mbedtls_dev import build_tree def check_cryptography_version(): match = re.match(r'^[0-9]+', cryptography.__version__) - if match is None or int(match[0]) < 35: + if match is None or int(match.group(0)) < 35: raise Exception("audit-validity-dates requires cryptography >= 35.0.0" + "({} is too old)".format(cryptography.__version__)) @@ -65,8 +66,20 @@ class AuditData: #pylint: disable=too-few-public-methods def __init__(self, data_type: DataType, x509_obj): self.data_type = data_type - self.location = "" + # the locations that the x509 object could be found + self.locations = [] # type: typing.List[str] self.fill_validity_duration(x509_obj) + self._obj = x509_obj + encoding = cryptography.hazmat.primitives.serialization.Encoding.DER + self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest() + + @property + def identifier(self): + """ + Identifier of the underlying X.509 object, which is consistent across + different runs. + """ + return self._identifier def fill_validity_duration(self, x509_obj): """Read validity period from an X.509 object.""" @@ -90,7 +103,7 @@ class AuditData: class X509Parser: """A parser class to parse crt/crl/csr file or data in PEM/DER format.""" - PEM_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n(?P.*?)-{5}END (?P=type)-{5}\n' + PEM_REGEX = br'-{5}BEGIN (?P.*?)-{5}(?P.*?)-{5}END (?P=type)-{5}' PEM_TAG_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n' PEM_TAGS = { DataType.CRT: 'CERTIFICATE', @@ -193,13 +206,11 @@ class Auditor: X.509 data(DER/PEM format) to an X.509 object. - walk_all: Defaultly, it iterates over all the files in the provided file name list, calls `parse_file` for each file and stores the results - by extending Auditor.audit_data. + by extending the `results` passed to the function. """ def __init__(self, logger): self.logger = logger self.default_files = self.collect_default_files() - # A list to store the parsed audit_data. - self.audit_data = [] # type: typing.List[AuditData] self.parser = X509Parser({ DataType.CRT: { DataFormat.PEM: x509.load_pem_x509_certificate, @@ -241,15 +252,27 @@ class Auditor: return audit_data return None - def walk_all(self, file_list: typing.Optional[typing.List[str]] = None): + def walk_all(self, + results: typing.Dict[str, AuditData], + file_list: typing.Optional[typing.List[str]] = None) \ + -> None: """ - Iterate over all the files in the list and get audit data. + Iterate over all the files in the list and get audit data. The + results will be written to `results` passed to this function. + + :param results: The dictionary used to store the parsed + AuditData. The keys of this dictionary should + be the identifier of the AuditData. """ if file_list is None: file_list = self.default_files for filename in file_list: data_list = self.parse_file(filename) - self.audit_data.extend(data_list) + for d in data_list: + if d.identifier in results: + results[d.identifier].locations.extend(d.locations) + else: + results[d.identifier] = d @staticmethod def find_test_dir(): @@ -277,12 +300,25 @@ class TestDataAuditor(Auditor): """ with open(filename, 'rb') as f: data = f.read() - result = self.parse_bytes(data) - if result is not None: - result.location = filename - return [result] - else: - return [] + + results = [] + # Try to parse all PEM blocks. + is_pem = False + for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1): + is_pem = True + result = self.parse_bytes(data[m.start():m.end()]) + if result is not None: + result.locations.append("{}#{}".format(filename, idx)) + results.append(result) + + # Might be DER format. + if not is_pem: + result = self.parse_bytes(data) + if result is not None: + result.locations.append("{}".format(filename)) + results.append(result) + + return results def parse_suite_data(data_f): @@ -339,20 +375,22 @@ class SuiteDataAuditor(Auditor): audit_data = self.parse_bytes(bytes.fromhex(match.group('data'))) if audit_data is None: continue - audit_data.location = "{}:{}:#{}".format(filename, - data_f.line_no, - idx + 1) + audit_data.locations.append("{}:{}:#{}".format(filename, + data_f.line_no, + idx + 1)) audit_data_list.append(audit_data) return audit_data_list def list_all(audit_data: AuditData): - print("{}\t{}\t{}\t{}".format( - audit_data.not_valid_before.isoformat(timespec='seconds'), - audit_data.not_valid_after.isoformat(timespec='seconds'), - audit_data.data_type.name, - audit_data.location)) + for loc in audit_data.locations: + print("{}\t{:20}\t{:20}\t{:3}\t{}".format( + audit_data.identifier, + audit_data.not_valid_before.isoformat(timespec='seconds'), + audit_data.not_valid_after.isoformat(timespec='seconds'), + audit_data.data_type.name, + loc)) def configure_logger(logger: logging.Logger) -> None: @@ -448,20 +486,24 @@ def main(): end_date = start_date # go through all the files - td_auditor.walk_all(data_files) - sd_auditor.walk_all(suite_data_files) - audit_results = td_auditor.audit_data + sd_auditor.audit_data + audit_results = {} + td_auditor.walk_all(audit_results, data_files) + sd_auditor.walk_all(audit_results, suite_data_files) + + logger.info("Total: {} objects found!".format(len(audit_results))) # we filter out the files whose validity duration covers the provided # duration. filter_func = lambda d: (start_date < d.not_valid_before) or \ (d.not_valid_after < end_date) + sortby_end = lambda d: d.not_valid_after + if args.all: filter_func = None # filter and output the results - for d in filter(filter_func, audit_results): + for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end): list_all(d) logger.debug("Done!")