diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 4f3d774af1..93e76aee86 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7830,38 +7830,22 @@ psa_status_t psa_pake_get_implicit_key(
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
-    psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status = PSA_SUCCESS;
 
-    if (operation->id != 0) {
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
         status = psa_driver_wrapper_pake_abort(operation);
-        if (status != PSA_SUCCESS) {
-            return status;
-        }
     }
 
-    if (operation->data.inputs.password_len > 0) {
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS &&
+        operation->data.inputs.password_len > 0) {
         mbedtls_platform_zeroize(operation->data.inputs.password,
                                  operation->data.inputs.password_len);
         mbedtls_free(operation->data.inputs.password);
     }
 
-    memset(&operation->data, 0, sizeof(operation->data));
+    memset(operation, 0, sizeof(psa_pake_operation_t));
 
-    if (operation->alg == PSA_ALG_JPAKE) {
-        psa_jpake_computation_stage_t *computation_stage =
-            &operation->computation_stage.jpake;
-
-        computation_stage->input_step = PSA_PAKE_STEP_INVALID;
-        computation_stage->output_step = PSA_PAKE_STEP_INVALID;
-        computation_stage->state = PSA_PAKE_STATE_INVALID;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-    }
-
-    operation->alg = PSA_ALG_NONE;
-    operation->stage = PSA_PAKE_OPERATION_STAGE_SETUP;
-    operation->id = 0;
-
-    return PSA_SUCCESS;
+    return status;
 }
 
 #endif /* MBEDTLS_PSA_CRYPTO_C */