From 1e855601ca2372070ad41b861a0586bc00d8cf8a Mon Sep 17 00:00:00 2001 From: Neil Armstrong Date: Wed, 15 Jun 2022 11:32:11 +0200 Subject: [PATCH] Fix psa_pake_get_implicit_key() state & add corresponding tests in ecjpake_rounds() Signed-off-by: Neil Armstrong --- library/psa_crypto_pake.c | 4 +- tests/suites/test_suite_psa_crypto.function | 54 +++++++++++++++------ 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c index f7fb384dd7..8ceacd952f 100644 --- a/library/psa_crypto_pake.c +++ b/library/psa_crypto_pake.c @@ -660,8 +660,8 @@ psa_status_t psa_pake_get_implicit_key(psa_pake_operation_t *operation, if( operation->alg == 0 || operation->state != PSA_PAKE_STATE_READY || - ( operation->input_step != PSA_PAKE_STEP_DERIVE && - operation->output_step != PSA_PAKE_STEP_DERIVE ) ) + operation->input_step != PSA_PAKE_STEP_DERIVE || + operation->output_step != PSA_PAKE_STEP_DERIVE ) return( PSA_ERROR_BAD_STATE ); #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE) diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function index 727784f4ab..6d4f2a8a01 100644 --- a/tests/suites/test_suite_psa_crypto.function +++ b/tests/suites/test_suite_psa_crypto.function @@ -8316,6 +8316,21 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, psa_pake_cs_set_primitive( &cipher_suite, primitive_arg ); psa_pake_cs_set_hash( &cipher_suite, hash_alg ); + /* Get shared key */ + PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) ); + PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) ); + + if( PSA_ALG_IS_TLS12_PRF( derive_alg ) || + PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) ) + { + PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive, + PSA_KEY_DERIVATION_INPUT_SEED, + (const uint8_t*) "", 0) ); + PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive, + PSA_KEY_DERIVATION_INPUT_SEED, + (const uint8_t*) "", 0) ); + } + PSA_ASSERT( psa_pake_setup( &server, &cipher_suite ) ); PSA_ASSERT( psa_pake_setup( &client, &cipher_suite ) ); @@ -8325,6 +8340,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, PSA_ASSERT( psa_pake_set_password_key( &server, key ) ); PSA_ASSERT( psa_pake_set_password_key( &client, key ) ); + TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ), + PSA_ERROR_BAD_STATE ); + TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ), + PSA_ERROR_BAD_STATE ); + /* Server first round Output */ PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE, buffer0 + buffer0_off, @@ -8389,6 +8409,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, c_x2_pr_off = buffer1_off; buffer1_off += c_x2_pr_len; + TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ), + PSA_ERROR_BAD_STATE ); + TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ), + PSA_ERROR_BAD_STATE ); + /* Client first round Input */ PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE, buffer0 + s_g1_off, s_g1_len ) ); @@ -8417,6 +8442,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF, buffer1 + c_x2_pr_off, c_x2_pr_len ) ); + TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ), + PSA_ERROR_BAD_STATE ); + TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ), + PSA_ERROR_BAD_STATE ); + /* Server second round Output */ buffer0_off = 0; @@ -8455,6 +8485,11 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, c_x2s_pr_off = buffer1_off; buffer1_off += c_x2s_pr_len; + TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ), + PSA_ERROR_BAD_STATE ); + TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ), + PSA_ERROR_BAD_STATE ); + /* Client second round Input */ PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE, buffer0 + s_a_off, s_a_len ) ); @@ -8463,6 +8498,9 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF, buffer0 + s_x2s_pr_off, s_x2s_pr_len ) ); + TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ), + PSA_ERROR_BAD_STATE ); + /* Server second round Input */ PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE, buffer1 + c_a_off, c_a_len ) ); @@ -8471,22 +8509,6 @@ void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg, PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF, buffer1 + c_x2s_pr_off, c_x2s_pr_len ) ); - - /* Get shared key */ - PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) ); - PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) ); - - if( PSA_ALG_IS_TLS12_PRF( derive_alg ) || - PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) ) - { - PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive, - PSA_KEY_DERIVATION_INPUT_SEED, - (const uint8_t*) "", 0) ); - PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive, - PSA_KEY_DERIVATION_INPUT_SEED, - (const uint8_t*) "", 0) ); - } - PSA_ASSERT( psa_pake_get_implicit_key( &server, &server_derive ) ); PSA_ASSERT( psa_pake_get_implicit_key( &client, &client_derive ) );