diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h index 5529dd1c84..a3351a6d03 100644 --- a/include/psa/crypto_extra.h +++ b/include/psa/crypto_extra.h @@ -2028,14 +2028,33 @@ typedef enum psa_crypto_driver_pake_step { PSA_JPAKE_X4S_STEP_ZK_PROOF = 12 /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */ } psa_crypto_driver_pake_step_t; +typedef enum psa_jpake_round { + FIRST = 0, + SECOND = 1, + FINISHED = 2 +} psa_jpake_round_t; + +typedef enum psa_jpake_io_mode { + INPUT = 0, + OUTPUT = 1 +} psa_jpake_io_mode_t; struct psa_jpake_computation_stage_s { - psa_jpake_state_t MBEDTLS_PRIVATE(state); - psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence); - psa_jpake_step_t MBEDTLS_PRIVATE(input_step); - psa_jpake_step_t MBEDTLS_PRIVATE(output_step); + /* The J-PAKE round we are currently on */ + psa_jpake_round_t MBEDTLS_PRIVATE(round); + /* The 'mode' we are currently in (inputting or outputting) */ + psa_jpake_io_mode_t MBEDTLS_PRIVATE(mode); + /* The number of inputs so far this round */ + uint8_t MBEDTLS_PRIVATE(inputs); + /* The number of outputs so far this round */ + uint8_t MBEDTLS_PRIVATE(outputs); + /* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */ + psa_pake_step_t MBEDTLS_PRIVATE(step); }; +#define PSA_JPAKE_EXPECTED_INPUTS(round) (((round) == FIRST) ? 2 : 1) +#define PSA_JPAKE_EXPECTED_OUTPUTS(round) (((round) == FIRST) ? 2 : 1) + struct psa_pake_operation_s { /** Unique ID indicating which driver got assigned to do the * operation. Since driver contexts are driver-specific, swapping diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 2173483230..f86ea3e6a3 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -7767,10 +7767,11 @@ psa_status_t psa_pake_setup( psa_jpake_computation_stage_t *computation_stage = &operation->computation_stage.jpake; - computation_stage->state = PSA_PAKE_STATE_SETUP; - computation_stage->sequence = PSA_PAKE_SEQ_INVALID; - computation_stage->input_step = PSA_PAKE_STEP_X1_X2; - computation_stage->output_step = PSA_PAKE_STEP_X1_X2; + computation_stage->round = FIRST; + computation_stage->mode = INPUT; + computation_stage->inputs = 0; + computation_stage->outputs = 0; + computation_stage->step = PSA_PAKE_STEP_KEY_SHARE; } else #endif /* PSA_WANT_ALG_JPAKE */ { @@ -7939,57 +7940,66 @@ exit: return status; } -/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */ +/* Auxiliary function to convert core computation stage to single driver step. */ #if defined(PSA_WANT_ALG_JPAKE) static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step( psa_jpake_computation_stage_t *stage) { - switch (stage->state) { - case PSA_PAKE_OUTPUT_X1_X2: - case PSA_PAKE_INPUT_X1_X2: - switch (stage->sequence) { - case PSA_PAKE_X1_STEP_KEY_SHARE: + if (stage->round == FIRST) { + int is_x1; + if (stage->mode == OUTPUT) { + is_x1 = (stage->outputs < 1); + } else { + is_x1 = (stage->inputs < 1); + } + + if (is_x1) { + switch (stage->step) { + case PSA_PAKE_STEP_KEY_SHARE: return PSA_JPAKE_X1_STEP_KEY_SHARE; - case PSA_PAKE_X1_STEP_ZK_PUBLIC: + case PSA_PAKE_STEP_ZK_PUBLIC: return PSA_JPAKE_X1_STEP_ZK_PUBLIC; - case PSA_PAKE_X1_STEP_ZK_PROOF: + case PSA_PAKE_STEP_ZK_PROOF: return PSA_JPAKE_X1_STEP_ZK_PROOF; - case PSA_PAKE_X2_STEP_KEY_SHARE: + default: + return PSA_JPAKE_STEP_INVALID; + } + } else { + switch (stage->step) { + case PSA_PAKE_STEP_KEY_SHARE: return PSA_JPAKE_X2_STEP_KEY_SHARE; - case PSA_PAKE_X2_STEP_ZK_PUBLIC: + case PSA_PAKE_STEP_ZK_PUBLIC: return PSA_JPAKE_X2_STEP_ZK_PUBLIC; - case PSA_PAKE_X2_STEP_ZK_PROOF: + case PSA_PAKE_STEP_ZK_PROOF: return PSA_JPAKE_X2_STEP_ZK_PROOF; default: return PSA_JPAKE_STEP_INVALID; } - break; - case PSA_PAKE_OUTPUT_X2S: - switch (stage->sequence) { - case PSA_PAKE_X1_STEP_KEY_SHARE: + } + } else if (stage->round == SECOND) { + if (stage->mode == OUTPUT) { + switch (stage->step) { + case PSA_PAKE_STEP_KEY_SHARE: return PSA_JPAKE_X2S_STEP_KEY_SHARE; - case PSA_PAKE_X1_STEP_ZK_PUBLIC: + case PSA_PAKE_STEP_ZK_PUBLIC: return PSA_JPAKE_X2S_STEP_ZK_PUBLIC; - case PSA_PAKE_X1_STEP_ZK_PROOF: + case PSA_PAKE_STEP_ZK_PROOF: return PSA_JPAKE_X2S_STEP_ZK_PROOF; default: return PSA_JPAKE_STEP_INVALID; } - break; - case PSA_PAKE_INPUT_X4S: - switch (stage->sequence) { - case PSA_PAKE_X1_STEP_KEY_SHARE: + } else { + switch (stage->step) { + case PSA_PAKE_STEP_KEY_SHARE: return PSA_JPAKE_X4S_STEP_KEY_SHARE; - case PSA_PAKE_X1_STEP_ZK_PUBLIC: + case PSA_PAKE_STEP_ZK_PUBLIC: return PSA_JPAKE_X4S_STEP_ZK_PUBLIC; - case PSA_PAKE_X1_STEP_ZK_PROOF: + case PSA_PAKE_STEP_ZK_PROOF: return PSA_JPAKE_X4S_STEP_ZK_PROOF; default: return PSA_JPAKE_STEP_INVALID; } - break; - default: - return PSA_JPAKE_STEP_INVALID; + } } return PSA_JPAKE_STEP_INVALID; } @@ -8032,10 +8042,11 @@ static psa_status_t psa_pake_complete_inputs( operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION; psa_jpake_computation_stage_t *computation_stage = &operation->computation_stage.jpake; - computation_stage->state = PSA_PAKE_STATE_READY; - computation_stage->sequence = PSA_PAKE_SEQ_INVALID; - computation_stage->input_step = PSA_PAKE_STEP_X1_X2; - computation_stage->output_step = PSA_PAKE_STEP_X1_X2; + computation_stage->round = FIRST; + computation_stage->mode = INPUT; + computation_stage->inputs = 0; + computation_stage->outputs = 0; + computation_stage->step = PSA_PAKE_STEP_KEY_SHARE; } else #endif /* PSA_WANT_ALG_JPAKE */ { @@ -8046,9 +8057,10 @@ static psa_status_t psa_pake_complete_inputs( } #if defined(PSA_WANT_ALG_JPAKE) -static psa_status_t psa_jpake_output_prologue( +static psa_status_t psa_jpake_prologue( psa_pake_operation_t *operation, - psa_pake_step_t step) + psa_pake_step_t step, + psa_jpake_io_mode_t function_mode) { if (step != PSA_PAKE_STEP_KEY_SHARE && step != PSA_PAKE_STEP_ZK_PUBLIC && @@ -8059,84 +8071,79 @@ static psa_status_t psa_jpake_output_prologue( psa_jpake_computation_stage_t *computation_stage = &operation->computation_stage.jpake; - if (computation_stage->state == PSA_PAKE_STATE_INVALID) { + if (computation_stage->round != FIRST && + computation_stage->round != SECOND) { return PSA_ERROR_BAD_STATE; } - if (computation_stage->state != PSA_PAKE_STATE_READY && - computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 && - computation_stage->state != PSA_PAKE_OUTPUT_X2S) { + /* Check that the step we are given is the one we were expecting */ + if (step != computation_stage->step) { return PSA_ERROR_BAD_STATE; } - if (computation_stage->state == PSA_PAKE_STATE_READY) { - if (step != PSA_PAKE_STEP_KEY_SHARE) { - return PSA_ERROR_BAD_STATE; - } - - switch (computation_stage->output_step) { - case PSA_PAKE_STEP_X1_X2: - computation_stage->state = PSA_PAKE_OUTPUT_X1_X2; - break; - case PSA_PAKE_STEP_X2S: - computation_stage->state = PSA_PAKE_OUTPUT_X2S; - break; - default: - return PSA_ERROR_BAD_STATE; - } - - computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE; + if (step == PSA_PAKE_STEP_KEY_SHARE && + computation_stage->inputs == 0 && + computation_stage->outputs == 0) { + /* Start of the round, so function decides whether we are inputting + * or outputting */ + computation_stage->mode = function_mode; + } else if (computation_stage->mode != function_mode) { + /* Middle of the round so the mode we are in must match the function + * called by the user */ + return PSA_ERROR_BAD_STATE; } - /* Check if step matches current sequence */ - switch (computation_stage->sequence) { - case PSA_PAKE_X1_STEP_KEY_SHARE: - case PSA_PAKE_X2_STEP_KEY_SHARE: - if (step != PSA_PAKE_STEP_KEY_SHARE) { - return PSA_ERROR_BAD_STATE; - } - break; - - case PSA_PAKE_X1_STEP_ZK_PUBLIC: - case PSA_PAKE_X2_STEP_ZK_PUBLIC: - if (step != PSA_PAKE_STEP_ZK_PUBLIC) { - return PSA_ERROR_BAD_STATE; - } - break; - - case PSA_PAKE_X1_STEP_ZK_PROOF: - case PSA_PAKE_X2_STEP_ZK_PROOF: - if (step != PSA_PAKE_STEP_ZK_PROOF) { - return PSA_ERROR_BAD_STATE; - } - break; - - default: + /* Check that we do not already have enough inputs/outputs + * this round */ + if (function_mode == INPUT) { + if (computation_stage->inputs >= + PSA_JPAKE_EXPECTED_INPUTS(computation_stage->round)) { return PSA_ERROR_BAD_STATE; + } + } else { + if (computation_stage->outputs >= + PSA_JPAKE_EXPECTED_OUTPUTS(computation_stage->round)) { + return PSA_ERROR_BAD_STATE; + } } - return PSA_SUCCESS; } -static psa_status_t psa_jpake_output_epilogue( - psa_pake_operation_t *operation) +static psa_status_t psa_jpake_epilogue( + psa_pake_operation_t *operation, + psa_jpake_io_mode_t function_mode) { - psa_jpake_computation_stage_t *computation_stage = + psa_jpake_computation_stage_t *stage = &operation->computation_stage.jpake; - if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 && - computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) || - (computation_stage->state == PSA_PAKE_OUTPUT_X2S && - computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) { - computation_stage->state = PSA_PAKE_STATE_READY; - computation_stage->output_step++; - computation_stage->sequence = PSA_PAKE_SEQ_INVALID; + if (stage->step == PSA_PAKE_STEP_ZK_PROOF) { + /* End of an input/output */ + if (function_mode == INPUT) { + stage->inputs++; + if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round)) { + stage->mode = OUTPUT; + } + } + if (function_mode == OUTPUT) { + stage->outputs++; + if (stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) { + stage->mode = INPUT; + } + } + if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round) && + stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) { + /* End of a round, move to the next round */ + stage->inputs = 0; + stage->outputs = 0; + stage->round++; + } + stage->step = PSA_PAKE_STEP_KEY_SHARE; } else { - computation_stage->sequence++; + stage->step++; } - return PSA_SUCCESS; } + #endif /* PSA_WANT_ALG_JPAKE */ psa_status_t psa_pake_output( @@ -8170,7 +8177,7 @@ psa_status_t psa_pake_output( switch (operation->alg) { #if defined(PSA_WANT_ALG_JPAKE) case PSA_ALG_JPAKE: - status = psa_jpake_output_prologue(operation, step); + status = psa_jpake_prologue(operation, step, OUTPUT); if (status != PSA_SUCCESS) { goto exit; } @@ -8194,7 +8201,7 @@ psa_status_t psa_pake_output( switch (operation->alg) { #if defined(PSA_WANT_ALG_JPAKE) case PSA_ALG_JPAKE: - status = psa_jpake_output_epilogue(operation); + status = psa_jpake_epilogue(operation, OUTPUT); if (status != PSA_SUCCESS) { goto exit; } @@ -8211,100 +8218,6 @@ exit: return status; } -#if defined(PSA_WANT_ALG_JPAKE) -static psa_status_t psa_jpake_input_prologue( - psa_pake_operation_t *operation, - psa_pake_step_t step) -{ - if (step != PSA_PAKE_STEP_KEY_SHARE && - step != PSA_PAKE_STEP_ZK_PUBLIC && - step != PSA_PAKE_STEP_ZK_PROOF) { - return PSA_ERROR_INVALID_ARGUMENT; - } - - psa_jpake_computation_stage_t *computation_stage = - &operation->computation_stage.jpake; - - if (computation_stage->state == PSA_PAKE_STATE_INVALID) { - return PSA_ERROR_BAD_STATE; - } - - if (computation_stage->state != PSA_PAKE_STATE_READY && - computation_stage->state != PSA_PAKE_INPUT_X1_X2 && - computation_stage->state != PSA_PAKE_INPUT_X4S) { - return PSA_ERROR_BAD_STATE; - } - - if (computation_stage->state == PSA_PAKE_STATE_READY) { - if (step != PSA_PAKE_STEP_KEY_SHARE) { - return PSA_ERROR_BAD_STATE; - } - - switch (computation_stage->input_step) { - case PSA_PAKE_STEP_X1_X2: - computation_stage->state = PSA_PAKE_INPUT_X1_X2; - break; - case PSA_PAKE_STEP_X2S: - computation_stage->state = PSA_PAKE_INPUT_X4S; - break; - default: - return PSA_ERROR_BAD_STATE; - } - - computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE; - } - - /* Check if step matches current sequence */ - switch (computation_stage->sequence) { - case PSA_PAKE_X1_STEP_KEY_SHARE: - case PSA_PAKE_X2_STEP_KEY_SHARE: - if (step != PSA_PAKE_STEP_KEY_SHARE) { - return PSA_ERROR_BAD_STATE; - } - break; - - case PSA_PAKE_X1_STEP_ZK_PUBLIC: - case PSA_PAKE_X2_STEP_ZK_PUBLIC: - if (step != PSA_PAKE_STEP_ZK_PUBLIC) { - return PSA_ERROR_BAD_STATE; - } - break; - - case PSA_PAKE_X1_STEP_ZK_PROOF: - case PSA_PAKE_X2_STEP_ZK_PROOF: - if (step != PSA_PAKE_STEP_ZK_PROOF) { - return PSA_ERROR_BAD_STATE; - } - break; - - default: - return PSA_ERROR_BAD_STATE; - } - - return PSA_SUCCESS; -} - -static psa_status_t psa_jpake_input_epilogue( - psa_pake_operation_t *operation) -{ - psa_jpake_computation_stage_t *computation_stage = - &operation->computation_stage.jpake; - - if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 && - computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) || - (computation_stage->state == PSA_PAKE_INPUT_X4S && - computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) { - computation_stage->state = PSA_PAKE_STATE_READY; - computation_stage->input_step++; - computation_stage->sequence = PSA_PAKE_SEQ_INVALID; - } else { - computation_stage->sequence++; - } - - return PSA_SUCCESS; -} -#endif /* PSA_WANT_ALG_JPAKE */ - psa_status_t psa_pake_input( psa_pake_operation_t *operation, psa_pake_step_t step, @@ -8337,7 +8250,7 @@ psa_status_t psa_pake_input( switch (operation->alg) { #if defined(PSA_WANT_ALG_JPAKE) case PSA_ALG_JPAKE: - status = psa_jpake_input_prologue(operation, step); + status = psa_jpake_prologue(operation, step, INPUT); if (status != PSA_SUCCESS) { goto exit; } @@ -8361,7 +8274,7 @@ psa_status_t psa_pake_input( switch (operation->alg) { #if defined(PSA_WANT_ALG_JPAKE) case PSA_ALG_JPAKE: - status = psa_jpake_input_epilogue(operation); + status = psa_jpake_epilogue(operation, INPUT); if (status != PSA_SUCCESS) { goto exit; } @@ -8396,8 +8309,7 @@ psa_status_t psa_pake_get_implicit_key( if (operation->alg == PSA_ALG_JPAKE) { psa_jpake_computation_stage_t *computation_stage = &operation->computation_stage.jpake; - if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE || - computation_stage->output_step != PSA_PAKE_STEP_DERIVE) { + if (computation_stage->round != FINISHED) { status = PSA_ERROR_BAD_STATE; goto exit; } diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function index b971f81666..87f7b37d7a 100644 --- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function +++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function @@ -3127,8 +3127,10 @@ void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_st PSA_SUCCESS); /* Simulate that we are ready to get implicit key. */ - operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE; - operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE; + operation.computation_stage.jpake.round = PSA_JPAKE_FINISHED; + operation.computation_stage.jpake.inputs = 0; + operation.computation_stage.jpake.outputs = 0; + operation.computation_stage.jpake.step = PSA_PAKE_STEP_KEY_SHARE; /* --- psa_pake_get_implicit_key --- */ mbedtls_test_driver_pake_hooks.forced_status = forced_status;