Change J-PAKE internal state machine

Keep track of the J-PAKE internal state in a more intuitive way.
Specifically, replace the current state with a struct of 5 fields:

* The round of J-PAKE we are currently in, FIRST or SECOND
* The 'mode' we are currently working in, INPUT or OUTPUT
* The number of inputs so far this round
* The number of outputs so far this round
* The PAKE step we are expecting, KEY_SHARE, ZK_PUBLIC or ZK_PROOF

This should improve the readability of the state-transformation code.

Signed-off-by: David Horstmann <david.horstmann@arm.com>
This commit is contained in:
David Horstmann 2023-05-12 18:17:21 +01:00
parent e25c43bd66
commit e7f21e65b6
3 changed files with 133 additions and 200 deletions

View File

@ -2028,14 +2028,33 @@ typedef enum psa_crypto_driver_pake_step {
PSA_JPAKE_X4S_STEP_ZK_PROOF = 12 /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */ PSA_JPAKE_X4S_STEP_ZK_PROOF = 12 /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
} psa_crypto_driver_pake_step_t; } psa_crypto_driver_pake_step_t;
typedef enum psa_jpake_round {
FIRST = 0,
SECOND = 1,
FINISHED = 2
} psa_jpake_round_t;
typedef enum psa_jpake_io_mode {
INPUT = 0,
OUTPUT = 1
} psa_jpake_io_mode_t;
struct psa_jpake_computation_stage_s { struct psa_jpake_computation_stage_s {
psa_jpake_state_t MBEDTLS_PRIVATE(state); /* The J-PAKE round we are currently on */
psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence); psa_jpake_round_t MBEDTLS_PRIVATE(round);
psa_jpake_step_t MBEDTLS_PRIVATE(input_step); /* The 'mode' we are currently in (inputting or outputting) */
psa_jpake_step_t MBEDTLS_PRIVATE(output_step); psa_jpake_io_mode_t MBEDTLS_PRIVATE(mode);
/* The number of inputs so far this round */
uint8_t MBEDTLS_PRIVATE(inputs);
/* The number of outputs so far this round */
uint8_t MBEDTLS_PRIVATE(outputs);
/* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */
psa_pake_step_t MBEDTLS_PRIVATE(step);
}; };
#define PSA_JPAKE_EXPECTED_INPUTS(round) (((round) == FIRST) ? 2 : 1)
#define PSA_JPAKE_EXPECTED_OUTPUTS(round) (((round) == FIRST) ? 2 : 1)
struct psa_pake_operation_s { struct psa_pake_operation_s {
/** Unique ID indicating which driver got assigned to do the /** Unique ID indicating which driver got assigned to do the
* operation. Since driver contexts are driver-specific, swapping * operation. Since driver contexts are driver-specific, swapping

View File

@ -7767,10 +7767,11 @@ psa_status_t psa_pake_setup(
psa_jpake_computation_stage_t *computation_stage = psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake; &operation->computation_stage.jpake;
computation_stage->state = PSA_PAKE_STATE_SETUP; computation_stage->round = FIRST;
computation_stage->sequence = PSA_PAKE_SEQ_INVALID; computation_stage->mode = INPUT;
computation_stage->input_step = PSA_PAKE_STEP_X1_X2; computation_stage->inputs = 0;
computation_stage->output_step = PSA_PAKE_STEP_X1_X2; computation_stage->outputs = 0;
computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else } else
#endif /* PSA_WANT_ALG_JPAKE */ #endif /* PSA_WANT_ALG_JPAKE */
{ {
@ -7939,57 +7940,66 @@ exit:
return status; return status;
} }
/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */ /* Auxiliary function to convert core computation stage to single driver step. */
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step( static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
psa_jpake_computation_stage_t *stage) psa_jpake_computation_stage_t *stage)
{ {
switch (stage->state) { if (stage->round == FIRST) {
case PSA_PAKE_OUTPUT_X1_X2: int is_x1;
case PSA_PAKE_INPUT_X1_X2: if (stage->mode == OUTPUT) {
switch (stage->sequence) { is_x1 = (stage->outputs < 1);
case PSA_PAKE_X1_STEP_KEY_SHARE: } else {
is_x1 = (stage->inputs < 1);
}
if (is_x1) {
switch (stage->step) {
case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X1_STEP_KEY_SHARE; return PSA_JPAKE_X1_STEP_KEY_SHARE;
case PSA_PAKE_X1_STEP_ZK_PUBLIC: case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X1_STEP_ZK_PUBLIC; return PSA_JPAKE_X1_STEP_ZK_PUBLIC;
case PSA_PAKE_X1_STEP_ZK_PROOF: case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X1_STEP_ZK_PROOF; return PSA_JPAKE_X1_STEP_ZK_PROOF;
case PSA_PAKE_X2_STEP_KEY_SHARE: default:
return PSA_JPAKE_STEP_INVALID;
}
} else {
switch (stage->step) {
case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X2_STEP_KEY_SHARE; return PSA_JPAKE_X2_STEP_KEY_SHARE;
case PSA_PAKE_X2_STEP_ZK_PUBLIC: case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X2_STEP_ZK_PUBLIC; return PSA_JPAKE_X2_STEP_ZK_PUBLIC;
case PSA_PAKE_X2_STEP_ZK_PROOF: case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X2_STEP_ZK_PROOF; return PSA_JPAKE_X2_STEP_ZK_PROOF;
default: default:
return PSA_JPAKE_STEP_INVALID; return PSA_JPAKE_STEP_INVALID;
} }
break; }
case PSA_PAKE_OUTPUT_X2S: } else if (stage->round == SECOND) {
switch (stage->sequence) { if (stage->mode == OUTPUT) {
case PSA_PAKE_X1_STEP_KEY_SHARE: switch (stage->step) {
case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X2S_STEP_KEY_SHARE; return PSA_JPAKE_X2S_STEP_KEY_SHARE;
case PSA_PAKE_X1_STEP_ZK_PUBLIC: case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X2S_STEP_ZK_PUBLIC; return PSA_JPAKE_X2S_STEP_ZK_PUBLIC;
case PSA_PAKE_X1_STEP_ZK_PROOF: case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X2S_STEP_ZK_PROOF; return PSA_JPAKE_X2S_STEP_ZK_PROOF;
default: default:
return PSA_JPAKE_STEP_INVALID; return PSA_JPAKE_STEP_INVALID;
} }
break; } else {
case PSA_PAKE_INPUT_X4S: switch (stage->step) {
switch (stage->sequence) { case PSA_PAKE_STEP_KEY_SHARE:
case PSA_PAKE_X1_STEP_KEY_SHARE:
return PSA_JPAKE_X4S_STEP_KEY_SHARE; return PSA_JPAKE_X4S_STEP_KEY_SHARE;
case PSA_PAKE_X1_STEP_ZK_PUBLIC: case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X4S_STEP_ZK_PUBLIC; return PSA_JPAKE_X4S_STEP_ZK_PUBLIC;
case PSA_PAKE_X1_STEP_ZK_PROOF: case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X4S_STEP_ZK_PROOF; return PSA_JPAKE_X4S_STEP_ZK_PROOF;
default: default:
return PSA_JPAKE_STEP_INVALID; return PSA_JPAKE_STEP_INVALID;
} }
break; }
default:
return PSA_JPAKE_STEP_INVALID;
} }
return PSA_JPAKE_STEP_INVALID; return PSA_JPAKE_STEP_INVALID;
} }
@ -8032,10 +8042,11 @@ static psa_status_t psa_pake_complete_inputs(
operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION; operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
psa_jpake_computation_stage_t *computation_stage = psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake; &operation->computation_stage.jpake;
computation_stage->state = PSA_PAKE_STATE_READY; computation_stage->round = FIRST;
computation_stage->sequence = PSA_PAKE_SEQ_INVALID; computation_stage->mode = INPUT;
computation_stage->input_step = PSA_PAKE_STEP_X1_X2; computation_stage->inputs = 0;
computation_stage->output_step = PSA_PAKE_STEP_X1_X2; computation_stage->outputs = 0;
computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else } else
#endif /* PSA_WANT_ALG_JPAKE */ #endif /* PSA_WANT_ALG_JPAKE */
{ {
@ -8046,9 +8057,10 @@ static psa_status_t psa_pake_complete_inputs(
} }
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
static psa_status_t psa_jpake_output_prologue( static psa_status_t psa_jpake_prologue(
psa_pake_operation_t *operation, psa_pake_operation_t *operation,
psa_pake_step_t step) psa_pake_step_t step,
psa_jpake_io_mode_t function_mode)
{ {
if (step != PSA_PAKE_STEP_KEY_SHARE && if (step != PSA_PAKE_STEP_KEY_SHARE &&
step != PSA_PAKE_STEP_ZK_PUBLIC && step != PSA_PAKE_STEP_ZK_PUBLIC &&
@ -8059,84 +8071,79 @@ static psa_status_t psa_jpake_output_prologue(
psa_jpake_computation_stage_t *computation_stage = psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake; &operation->computation_stage.jpake;
if (computation_stage->state == PSA_PAKE_STATE_INVALID) { if (computation_stage->round != FIRST &&
computation_stage->round != SECOND) {
return PSA_ERROR_BAD_STATE; return PSA_ERROR_BAD_STATE;
} }
if (computation_stage->state != PSA_PAKE_STATE_READY && /* Check that the step we are given is the one we were expecting */
computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 && if (step != computation_stage->step) {
computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
return PSA_ERROR_BAD_STATE; return PSA_ERROR_BAD_STATE;
} }
if (computation_stage->state == PSA_PAKE_STATE_READY) { if (step == PSA_PAKE_STEP_KEY_SHARE &&
if (step != PSA_PAKE_STEP_KEY_SHARE) { computation_stage->inputs == 0 &&
computation_stage->outputs == 0) {
/* Start of the round, so function decides whether we are inputting
* or outputting */
computation_stage->mode = function_mode;
} else if (computation_stage->mode != function_mode) {
/* Middle of the round so the mode we are in must match the function
* called by the user */
return PSA_ERROR_BAD_STATE; return PSA_ERROR_BAD_STATE;
} }
switch (computation_stage->output_step) { /* Check that we do not already have enough inputs/outputs
case PSA_PAKE_STEP_X1_X2: * this round */
computation_stage->state = PSA_PAKE_OUTPUT_X1_X2; if (function_mode == INPUT) {
break; if (computation_stage->inputs >=
case PSA_PAKE_STEP_X2S: PSA_JPAKE_EXPECTED_INPUTS(computation_stage->round)) {
computation_stage->state = PSA_PAKE_OUTPUT_X2S;
break;
default:
return PSA_ERROR_BAD_STATE; return PSA_ERROR_BAD_STATE;
} }
} else {
computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE; if (computation_stage->outputs >=
} PSA_JPAKE_EXPECTED_OUTPUTS(computation_stage->round)) {
/* Check if step matches current sequence */
switch (computation_stage->sequence) {
case PSA_PAKE_X1_STEP_KEY_SHARE:
case PSA_PAKE_X2_STEP_KEY_SHARE:
if (step != PSA_PAKE_STEP_KEY_SHARE) {
return PSA_ERROR_BAD_STATE; return PSA_ERROR_BAD_STATE;
} }
break;
case PSA_PAKE_X1_STEP_ZK_PUBLIC:
case PSA_PAKE_X2_STEP_ZK_PUBLIC:
if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
return PSA_ERROR_BAD_STATE;
} }
break;
case PSA_PAKE_X1_STEP_ZK_PROOF:
case PSA_PAKE_X2_STEP_ZK_PROOF:
if (step != PSA_PAKE_STEP_ZK_PROOF) {
return PSA_ERROR_BAD_STATE;
}
break;
default:
return PSA_ERROR_BAD_STATE;
}
return PSA_SUCCESS; return PSA_SUCCESS;
} }
static psa_status_t psa_jpake_output_epilogue( static psa_status_t psa_jpake_epilogue(
psa_pake_operation_t *operation) psa_pake_operation_t *operation,
psa_jpake_io_mode_t function_mode)
{ {
psa_jpake_computation_stage_t *computation_stage = psa_jpake_computation_stage_t *stage =
&operation->computation_stage.jpake; &operation->computation_stage.jpake;
if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 && if (stage->step == PSA_PAKE_STEP_ZK_PROOF) {
computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) || /* End of an input/output */
(computation_stage->state == PSA_PAKE_OUTPUT_X2S && if (function_mode == INPUT) {
computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) { stage->inputs++;
computation_stage->state = PSA_PAKE_STATE_READY; if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round)) {
computation_stage->output_step++; stage->mode = OUTPUT;
computation_stage->sequence = PSA_PAKE_SEQ_INVALID; }
} else { }
computation_stage->sequence++; if (function_mode == OUTPUT) {
stage->outputs++;
if (stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
stage->mode = INPUT;
}
}
if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round) &&
stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
/* End of a round, move to the next round */
stage->inputs = 0;
stage->outputs = 0;
stage->round++;
}
stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else {
stage->step++;
} }
return PSA_SUCCESS; return PSA_SUCCESS;
} }
#endif /* PSA_WANT_ALG_JPAKE */ #endif /* PSA_WANT_ALG_JPAKE */
psa_status_t psa_pake_output( psa_status_t psa_pake_output(
@ -8170,7 +8177,7 @@ psa_status_t psa_pake_output(
switch (operation->alg) { switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE: case PSA_ALG_JPAKE:
status = psa_jpake_output_prologue(operation, step); status = psa_jpake_prologue(operation, step, OUTPUT);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
goto exit; goto exit;
} }
@ -8194,7 +8201,7 @@ psa_status_t psa_pake_output(
switch (operation->alg) { switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE: case PSA_ALG_JPAKE:
status = psa_jpake_output_epilogue(operation); status = psa_jpake_epilogue(operation, OUTPUT);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
goto exit; goto exit;
} }
@ -8211,100 +8218,6 @@ exit:
return status; return status;
} }
#if defined(PSA_WANT_ALG_JPAKE)
static psa_status_t psa_jpake_input_prologue(
psa_pake_operation_t *operation,
psa_pake_step_t step)
{
if (step != PSA_PAKE_STEP_KEY_SHARE &&
step != PSA_PAKE_STEP_ZK_PUBLIC &&
step != PSA_PAKE_STEP_ZK_PROOF) {
return PSA_ERROR_INVALID_ARGUMENT;
}
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
return PSA_ERROR_BAD_STATE;
}
if (computation_stage->state != PSA_PAKE_STATE_READY &&
computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
computation_stage->state != PSA_PAKE_INPUT_X4S) {
return PSA_ERROR_BAD_STATE;
}
if (computation_stage->state == PSA_PAKE_STATE_READY) {
if (step != PSA_PAKE_STEP_KEY_SHARE) {
return PSA_ERROR_BAD_STATE;
}
switch (computation_stage->input_step) {
case PSA_PAKE_STEP_X1_X2:
computation_stage->state = PSA_PAKE_INPUT_X1_X2;
break;
case PSA_PAKE_STEP_X2S:
computation_stage->state = PSA_PAKE_INPUT_X4S;
break;
default:
return PSA_ERROR_BAD_STATE;
}
computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
}
/* Check if step matches current sequence */
switch (computation_stage->sequence) {
case PSA_PAKE_X1_STEP_KEY_SHARE:
case PSA_PAKE_X2_STEP_KEY_SHARE:
if (step != PSA_PAKE_STEP_KEY_SHARE) {
return PSA_ERROR_BAD_STATE;
}
break;
case PSA_PAKE_X1_STEP_ZK_PUBLIC:
case PSA_PAKE_X2_STEP_ZK_PUBLIC:
if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
return PSA_ERROR_BAD_STATE;
}
break;
case PSA_PAKE_X1_STEP_ZK_PROOF:
case PSA_PAKE_X2_STEP_ZK_PROOF:
if (step != PSA_PAKE_STEP_ZK_PROOF) {
return PSA_ERROR_BAD_STATE;
}
break;
default:
return PSA_ERROR_BAD_STATE;
}
return PSA_SUCCESS;
}
static psa_status_t psa_jpake_input_epilogue(
psa_pake_operation_t *operation)
{
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
(computation_stage->state == PSA_PAKE_INPUT_X4S &&
computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
computation_stage->state = PSA_PAKE_STATE_READY;
computation_stage->input_step++;
computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
} else {
computation_stage->sequence++;
}
return PSA_SUCCESS;
}
#endif /* PSA_WANT_ALG_JPAKE */
psa_status_t psa_pake_input( psa_status_t psa_pake_input(
psa_pake_operation_t *operation, psa_pake_operation_t *operation,
psa_pake_step_t step, psa_pake_step_t step,
@ -8337,7 +8250,7 @@ psa_status_t psa_pake_input(
switch (operation->alg) { switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE: case PSA_ALG_JPAKE:
status = psa_jpake_input_prologue(operation, step); status = psa_jpake_prologue(operation, step, INPUT);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
goto exit; goto exit;
} }
@ -8361,7 +8274,7 @@ psa_status_t psa_pake_input(
switch (operation->alg) { switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE) #if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE: case PSA_ALG_JPAKE:
status = psa_jpake_input_epilogue(operation); status = psa_jpake_epilogue(operation, INPUT);
if (status != PSA_SUCCESS) { if (status != PSA_SUCCESS) {
goto exit; goto exit;
} }
@ -8396,8 +8309,7 @@ psa_status_t psa_pake_get_implicit_key(
if (operation->alg == PSA_ALG_JPAKE) { if (operation->alg == PSA_ALG_JPAKE) {
psa_jpake_computation_stage_t *computation_stage = psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake; &operation->computation_stage.jpake;
if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE || if (computation_stage->round != FINISHED) {
computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
status = PSA_ERROR_BAD_STATE; status = PSA_ERROR_BAD_STATE;
goto exit; goto exit;
} }

View File

@ -3127,8 +3127,10 @@ void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_st
PSA_SUCCESS); PSA_SUCCESS);
/* Simulate that we are ready to get implicit key. */ /* Simulate that we are ready to get implicit key. */
operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE; operation.computation_stage.jpake.round = PSA_JPAKE_FINISHED;
operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE; operation.computation_stage.jpake.inputs = 0;
operation.computation_stage.jpake.outputs = 0;
operation.computation_stage.jpake.step = PSA_PAKE_STEP_KEY_SHARE;
/* --- psa_pake_get_implicit_key --- */ /* --- psa_pake_get_implicit_key --- */
mbedtls_test_driver_pake_hooks.forced_status = forced_status; mbedtls_test_driver_pake_hooks.forced_status = forced_status;