diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 93e76aee86..adbd7af82a 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -7237,13 +7237,18 @@ psa_status_t psa_pake_setup( psa_pake_operation_t *operation, const psa_pake_cipher_suite_t *cipher_suite) { + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; + if (operation->stage != PSA_PAKE_OPERATION_STAGE_SETUP) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 || PSA_ALG_IS_HASH(cipher_suite->hash) == 0) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } memset(&operation->data.inputs, 0, sizeof(operation->data.inputs)); @@ -7264,6 +7269,9 @@ psa_status_t psa_pake_setup( operation->stage = PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS; return PSA_SUCCESS; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } psa_status_t psa_pake_set_password_key( @@ -7272,17 +7280,19 @@ psa_status_t psa_pake_set_password_key( { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; psa_key_slot_t *slot = NULL; if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } status = psa_get_and_lock_key_slot_with_policy(password, &slot, PSA_KEY_USAGE_DERIVE, operation->alg); if (status != PSA_SUCCESS) { - return status; + goto exit; } psa_key_attributes_t attributes = { @@ -7294,21 +7304,27 @@ psa_status_t psa_pake_set_password_key( if (type != PSA_KEY_TYPE_PASSWORD && type != PSA_KEY_TYPE_PASSWORD_HASH) { status = PSA_ERROR_INVALID_ARGUMENT; - goto error; + goto exit; } operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes); if (operation->data.inputs.password == NULL) { status = PSA_ERROR_INSUFFICIENT_MEMORY; - goto error; + goto exit; } memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes); operation->data.inputs.password_len = slot->key.bytes; operation->data.inputs.attributes = attributes; -error: + unlock_status = psa_unlock_key_slot(slot); - return (status == PSA_SUCCESS) ? unlock_status : status; + + return unlock_status; +exit: + unlock_status = psa_unlock_key_slot(slot); + abort_status = psa_pake_abort(operation); + status = (status == PSA_SUCCESS) ? unlock_status : status; + return (status == PSA_SUCCESS) ? abort_status : status; } psa_status_t psa_pake_set_user( @@ -7316,17 +7332,24 @@ psa_status_t psa_pake_set_user( const uint8_t *user_id, size_t user_id_len) { + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; (void) user_id; if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (user_id_len == 0) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } return PSA_ERROR_NOT_SUPPORTED; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } psa_status_t psa_pake_set_peer( @@ -7334,25 +7357,36 @@ psa_status_t psa_pake_set_peer( const uint8_t *peer_id, size_t peer_id_len) { + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; (void) peer_id; if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (peer_id_len == 0) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } return PSA_ERROR_NOT_SUPPORTED; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } psa_status_t psa_pake_set_role( psa_pake_operation_t *operation, psa_pake_role_t role) { + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; + if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (role != PSA_PAKE_ROLE_NONE && @@ -7360,12 +7394,16 @@ psa_status_t psa_pake_set_role( role != PSA_PAKE_ROLE_SECOND && role != PSA_PAKE_ROLE_CLIENT && role != PSA_PAKE_ROLE_SERVER) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } operation->data.inputs.role = role; return PSA_SUCCESS; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } /* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */ @@ -7572,32 +7610,36 @@ psa_status_t psa_pake_output( size_t *output_length) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; *output_length = 0; if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { status = psa_pake_complete_inputs(operation); if (status != PSA_SUCCESS) { - return status; + goto exit; } } if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (output_size == 0) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } switch (operation->alg) { case PSA_ALG_JPAKE: status = psa_jpake_output_prologue(operation, step); if (status != PSA_SUCCESS) { - return status; + goto exit; } break; default: - return PSA_ERROR_NOT_SUPPORTED; + status = PSA_ERROR_NOT_SUPPORTED; + goto exit; } status = psa_driver_wrapper_pake_output(operation, @@ -7608,21 +7650,25 @@ psa_status_t psa_pake_output( output_length); if (status != PSA_SUCCESS) { - return status; + goto exit; } switch (operation->alg) { case PSA_ALG_JPAKE: status = psa_jpake_output_epilogue(operation); if (status != PSA_SUCCESS) { - return status; + goto exit; } break; default: - return PSA_ERROR_NOT_SUPPORTED; + status = PSA_ERROR_NOT_SUPPORTED; + goto exit; } - return status; + return PSA_SUCCESS; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } static psa_status_t psa_jpake_input_prologue( @@ -7731,27 +7777,30 @@ psa_status_t psa_pake_input( size_t input_length) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) { status = psa_pake_complete_inputs(operation); if (status != PSA_SUCCESS) { - return status; + goto exit; } } if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (input_length == 0) { - return PSA_ERROR_INVALID_ARGUMENT; + status = PSA_ERROR_INVALID_ARGUMENT; + goto exit; } switch (operation->alg) { case PSA_ALG_JPAKE: status = psa_jpake_input_prologue(operation, step, input_length); if (status != PSA_SUCCESS) { - return status; + goto exit; } break; default: @@ -7765,21 +7814,25 @@ psa_status_t psa_pake_input( input_length); if (status != PSA_SUCCESS) { - return status; + goto exit; } switch (operation->alg) { case PSA_ALG_JPAKE: status = psa_jpake_input_epilogue(operation); if (status != PSA_SUCCESS) { - return status; + goto exit; } break; default: - return PSA_ERROR_NOT_SUPPORTED; + status = PSA_ERROR_NOT_SUPPORTED; + goto exit; } - return status; + return PSA_SUCCESS; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } psa_status_t psa_pake_get_implicit_key( @@ -7787,19 +7840,22 @@ psa_status_t psa_pake_get_implicit_key( psa_key_derivation_operation_t *output) { psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED; uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE]; size_t shared_key_len = 0; - psa_jpake_computation_stage_t *computation_stage = - &operation->computation_stage.jpake; if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } if (operation->alg == PSA_ALG_JPAKE) { + psa_jpake_computation_stage_t *computation_stage = + &operation->computation_stage.jpake; if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE || computation_stage->output_step != PSA_PAKE_STEP_DERIVE) { - return PSA_ERROR_BAD_STATE; + status = PSA_ERROR_BAD_STATE; + goto exit; } } @@ -7808,7 +7864,7 @@ psa_status_t psa_pake_get_implicit_key( &shared_key_len); if (status != PSA_SUCCESS) { - return status; + goto exit; } status = psa_key_derivation_input_bytes(output, @@ -7816,15 +7872,10 @@ psa_status_t psa_pake_get_implicit_key( shared_key, shared_key_len); - if (status != PSA_SUCCESS) { - psa_key_derivation_abort(output); - } - mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE); - - psa_pake_abort(operation); - - return status; +exit: + abort_status = psa_pake_abort(operation); + return status == PSA_SUCCESS ? abort_status : status; } psa_status_t psa_pake_abort( diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function index 3220c62a6e..c1eea5059c 100644 --- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function +++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function @@ -3082,18 +3082,18 @@ void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_st break; case 2: /* input */ - /* --- psa_pake_input (driver: setup, input) --- */ + /* --- psa_pake_input (driver: setup, input, (abort)) --- */ mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup; mbedtls_test_driver_pake_hooks.forced_status = forced_status; mbedtls_test_driver_pake_hooks.hits = 0; TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE, input_buffer, size_key_share), expected_status_input); - TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, in_driver ? 2 : 1); + TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, in_driver ? 3 : 1); break; case 3: /* output */ - /* --- psa_pake_input (driver: setup, output) --- */ + /* --- psa_pake_input (driver: setup, output, (abort)) --- */ mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup; mbedtls_test_driver_pake_hooks.forced_status = forced_status; mbedtls_test_driver_pake_hooks.hits = 0; @@ -3105,10 +3105,12 @@ void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_st output_buffer, output_size, &output_len), expected_status_output); - TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, in_driver ? 2 : 1); if (forced_output->len > 0) { + TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, in_driver ? 2 : 1); TEST_EQUAL(output_len, forced_output->len); TEST_EQUAL(memcmp(output_buffer, forced_output->x, output_len), 0); + } else { + TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, in_driver ? 3 : 1); } break; @@ -3127,7 +3129,7 @@ void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_st mbedtls_test_driver_pake_hooks.hits = 0; TEST_EQUAL(psa_pake_get_implicit_key(&operation, &implicit_key), expected_status_get_key); - TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, 1); + TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits, 2); break;