Pass KNOWN_TASKS as an argument to main

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine 2024-10-03 18:36:09 +02:00
parent 005dca6ad8
commit e41cde57c3

View File

@ -761,7 +761,7 @@ KNOWN_TASKS = {
'analyze_block_cipher_dispatch': DriverVSReference_block_cipher_dispatch, 'analyze_block_cipher_dispatch': DriverVSReference_block_cipher_dispatch,
} }
def main() -> None: def main(known_tasks: typing.Dict[str, typing.Type[Task]]) -> None:
main_results = Results() main_results = Results()
try: try:
@ -783,16 +783,16 @@ def main() -> None:
options = parser.parse_args() options = parser.parse_args()
if options.list: if options.list:
for task_name in KNOWN_TASKS: for task_name in known_tasks:
print(task_name) print(task_name)
sys.exit(0) sys.exit(0)
if options.specified_tasks == 'all': if options.specified_tasks == 'all':
tasks_list = list(KNOWN_TASKS.keys()) tasks_list = list(known_tasks.keys())
else: else:
tasks_list = re.split(r'[, ]+', options.specified_tasks) tasks_list = re.split(r'[, ]+', options.specified_tasks)
for task_name in tasks_list: for task_name in tasks_list:
if task_name not in KNOWN_TASKS: if task_name not in known_tasks:
sys.stderr.write('invalid task: {}\n'.format(task_name)) sys.stderr.write('invalid task: {}\n'.format(task_name))
sys.exit(2) sys.exit(2)
@ -805,7 +805,7 @@ def main() -> None:
sys.exit(2) sys.exit(2)
task_name = tasks_list[0] task_name = tasks_list[0]
task_class = KNOWN_TASKS[task_name] task_class = known_tasks[task_name]
if not issubclass(task_class, DriverVSReference): if not issubclass(task_class, DriverVSReference):
sys.stderr.write("please provide valid outcomes file for {}.\n".format(task_name)) sys.stderr.write("please provide valid outcomes file for {}.\n".format(task_name))
sys.exit(2) sys.exit(2)
@ -824,7 +824,7 @@ def main() -> None:
outcomes = read_outcome_file(options.outcomes) outcomes = read_outcome_file(options.outcomes)
for task_name in tasks_list: for task_name in tasks_list:
task_constructor = KNOWN_TASKS[task_name] task_constructor = known_tasks[task_name]
task_instance = task_constructor(options) task_instance = task_constructor(options)
main_results.new_section(task_instance.section_name()) main_results.new_section(task_instance.section_name())
task_instance.run(main_results, outcomes) task_instance.run(main_results, outcomes)
@ -840,4 +840,4 @@ def main() -> None:
sys.exit(120) sys.exit(120)
if __name__ == '__main__': if __name__ == '__main__':
main() main(KNOWN_TASKS)