diff --git a/tests/scripts/audit-validity-dates.py b/tests/scripts/audit-validity-dates.py new file mode 100755 index 0000000000..67d4096659 --- /dev/null +++ b/tests/scripts/audit-validity-dates.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# +# copyright the mbed tls contributors +# spdx-license-identifier: apache-2.0 +# +# licensed under the apache license, version 2.0 (the "license"); you may +# not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audit validity date of X509 crt/crl/csr + +This script is used to audit the validity date of crt/crl/csr used for testing. +The files are in tests/data_files/ while some data are in test suites data in +tests/suites/*.data files. +""" + +import os +import sys +import re +import typing +import types +import argparse +import datetime +from enum import Enum + +from cryptography import x509 + +class DataType(Enum): + CRT = 1 # Certificate + CRL = 2 # Certificate Revocation List + CSR = 3 # Certificate Signing Request + +class DataFormat(Enum): + PEM = 1 # Privacy-Enhanced Mail + DER = 2 # Distinguished Encoding Rules + +class AuditData: + """Store file, type and expiration date for audit.""" + #pylint: disable=too-few-public-methods + def __init__(self, data_type: DataType): + self.data_type = data_type + self.filename = "" + self.not_valid_after: datetime.datetime + self.not_valid_before: datetime.datetime + + def fill_validity_duration(self, x509_obj): + """Fill expiration_date field from a x509 object""" + # Certificate expires after "not_valid_after" + # Certificate is invalid before "not_valid_before" + if self.data_type == DataType.CRT: + self.not_valid_after = x509_obj.not_valid_after + self.not_valid_before = x509_obj.not_valid_before + # CertificateRevocationList expires after "next_update" + # CertificateRevocationList is invalid before "last_update" + elif self.data_type == DataType.CRL: + self.not_valid_after = x509_obj.next_update + self.not_valid_before = x509_obj.last_update + # CertificateSigningRequest is always valid. + elif self.data_type == DataType.CSR: + self.not_valid_after = datetime.datetime.max + self.not_valid_before = datetime.datetime.min + else: + raise ValueError("Unsupported file_type: {}".format(self.data_type)) + +class X509Parser(): + """A parser class to parse crt/crl/csr file or data in PEM/DER format.""" + PEM_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n(?P.*?)-{5}END (?P=type)-{5}\n' + PEM_TAG_REGEX = br'-{5}BEGIN (?P.*?)-{5}\n' + PEM_TAGS = { + DataType.CRT: 'CERTIFICATE', + DataType.CRL: 'X509 CRL', + DataType.CSR: 'CERTIFICATE REQUEST' + } + + def __init__(self, backends: dict): + self.backends = backends + self.__generate_parsers() + + def __generate_parser(self, data_type: DataType): + """Parser generator for a specific DataType""" + tag = self.PEM_TAGS[data_type] + pem_loader = self.backends[data_type][DataFormat.PEM] + der_loader = self.backends[data_type][DataFormat.DER] + def wrapper(data: bytes): + pem_type = X509Parser.pem_data_type(data) + # It is in PEM format with target tag + if pem_type == tag: + return pem_loader(data) + # It is in PEM format without target tag + if pem_type: + return None + # It might be in DER format + try: + result = der_loader(data) + except ValueError: + result = None + return result + wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag) + return wrapper + + def __generate_parsers(self): + """Generate parsers for all support DataType""" + self.parsers = {} + for data_type, _ in self.PEM_TAGS.items(): + self.parsers[data_type] = self.__generate_parser(data_type) + + def __getitem__(self, item): + return self.parsers[item] + + @staticmethod + def pem_data_type(data: bytes) -> str: + """Get the tag from the data in PEM format + + :param data: data to be checked in binary mode. + :return: PEM tag or "" when no tag detected. + """ + m = re.search(X509Parser.PEM_TAG_REGEX, data) + if m is not None: + return m.group('type').decode('UTF-8') + else: + return "" + +class Auditor: + """A base class for audit.""" + def __init__(self, verbose): + self.verbose = verbose + self.default_files = [] + self.audit_data = [] + self.parser = X509Parser({ + DataType.CRT: { + DataFormat.PEM: x509.load_pem_x509_certificate, + DataFormat.DER: x509.load_der_x509_certificate + }, + DataType.CRL: { + DataFormat.PEM: x509.load_pem_x509_crl, + DataFormat.DER: x509.load_der_x509_crl + }, + DataType.CSR: { + DataFormat.PEM: x509.load_pem_x509_csr, + DataFormat.DER: x509.load_der_x509_csr + }, + }) + + def error(self, *args): + #pylint: disable=no-self-use + print("Error: ", *args, file=sys.stderr) + + def warn(self, *args): + if self.verbose: + print("Warn: ", *args, file=sys.stderr) + + def parse_file(self, filename: str) -> typing.List[AuditData]: + """ + Parse a list of AuditData from file. + + :param filename: name of the file to parse. + :return list of AuditData parsed from the file. + """ + with open(filename, 'rb') as f: + data = f.read() + result_list = [] + result = self.parse_bytes(data) + if result is not None: + result.filename = filename + result_list.append(result) + return result_list + + def parse_bytes(self, data: bytes): + """Parse AuditData from bytes.""" + for data_type in list(DataType): + try: + result = self.parser[data_type](data) + except ValueError as val_error: + result = None + self.warn(val_error) + if result is not None: + audit_data = AuditData(data_type) + audit_data.fill_validity_duration(result) + return audit_data + return None + + def walk_all(self, file_list): + """ + Iterate over all the files in the list and get audit data. + """ + if not file_list: + file_list = self.default_files + for filename in file_list: + data_list = self.parse_file(filename) + self.audit_data.extend(data_list) + + def for_each(self, do, *args, **kwargs): + """ + Sort the audit data and iterate over them. + """ + if not isinstance(do, types.FunctionType): + return + for d in self.audit_data: + do(d, *args, **kwargs) + + @staticmethod + def find_test_dir(): + """Get the relative path for the MbedTLS test directory.""" + if os.path.isdir('tests'): + tests_dir = 'tests' + elif os.path.isdir('suites'): + tests_dir = '.' + elif os.path.isdir('../suites'): + tests_dir = '..' + else: + raise Exception("Mbed TLS source tree not found") + return tests_dir + +class TestDataAuditor(Auditor): + """Class for auditing files in tests/data_files/""" + def __init__(self, verbose): + super().__init__(verbose) + self.default_files = self.collect_default_files() + + def collect_default_files(self): + """collect all files in tests/data_files/""" + test_dir = self.find_test_dir() + test_data_folder = os.path.join(test_dir, 'data_files') + data_files = [] + for (dir_path, _, file_names) in os.walk(test_data_folder): + data_files.extend(os.path.join(dir_path, file_name) + for file_name in file_names) + return data_files + + +def list_all(audit_data: AuditData): + print("{}\t{}\t{}\t{}".format( + audit_data.not_valid_before.isoformat(timespec='seconds'), + audit_data.not_valid_after.isoformat(timespec='seconds'), + audit_data.data_type.name, + audit_data.filename)) + +def main(): + """ + Perform argument parsing. + """ + parser = argparse.ArgumentParser( + description='Audit script for X509 crt/crl/csr files.' + ) + + parser.add_argument('-a', '--all', + action='store_true', + help='list the information of all files') + parser.add_argument('-v', '--verbose', + action='store_true', dest='verbose', + help='Show warnings') + parser.add_argument('-f', '--file', dest='file', + help='file to audit (Debug only)', + metavar='FILE') + + args = parser.parse_args() + + # start main routine + td_auditor = TestDataAuditor(args.verbose) + + if args.file: + data_files = [args.file] + else: + data_files = td_auditor.default_files + + td_auditor.walk_all(data_files) + + if args.all: + td_auditor.for_each(list_all) + + print("\nDone!\n") + +if __name__ == "__main__": + main()