cert_audit: Code refinement

This commit is a collection of code refinements
from review comments.

Signed-off-by: Pengyu Lv <pengyu.lv@arm.com>
This commit is contained in:
Pengyu Lv 2023-04-18 17:00:47 +08:00
parent f8e5e059c5
commit 8e6794ad56

View File

@ -86,7 +86,12 @@ class X509Parser:
DataType.CSR: 'CERTIFICATE REQUEST' DataType.CSR: 'CERTIFICATE REQUEST'
} }
def __init__(self, backends: dict): def __init__(self,
backends:
typing.Dict[DataType,
typing.Dict[DataFormat,
typing.Callable[[bytes], object]]]) \
-> None:
self.backends = backends self.backends = backends
self.__generate_parsers() self.__generate_parsers()
@ -122,7 +127,7 @@ class X509Parser:
return self.parsers[item] return self.parsers[item]
@staticmethod @staticmethod
def pem_data_type(data: bytes) -> str: def pem_data_type(data: bytes) -> typing.Optional[str]:
"""Get the tag from the data in PEM format """Get the tag from the data in PEM format
:param data: data to be checked in binary mode. :param data: data to be checked in binary mode.
@ -132,7 +137,7 @@ class X509Parser:
if m is not None: if m is not None:
return m.group('type').decode('UTF-8') return m.group('type').decode('UTF-8')
else: else:
return "" return None
@staticmethod @staticmethod
def check_hex_string(hex_str: str) -> bool: def check_hex_string(hex_str: str) -> bool:
@ -165,6 +170,7 @@ class Auditor:
def __init__(self, verbose): def __init__(self, verbose):
self.verbose = verbose self.verbose = verbose
self.default_files = [] self.default_files = []
# A list to store the parsed audit_data.
self.audit_data = [] self.audit_data = []
self.parser = X509Parser({ self.parser = X509Parser({
DataType.CRT: { DataType.CRT: {
@ -198,12 +204,12 @@ class Auditor:
""" """
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
data = f.read() data = f.read()
result_list = []
result = self.parse_bytes(data) result = self.parse_bytes(data)
if result is not None: if result is not None:
result.location = filename result.location = filename
result_list.append(result) return [result]
return result_list else:
return []
def parse_bytes(self, data: bytes): def parse_bytes(self, data: bytes):
"""Parse AuditData from bytes.""" """Parse AuditData from bytes."""
@ -218,11 +224,11 @@ class Auditor:
return audit_data return audit_data
return None return None
def walk_all(self, file_list): def walk_all(self, file_list: typing.Optional[typing.List[str]] = None):
""" """
Iterate over all the files in the list and get audit data. Iterate over all the files in the list and get audit data.
""" """
if not file_list: if file_list is None:
file_list = self.default_files file_list = self.default_files
for filename in file_list: for filename in file_list:
data_list = self.parse_file(filename) data_list = self.parse_file(filename)
@ -250,11 +256,9 @@ class TestDataAuditor(Auditor):
def collect_default_files(self): def collect_default_files(self):
"""Collect all files in tests/data_files/""" """Collect all files in tests/data_files/"""
test_dir = self.find_test_dir() test_dir = self.find_test_dir()
test_data_folder = os.path.join(test_dir, 'data_files') test_data_glob = os.path.join(test_dir, 'data_files/**')
data_files = [] data_files = [f for f in glob.glob(test_data_glob, recursive=True)
for (dir_path, _, file_names) in os.walk(test_data_folder): if os.path.isfile(f)]
data_files.extend(os.path.join(dir_path, file_name)
for file_name in file_names)
return data_files return data_files
class FileWrapper(): class FileWrapper():