diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py index 103de692bd..b9e09d85c6 100755 --- a/tests/scripts/generate_tls13_compat_tests.py +++ b/tests/scripts/generate_tls13_compat_tests.py @@ -45,7 +45,6 @@ CERTIFICATES = { } - CAFILE = { 'ecdsa_secp256r1_sha256': 'data_files/test-ca2.crt', 'ecdsa_secp384r1_sha384': 'data_files/test-ca2.crt', @@ -76,7 +75,6 @@ NAMED_GROUP_IANA_VALUE = { 'x448': 0x1e, } -OUTPUT_FILE=sys.stdout def remove_duplicates(seq): seen = set() @@ -372,7 +370,8 @@ def generate_compat_test(server=None, client=None, cipher=None, # pylint: disab cmd = prefix.join(cmd) return '\n'.join(server.pre_checks() + client.pre_checks() + [cmd]) -SSL_OUTPUT_HEADER='''#!/bin/sh + +SSL_OUTPUT_HEADER = '''#!/bin/sh # {filename} # @@ -401,11 +400,12 @@ SSL_OUTPUT_HEADER='''#!/bin/sh # ''' + def main(): parser = argparse.ArgumentParser() parser.add_argument('-o', '--output', nargs='?', - default=None, help='Output file path') + default=None, help='Output file path if `-a` was set') parser.add_argument('-a', '--generate-all-tls13-compat-tests', action='store_true', default=False, help='Generate all available tls13 compat tests') @@ -442,34 +442,38 @@ def main(): help='Choose cipher suite for test') args = parser.parse_args() - if args.output: - OUTPUT_FILE=open(args.output,'w') - OUTPUT_FILE.write(SSL_OUTPUT_HEADER.format(filename=args.output)) - if args.generate_all_tls13_compat_tests: + def get_all_test_cases(): for i in itertools.product(CIPHER_SUITE_IANA_VALUE.keys(), SIG_ALG_IANA_VALUE.keys(), NAMED_GROUP_IANA_VALUE.keys(), SERVER_CLS.keys(), CLIENT_CLS.keys()): - test_case = generate_compat_test( **dict( + yield generate_compat_test(**dict( zip(['cipher', 'sig_alg', 'named_group', 'server', 'client'], i))) - print(test_case,file=OUTPUT_FILE) + + if args.generate_all_tls13_compat_tests: + if args.output: + with open(args.output, 'w', encoding="utf-8") as f: + f.write(SSL_OUTPUT_HEADER.format(filename=args.output)) + f.write('\n\n'.join(get_all_test_cases())) + else: + print('\n'.join(get_all_test_cases())) return 0 if args.list_ciphers or args.list_sig_algs or args.list_named_groups \ or args.list_servers or args.list_clients: if args.list_ciphers: - print(*CIPHER_SUITE_IANA_VALUE.keys(),file=OUTPUT_FILE) + print(*CIPHER_SUITE_IANA_VALUE.keys()) if args.list_sig_algs: - print(*SIG_ALG_IANA_VALUE.keys(),file=OUTPUT_FILE) + print(*SIG_ALG_IANA_VALUE.keys()) if args.list_named_groups: - print(*NAMED_GROUP_IANA_VALUE.keys(),file=OUTPUT_FILE) + print(*NAMED_GROUP_IANA_VALUE.keys()) if args.list_servers: - print(*SERVER_CLS.keys(),file=OUTPUT_FILE) + print(*SERVER_CLS.keys()) if args.list_clients: - print(*CLIENT_CLS.keys(),file=OUTPUT_FILE) + print(*CLIENT_CLS.keys()) return 0 - print(generate_compat_test(**vars(args)),file=OUTPUT_FILE) + print(generate_compat_test(**vars(args))) return 0