/* BEGIN_HEADER */
#include <stdint.h>

#include "psa/crypto.h"

typedef enum {
    ERR_NONE = 0,
    /* errors forced internally in the code */
    ERR_INJECT_UNINITIALIZED_ACCESS,
    ERR_INJECT_DUPLICATE_SETUP,
    ERR_INJECT_SET_USER,
    ERR_INJECT_SET_PEER,
    ERR_INJECT_SET_ROLE,
    ERR_DUPLICATE_SET_USER,
    ERR_DUPLICATE_SET_PEER,
    ERR_INJECT_EMPTY_IO_BUFFER,
    ERR_INJECT_UNKNOWN_STEP,
    ERR_INJECT_INVALID_FIRST_STEP,
    ERR_INJECT_WRONG_BUFFER_SIZE,
    ERR_INJECT_VALID_OPERATION_AFTER_FAILURE,
    ERR_INJECT_ANTICIPATE_KEY_DERIVATION_1,
    ERR_INJECT_ANTICIPATE_KEY_DERIVATION_2,
    ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1,
    ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART1,
    ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART1,
    ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART2,
    ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART2,
    ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2,
    ERR_INJECT_ROUND2_CLIENT_KEY_SHARE,
    ERR_INJECT_ROUND2_CLIENT_ZK_PUBLIC,
    ERR_INJECT_ROUND2_CLIENT_ZK_PROOF,
    ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1,
    ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART1,
    ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART1,
    ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART2,
    ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART2,
    ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2,
    ERR_INJECT_ROUND2_SERVER_KEY_SHARE,
    ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC,
    ERR_INJECT_ROUND2_SERVER_ZK_PROOF,
    /* erros issued from the .data file */
    ERR_IN_SETUP,
    ERR_IN_SET_USER,
    ERR_IN_SET_PEER,
    ERR_IN_SET_ROLE,
    ERR_IN_SET_PASSWORD_KEY,
    ERR_IN_INPUT,
    ERR_IN_OUTPUT,
} ecjpake_error_stage_t;

typedef enum {
    PAKE_ROUND_ONE,
    PAKE_ROUND_TWO
} pake_round_t;

/*
 * Inject an error on the specified buffer ONLY it this is the correct stage.
 * Offset 7 is arbitrary, but chosen because it's "in the middle" of the part
 * we're corrupting.
 */
#define DO_ROUND_CONDITIONAL_INJECT(this_stage, buf) \
    if (this_stage == err_stage)                  \
    {                                               \
        *(buf + 7) ^= 1;                           \
    }

#define DO_ROUND_UPDATE_OFFSETS(main_buf_offset, step_offset, step_size) \
    {                                       \
        step_offset = main_buf_offset;      \
        main_buf_offset += step_size;        \
    }

#define DO_ROUND_CHECK_FAILURE()                                  \
    if (err_stage != ERR_NONE && status != PSA_SUCCESS)            \
    {                                                               \
        TEST_EQUAL(status, expected_error_arg);                   \
        break;                                                      \
    }                                                               \
    else                                                            \
    {                                                               \
        TEST_EQUAL(status, PSA_SUCCESS);                          \
    }

#if defined(PSA_WANT_ALG_JPAKE)
static void ecjpake_do_round(psa_algorithm_t alg, unsigned int primitive,
                             psa_pake_operation_t *server,
                             psa_pake_operation_t *client,
                             int client_input_first,
                             pake_round_t round,
                             ecjpake_error_stage_t err_stage,
                             int expected_error_arg)
{
    unsigned char *buffer0 = NULL, *buffer1 = NULL;
    size_t buffer_length = (
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE) +
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC) +
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF)) * 2;
    /* The output should be exactly this size according to the spec */
    const size_t expected_size_key_share =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE);
    /* The output should be exactly this size according to the spec */
    const size_t expected_size_zk_public =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC);
    /* The output can be smaller: the spec allows stripping leading zeroes */
    const size_t max_expected_size_zk_proof =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF);
    size_t buffer0_off = 0;
    size_t buffer1_off = 0;
    size_t s_g1_len, s_g2_len, s_a_len;
    size_t s_g1_off, s_g2_off, s_a_off;
    size_t s_x1_pk_len, s_x2_pk_len, s_x2s_pk_len;
    size_t s_x1_pk_off, s_x2_pk_off, s_x2s_pk_off;
    size_t s_x1_pr_len, s_x2_pr_len, s_x2s_pr_len;
    size_t s_x1_pr_off, s_x2_pr_off, s_x2s_pr_off;
    size_t c_g1_len, c_g2_len, c_a_len;
    size_t c_g1_off, c_g2_off, c_a_off;
    size_t c_x1_pk_len, c_x2_pk_len, c_x2s_pk_len;
    size_t c_x1_pk_off, c_x2_pk_off, c_x2s_pk_off;
    size_t c_x1_pr_len, c_x2_pr_len, c_x2s_pr_len;
    size_t c_x1_pr_off, c_x2_pr_off, c_x2s_pr_off;
    psa_status_t status;

    ASSERT_ALLOC(buffer0, buffer_length);
    ASSERT_ALLOC(buffer1, buffer_length);

    switch (round) {
        case PAKE_ROUND_ONE:
            /* Server first round Output */
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_g1_len));
            TEST_EQUAL(s_g1_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_g1_off, s_g1_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x1_pk_len));
            TEST_EQUAL(s_x1_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART1,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x1_pk_off, s_x1_pk_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x1_pr_len));
            TEST_LE_U(s_x1_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART1,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x1_pr_off, s_x1_pr_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_g2_len));
            TEST_EQUAL(s_g2_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART2,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_g2_off, s_g2_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2_pk_len));
            TEST_EQUAL(s_x2_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART2,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x2_pk_off, s_x2_pk_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2_pr_len));
            TEST_LE_U(s_x2_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x2_pr_off, s_x2_pr_len);

            /*
             * When injecting errors in inputs, the implementation is
             * free to detect it right away of with a delay.
             * This permits delaying the error until the end of the input
             * sequence, if no error appears then, this will be treated
             * as an error.
             */
            if (client_input_first == 1) {
                /* Client first round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g1_off, s_g1_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x1_pk_off,
                                        s_x1_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x1_pr_off,
                                        s_x1_pr_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g2_off,
                                        s_g2_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2_pk_off,
                                        s_x2_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2_pr_off,
                                        s_x2_pr_len);
                DO_ROUND_CHECK_FAILURE();

                /* Error didn't trigger, make test fail */
                if ((err_stage >= ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1) &&
                    (err_stage <= ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2)) {
                    TEST_ASSERT(
                        !"One of the last psa_pake_input() calls should have returned the expected error.");
                }
            }

            /* Client first round Output */
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_g1_len));
            TEST_EQUAL(c_g1_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_g1_off, c_g1_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x1_pk_len));
            TEST_EQUAL(c_x1_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART1,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x1_pk_off, c_x1_pk_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x1_pr_len));
            TEST_LE_U(c_x1_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART1,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x1_pr_off, c_x1_pr_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_g2_len));
            TEST_EQUAL(c_g2_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART2,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_g2_off, c_g2_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2_pk_len));
            TEST_EQUAL(c_x2_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART2,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2_pk_off, c_x2_pk_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2_pr_len));
            TEST_LE_U(c_x2_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2_pr_off, buffer1_off);

            if (client_input_first == 0) {
                /* Client first round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g1_off, s_g1_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x1_pk_off,
                                        s_x1_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x1_pr_off,
                                        s_x1_pr_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g2_off,
                                        s_g2_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2_pk_off,
                                        s_x2_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2_pr_off,
                                        s_x2_pr_len);
                DO_ROUND_CHECK_FAILURE();

                /* Error didn't trigger, make test fail */
                if ((err_stage >= ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1) &&
                    (err_stage <= ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2)) {
                    TEST_ASSERT(
                        !"One of the last psa_pake_input() calls should have returned the expected error.");
                }
            }

            /* Server first round Input */
            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_g1_off, c_g1_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x1_pk_off, c_x1_pk_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x1_pr_off, c_x1_pr_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_g2_off, c_g2_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x2_pk_off, c_x2_pk_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x2_pr_off, c_x2_pr_len);
            DO_ROUND_CHECK_FAILURE();

            /* Error didn't trigger, make test fail */
            if ((err_stage >= ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1) &&
                (err_stage <= ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2)) {
                TEST_ASSERT(
                    !"One of the last psa_pake_input() calls should have returned the expected error.");
            }

            break;

        case PAKE_ROUND_TWO:
            /* Server second round Output */
            buffer0_off = 0;

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_a_len));
            TEST_EQUAL(s_a_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_SERVER_KEY_SHARE,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_a_off, s_a_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2s_pk_len));
            TEST_EQUAL(s_x2s_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x2s_pk_off, s_x2s_pk_len);

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2s_pr_len));
            TEST_LE_U(s_x2s_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_SERVER_ZK_PROOF,
                buffer0 + buffer0_off);
            DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x2s_pr_off, s_x2s_pr_len);

            if (client_input_first == 1) {
                /* Client second round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_a_off, s_a_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2s_pk_off,
                                        s_x2s_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2s_pr_off,
                                        s_x2s_pr_len);
                DO_ROUND_CHECK_FAILURE();

                /* Error didn't trigger, make test fail */
                if ((err_stage >= ERR_INJECT_ROUND2_SERVER_KEY_SHARE) &&
                    (err_stage <= ERR_INJECT_ROUND2_SERVER_ZK_PROOF)) {
                    TEST_ASSERT(
                        !"One of the last psa_pake_input() calls should have returned the expected error.");
                }
            }

            /* Client second round Output */
            buffer1_off = 0;

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_a_len));
            TEST_EQUAL(c_a_len, expected_size_key_share);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_CLIENT_KEY_SHARE,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_a_off, c_a_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2s_pk_len));
            TEST_EQUAL(c_x2s_pk_len, expected_size_zk_public);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_CLIENT_ZK_PUBLIC,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2s_pk_off, c_x2s_pk_len);

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2s_pr_len));
            TEST_LE_U(c_x2s_pr_len, max_expected_size_zk_proof);
            DO_ROUND_CONDITIONAL_INJECT(
                ERR_INJECT_ROUND2_CLIENT_ZK_PROOF,
                buffer1 + buffer1_off);
            DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2s_pr_off, c_x2s_pr_len);

            if (client_input_first == 0) {
                /* Client second round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_a_off, s_a_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2s_pk_off,
                                        s_x2s_pk_len);
                DO_ROUND_CHECK_FAILURE();

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2s_pr_off,
                                        s_x2s_pr_len);
                DO_ROUND_CHECK_FAILURE();

                /* Error didn't trigger, make test fail */
                if ((err_stage >= ERR_INJECT_ROUND2_SERVER_KEY_SHARE) &&
                    (err_stage <= ERR_INJECT_ROUND2_SERVER_ZK_PROOF)) {
                    TEST_ASSERT(
                        !"One of the last psa_pake_input() calls should have returned the expected error.");
                }
            }

            /* Server second round Input */
            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_a_off, c_a_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x2s_pk_off, c_x2s_pk_len);
            DO_ROUND_CHECK_FAILURE();

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x2s_pr_off, c_x2s_pr_len);
            DO_ROUND_CHECK_FAILURE();

            /* Error didn't trigger, make test fail */
            if ((err_stage >= ERR_INJECT_ROUND2_CLIENT_KEY_SHARE) &&
                (err_stage <= ERR_INJECT_ROUND2_CLIENT_ZK_PROOF)) {
                TEST_ASSERT(
                    !"One of the last psa_pake_input() calls should have returned the expected error.");
            }

            break;

    }

exit:
    mbedtls_free(buffer0);
    mbedtls_free(buffer1);
}
#endif /* PSA_WANT_ALG_JPAKE */

/*
 * This check is used for functions that might either succeed or fail depending
 * on the parameters that are passed in from the *.data file:
 * - in case of success following functions depend on the current one
 * - in case of failure the test is always terminated. There are two options
 *   here
 *     - terminated successfully if this exact error was expected at this stage
 *     - terminated with failure otherwise (either no error was expected at this
 *       stage or a different error code was expected)
 */
#define SETUP_ALWAYS_CHECK_STEP(test_function, this_check_err_stage)      \
    status = test_function;                                                 \
    if (err_stage != this_check_err_stage)                                 \
    {                                                                       \
        PSA_ASSERT(status);                                               \
    }                                                                       \
    else                                                                    \
    {                                                                       \
        TEST_EQUAL(status, expected_error);                               \
        goto exit;                                                          \
    }

/*
 * This check is used for failures that are injected at code level. There's only
 * 1 input parameter that is relevant in this case and it's the stage at which
 * the error should be injected.
 * The check is conditional in this case because, once the error is triggered,
 * the pake's context structure is compromised and the setup function cannot
 * proceed further. As a consequence the test is terminated.
 * The test succeeds if the returned error is exactly the expected one,
 * otherwise it fails.
 */
#define SETUP_CONDITIONAL_CHECK_STEP(test_function, this_check_err_stage) \
    if (err_stage == this_check_err_stage)                                 \
    {                                                                       \
        TEST_EQUAL(test_function, expected_error);                        \
        goto exit;                                                          \
    }
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void ecjpake_setup(int alg_arg, int key_type_pw_arg, int key_usage_pw_arg,
                   int primitive_arg, int hash_arg, char *user_arg, char *peer_arg,
                   int test_input,
                   int err_stage_arg,
                   int expected_error_arg)
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    psa_algorithm_t alg = alg_arg;
    psa_pake_primitive_t primitive = primitive_arg;
    psa_key_type_t key_type_pw = key_type_pw_arg;
    psa_key_usage_t key_usage_pw = key_usage_pw_arg;
    psa_algorithm_t hash_alg = hash_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    ecjpake_error_stage_t err_stage = err_stage_arg;
    psa_status_t expected_error = expected_error_arg;
    psa_status_t status;
    unsigned char *output_buffer = NULL;
    size_t output_len = 0;
    const uint8_t password[] = "abcd";
    uint8_t *user = (uint8_t*)user_arg;
    uint8_t *peer = (uint8_t*)peer_arg;
    size_t user_len = strlen(user_arg);
    size_t peer_len = strlen(peer_arg);

    psa_key_derivation_operation_t key_derivation =
        PSA_KEY_DERIVATION_OPERATION_INIT;

    PSA_INIT();

    size_t buf_size = PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg,
                                           PSA_PAKE_STEP_KEY_SHARE);
    ASSERT_ALLOC(output_buffer, buf_size);

    psa_set_key_usage_flags(&attributes, key_usage_pw);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type_pw);
    PSA_ASSERT(psa_import_key(&attributes, password, sizeof(password),
                              &key));

    psa_pake_cs_set_algorithm(&cipher_suite, alg);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, hash_alg);

    PSA_ASSERT(psa_pake_abort(&operation));

    if (err_stage == ERR_INJECT_UNINITIALIZED_ACCESS) {
        TEST_EQUAL(psa_pake_set_user(&operation, user, user_len),
                   expected_error);
        TEST_EQUAL(psa_pake_set_peer(&operation, peer, peer_len),
                   expected_error);
        TEST_EQUAL(psa_pake_set_password_key(&operation, key),
                   expected_error);
        TEST_EQUAL(psa_pake_set_role(&operation, PSA_PAKE_ROLE_SERVER),
                   expected_error);
        TEST_EQUAL(psa_pake_output(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                   output_buffer, 0, &output_len),
                   expected_error);
        TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                  output_buffer, 0),
                   expected_error);
        TEST_EQUAL(psa_pake_get_implicit_key(&operation, &key_derivation),
                   expected_error);
        goto exit;
    }

    SETUP_ALWAYS_CHECK_STEP(psa_pake_setup(&operation, &cipher_suite),
                            ERR_IN_SETUP);

    SETUP_CONDITIONAL_CHECK_STEP(psa_pake_setup(&operation, &cipher_suite),
                                 ERR_INJECT_DUPLICATE_SETUP);

    SETUP_CONDITIONAL_CHECK_STEP(psa_pake_set_role(&operation, PSA_PAKE_ROLE_SERVER),
                                 ERR_INJECT_SET_ROLE);

    SETUP_ALWAYS_CHECK_STEP(psa_pake_set_role(&operation, PSA_PAKE_ROLE_NONE),
                                 ERR_IN_SET_ROLE);

    SETUP_ALWAYS_CHECK_STEP(psa_pake_set_user(&operation, user, user_len),
                                 ERR_IN_SET_USER);

    SETUP_ALWAYS_CHECK_STEP(psa_pake_set_peer(&operation, peer, peer_len),
                                 ERR_IN_SET_PEER);

    SETUP_CONDITIONAL_CHECK_STEP(psa_pake_set_user(&operation, user, user_len),
                                 ERR_DUPLICATE_SET_USER);

    SETUP_CONDITIONAL_CHECK_STEP(psa_pake_set_peer(&operation, peer, peer_len),
                                 ERR_DUPLICATE_SET_PEER);

    SETUP_ALWAYS_CHECK_STEP(psa_pake_set_password_key(&operation, key),
                            ERR_IN_SET_PASSWORD_KEY);

    const size_t size_key_share = PSA_PAKE_INPUT_SIZE(alg, primitive,
                                                      PSA_PAKE_STEP_KEY_SHARE);
    const size_t size_zk_public = PSA_PAKE_INPUT_SIZE(alg, primitive,
                                                      PSA_PAKE_STEP_ZK_PUBLIC);
    const size_t size_zk_proof = PSA_PAKE_INPUT_SIZE(alg, primitive,
                                                     PSA_PAKE_STEP_ZK_PROOF);

    if (test_input) {
        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
                                                    PSA_PAKE_STEP_ZK_PROOF,
                                                    output_buffer, 0),
                                     ERR_INJECT_EMPTY_IO_BUFFER);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
                                                    PSA_PAKE_STEP_ZK_PROOF + 10,
                                                    output_buffer, size_zk_proof),
                                     ERR_INJECT_UNKNOWN_STEP);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
                                                    PSA_PAKE_STEP_ZK_PROOF,
                                                    output_buffer, size_zk_proof),
                                     ERR_INJECT_INVALID_FIRST_STEP)

        SETUP_ALWAYS_CHECK_STEP(psa_pake_input(&operation,
                                               PSA_PAKE_STEP_KEY_SHARE,
                                               output_buffer, size_key_share),
                                ERR_IN_INPUT);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
                                                    PSA_PAKE_STEP_ZK_PUBLIC,
                                                    output_buffer, size_zk_public + 1),
                                     ERR_INJECT_WRONG_BUFFER_SIZE);

        SETUP_CONDITIONAL_CHECK_STEP(
            (psa_pake_input(&operation, PSA_PAKE_STEP_ZK_PUBLIC,
                            output_buffer, size_zk_public + 1),
             psa_pake_input(&operation, PSA_PAKE_STEP_ZK_PUBLIC,
                            output_buffer, size_zk_public)),
            ERR_INJECT_VALID_OPERATION_AFTER_FAILURE);
    } else {
        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,
                                                     PSA_PAKE_STEP_ZK_PROOF,
                                                     output_buffer, 0,
                                                     &output_len),
                                     ERR_INJECT_EMPTY_IO_BUFFER);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,
                                                     PSA_PAKE_STEP_ZK_PROOF + 10,
                                                     output_buffer, buf_size, &output_len),
                                     ERR_INJECT_UNKNOWN_STEP);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,
                                                     PSA_PAKE_STEP_ZK_PROOF,
                                                     output_buffer, buf_size, &output_len),
                                     ERR_INJECT_INVALID_FIRST_STEP);

        SETUP_ALWAYS_CHECK_STEP(psa_pake_output(&operation,
                                                PSA_PAKE_STEP_KEY_SHARE,
                                                output_buffer, buf_size, &output_len),
                                ERR_IN_OUTPUT);

        TEST_ASSERT(output_len > 0);

        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,
                                                     PSA_PAKE_STEP_ZK_PUBLIC,
                                                     output_buffer, size_zk_public - 1,
                                                     &output_len),
                                     ERR_INJECT_WRONG_BUFFER_SIZE);

        SETUP_CONDITIONAL_CHECK_STEP(
            (psa_pake_output(&operation, PSA_PAKE_STEP_ZK_PUBLIC,
                             output_buffer, size_zk_public - 1, &output_len),
             psa_pake_output(&operation, PSA_PAKE_STEP_ZK_PUBLIC,
                             output_buffer, buf_size, &output_len)),
            ERR_INJECT_VALID_OPERATION_AFTER_FAILURE);
    }

exit:
    PSA_ASSERT(psa_destroy_key(key));
    PSA_ASSERT(psa_pake_abort(&operation));
    mbedtls_free(output_buffer);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void ecjpake_rounds_inject(int alg_arg, int primitive_arg, int hash_arg,
                           int client_input_first,
                           data_t *pw_data,
                           int err_stage_arg,
                           int expected_error_arg)
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t server = psa_pake_operation_init();
    psa_pake_operation_t client = psa_pake_operation_init();
    psa_algorithm_t alg = alg_arg;
    psa_algorithm_t hash_alg = hash_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    ecjpake_error_stage_t err_stage = err_stage_arg;
    const uint8_t server_id[] = PSA_JPAKE_SERVER_ID;
    const uint8_t client_id[] = PSA_JPAKE_CLIENT_ID;
    const size_t server_id_len = strlen(PSA_JPAKE_SERVER_ID);
    const size_t client_id_len = strlen(PSA_JPAKE_CLIENT_ID);

    PSA_INIT();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);

    PSA_ASSERT(psa_import_key(&attributes, pw_data->x, pw_data->len,
                              &key));

    psa_pake_cs_set_algorithm(&cipher_suite, alg);
    psa_pake_cs_set_primitive(&cipher_suite, primitive_arg);
    psa_pake_cs_set_hash(&cipher_suite, hash_alg);

    PSA_ASSERT(psa_pake_setup(&server, &cipher_suite));
    PSA_ASSERT(psa_pake_setup(&client, &cipher_suite));

    PSA_ASSERT(psa_pake_set_user(&server, server_id, server_id_len));
    PSA_ASSERT(psa_pake_set_peer(&server, client_id, client_id_len));
    PSA_ASSERT(psa_pake_set_user(&client, client_id, client_id_len));
    PSA_ASSERT(psa_pake_set_peer(&client, server_id, server_id_len));

    PSA_ASSERT(psa_pake_set_password_key(&server, key));
    PSA_ASSERT(psa_pake_set_password_key(&client, key));

    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, PAKE_ROUND_ONE,
                     err_stage, expected_error_arg);

    if (err_stage != ERR_NONE) {
        goto exit;
    }

    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, PAKE_ROUND_TWO,
                     err_stage, expected_error_arg);

exit:
    psa_destroy_key(key);
    psa_pake_abort(&server);
    psa_pake_abort(&client);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void ecjpake_rounds(int alg_arg, int primitive_arg, int hash_arg,
                    int derive_alg_arg, data_t *pw_data,
                    int client_input_first, int destroy_key,
                    int err_stage_arg)
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t server = psa_pake_operation_init();
    psa_pake_operation_t client = psa_pake_operation_init();
    psa_algorithm_t alg = alg_arg;
    psa_algorithm_t hash_alg = hash_arg;
    psa_algorithm_t derive_alg = derive_alg_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_key_derivation_operation_t server_derive =
        PSA_KEY_DERIVATION_OPERATION_INIT;
    psa_key_derivation_operation_t client_derive =
        PSA_KEY_DERIVATION_OPERATION_INIT;
    ecjpake_error_stage_t err_stage = err_stage_arg;
    const uint8_t server_id[] = PSA_JPAKE_SERVER_ID;
    const uint8_t client_id[] = PSA_JPAKE_CLIENT_ID;
    const size_t server_id_len = strlen(PSA_JPAKE_SERVER_ID);
    const size_t client_id_len = strlen(PSA_JPAKE_CLIENT_ID);

    PSA_INIT();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);
    PSA_ASSERT(psa_import_key(&attributes, pw_data->x, pw_data->len,
                              &key));

    psa_pake_cs_set_algorithm(&cipher_suite, alg);
    psa_pake_cs_set_primitive(&cipher_suite, primitive_arg);
    psa_pake_cs_set_hash(&cipher_suite, hash_alg);

    /* Get shared key */
    PSA_ASSERT(psa_key_derivation_setup(&server_derive, derive_alg));
    PSA_ASSERT(psa_key_derivation_setup(&client_derive, derive_alg));

    if (PSA_ALG_IS_TLS12_PRF(derive_alg) ||
        PSA_ALG_IS_TLS12_PSK_TO_MS(derive_alg)) {
        PSA_ASSERT(psa_key_derivation_input_bytes(&server_derive,
                                                  PSA_KEY_DERIVATION_INPUT_SEED,
                                                  (const uint8_t *) "", 0));
        PSA_ASSERT(psa_key_derivation_input_bytes(&client_derive,
                                                  PSA_KEY_DERIVATION_INPUT_SEED,
                                                  (const uint8_t *) "", 0));
    }

    PSA_ASSERT(psa_pake_setup(&server, &cipher_suite));
    PSA_ASSERT(psa_pake_setup(&client, &cipher_suite));

    PSA_ASSERT(psa_pake_set_user(&server, server_id, server_id_len));
    PSA_ASSERT(psa_pake_set_peer(&server, client_id, client_id_len));
    PSA_ASSERT(psa_pake_set_user(&client, client_id, client_id_len));
    PSA_ASSERT(psa_pake_set_peer(&client, server_id, server_id_len));

    PSA_ASSERT(psa_pake_set_password_key(&server, key));
    PSA_ASSERT(psa_pake_set_password_key(&client, key));

    if (destroy_key == 1) {
        psa_destroy_key(key);
    }

    if (err_stage == ERR_INJECT_ANTICIPATE_KEY_DERIVATION_1) {
        TEST_EQUAL(psa_pake_get_implicit_key(&server, &server_derive),
                   PSA_ERROR_BAD_STATE);
        TEST_EQUAL(psa_pake_get_implicit_key(&client, &client_derive),
                   PSA_ERROR_BAD_STATE);
        goto exit;
    }

    /* First round */
    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, PAKE_ROUND_ONE,
                     ERR_NONE, PSA_SUCCESS);

    if (err_stage == ERR_INJECT_ANTICIPATE_KEY_DERIVATION_2) {
        TEST_EQUAL(psa_pake_get_implicit_key(&server, &server_derive),
                   PSA_ERROR_BAD_STATE);
        TEST_EQUAL(psa_pake_get_implicit_key(&client, &client_derive),
                   PSA_ERROR_BAD_STATE);
        goto exit;
    }

    /* Second round */
    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, PAKE_ROUND_TWO,
                     ERR_NONE, PSA_SUCCESS);

    PSA_ASSERT(psa_pake_get_implicit_key(&server, &server_derive));
    PSA_ASSERT(psa_pake_get_implicit_key(&client, &client_derive));

exit:
    psa_key_derivation_abort(&server_derive);
    psa_key_derivation_abort(&client_derive);
    psa_destroy_key(key);
    psa_pake_abort(&server);
    psa_pake_abort(&client);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void ecjpake_size_macros()
{
    const psa_algorithm_t alg = PSA_ALG_JPAKE;
    const size_t bits = 256;
    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, bits);
    const psa_key_type_t key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(
        PSA_ECC_FAMILY_SECP_R1);

    // https://armmbed.github.io/mbed-crypto/1.1_PAKE_Extension.0-bet.0/html/pake.html#pake-step-types
    /* The output for KEY_SHARE and ZK_PUBLIC is the same as a public key */
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
               PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE(key_type, bits));
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
               PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE(key_type, bits));
    /* The output for ZK_PROOF is the same bitsize as the curve */
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
               PSA_BITS_TO_BYTES(bits));

    /* Input sizes are the same as output sizes */
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
               PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE));
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
               PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC));
    TEST_EQUAL(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
               PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF));

    /* These inequalities will always hold even when other PAKEs are added */
    TEST_LE_U(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
              PSA_PAKE_OUTPUT_MAX_SIZE);
    TEST_LE_U(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
              PSA_PAKE_OUTPUT_MAX_SIZE);
    TEST_LE_U(PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
              PSA_PAKE_OUTPUT_MAX_SIZE);
    TEST_LE_U(PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
              PSA_PAKE_INPUT_MAX_SIZE);
    TEST_LE_U(PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
              PSA_PAKE_INPUT_MAX_SIZE);
    TEST_LE_U(PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
              PSA_PAKE_INPUT_MAX_SIZE);
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void pake_input_getters_password()
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    const char *password = "password";
    uint8_t password_ret[20] = { 0 }; // max key length is 20 bytes
    size_t password_len_ret = 0;
    size_t buffer_len_ret = 0;

    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);

    PSA_INIT();

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, PSA_ALG_JPAKE);
    psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);

    PSA_ASSERT(psa_pake_setup(&operation, &cipher_suite));

    PSA_ASSERT(psa_import_key(&attributes, (uint8_t *) password, strlen(password), &key));

    TEST_EQUAL(psa_crypto_driver_pake_get_password(&operation.data.inputs,
                                                   (uint8_t *) &password_ret,
                                                   10, &buffer_len_ret),
               PSA_ERROR_BAD_STATE);

    TEST_EQUAL(psa_crypto_driver_pake_get_password_len(&operation.data.inputs, &password_len_ret),
               PSA_ERROR_BAD_STATE);

    PSA_ASSERT(psa_pake_set_password_key(&operation, key));

    TEST_EQUAL(psa_crypto_driver_pake_get_password_len(&operation.data.inputs, &password_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(password_len_ret, strlen(password));

    TEST_EQUAL(psa_crypto_driver_pake_get_password(&operation.data.inputs,
                                                   (uint8_t *) &password_ret,
                                                   password_len_ret - 1,
                                                   &buffer_len_ret),
               PSA_ERROR_BUFFER_TOO_SMALL);

    TEST_EQUAL(psa_crypto_driver_pake_get_password(&operation.data.inputs,
                                                   (uint8_t *) &password_ret,
                                                   password_len_ret,
                                                   &buffer_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(buffer_len_ret, strlen(password));
    PSA_ASSERT(memcmp(password_ret, password, buffer_len_ret));
exit:
    PSA_ASSERT(psa_destroy_key(key));
    PSA_ASSERT(psa_pake_abort(&operation));
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void pake_input_getters_cipher_suite()
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    psa_pake_cipher_suite_t cipher_suite_ret = psa_pake_cipher_suite_init();

    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);

    PSA_INIT();

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    TEST_EQUAL(psa_crypto_driver_pake_get_cipher_suite(&operation.data.inputs, &cipher_suite_ret),
               PSA_ERROR_BAD_STATE);

    PSA_ASSERT(psa_pake_setup(&operation, &cipher_suite));

    TEST_EQUAL(psa_crypto_driver_pake_get_cipher_suite(&operation.data.inputs, &cipher_suite_ret),
               PSA_SUCCESS);

    PSA_ASSERT(memcmp(&cipher_suite_ret, &cipher_suite, sizeof(cipher_suite)));

exit:
    PSA_ASSERT(psa_pake_abort(&operation));
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void pake_input_getters_role()
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    psa_pake_role_t role_ret = PSA_PAKE_ROLE_NONE;

    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);

    PSA_INIT();

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    PSA_ASSERT(psa_pake_setup(&operation, &cipher_suite));

    TEST_EQUAL(psa_crypto_driver_pake_get_role(&operation.data.inputs, &role_ret),
               PSA_ERROR_BAD_STATE);

    /* Role can not be set directly using psa_pake_set_role(). It is set by the core
       based on given user/peer. Simulate that Role is already set. */
    operation.data.inputs.role = PSA_PAKE_ROLE_SERVER;
    TEST_EQUAL(psa_crypto_driver_pake_get_role(&operation.data.inputs, &role_ret),
               PSA_SUCCESS);

    TEST_EQUAL(role_ret, PSA_PAKE_ROLE_SERVER);
exit:
    PSA_ASSERT(psa_pake_abort(&operation));
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_ALG_SHA_256 */
void pake_input_getters_user()
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    const uint8_t user[] = "server";
    const size_t user_len = strlen("server");
    uint8_t user_ret[20] = { 0 }; // max user length is 20 bytes
    size_t user_len_ret = 0;
    size_t buffer_len_ret = 0;

    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);

    PSA_INIT();

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    PSA_ASSERT(psa_pake_setup(&operation, &cipher_suite));

    TEST_EQUAL(psa_crypto_driver_pake_get_user(&operation.data.inputs,
                                               (uint8_t *) &user_ret,
                                               10, &buffer_len_ret),
               PSA_ERROR_BAD_STATE);

    TEST_EQUAL(psa_crypto_driver_pake_get_user_len(&operation.data.inputs, &user_len_ret),
               PSA_ERROR_BAD_STATE);

    PSA_ASSERT(psa_pake_set_user(&operation, user, user_len));

    TEST_EQUAL(psa_crypto_driver_pake_get_user_len(&operation.data.inputs, &user_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(user_len_ret, user_len);

    TEST_EQUAL(psa_crypto_driver_pake_get_user(&operation.data.inputs,
                                               (uint8_t *) &user_ret,
                                               user_len_ret - 1,
                                               &buffer_len_ret),
               PSA_ERROR_BUFFER_TOO_SMALL);

    TEST_EQUAL(psa_crypto_driver_pake_get_user(&operation.data.inputs,
                                                   (uint8_t *) &user_ret,
                                                   user_len_ret,
                                                   &buffer_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(buffer_len_ret, user_len);
    PSA_ASSERT(memcmp(user_ret, user, buffer_len_ret));
exit:
    PSA_ASSERT(psa_pake_abort(&operation));
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_ALG_SHA_256 */
void pake_input_getters_peer()
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t operation = psa_pake_operation_init();
    const uint8_t peer[] = "server";
    const size_t peer_len = strlen("server");
    uint8_t peer_ret[20] = { 0 }; // max peer length is 20 bytes
    size_t peer_len_ret = 0;
    size_t buffer_len_ret = 0;

    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);

    PSA_INIT();

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    PSA_ASSERT(psa_pake_setup(&operation, &cipher_suite));

    TEST_EQUAL(psa_crypto_driver_pake_get_peer(&operation.data.inputs,
                                               (uint8_t *) &peer_ret,
                                               10, &buffer_len_ret),
               PSA_ERROR_BAD_STATE);

    TEST_EQUAL(psa_crypto_driver_pake_get_peer_len(&operation.data.inputs, &peer_len_ret),
               PSA_ERROR_BAD_STATE);

    PSA_ASSERT(psa_pake_set_peer(&operation, peer, peer_len));

    TEST_EQUAL(psa_crypto_driver_pake_get_peer_len(&operation.data.inputs, &peer_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(peer_len_ret, peer_len);

    TEST_EQUAL(psa_crypto_driver_pake_get_peer(&operation.data.inputs,
                                               (uint8_t *) &peer_ret,
                                               peer_len_ret - 1,
                                               &buffer_len_ret),
               PSA_ERROR_BUFFER_TOO_SMALL);

    TEST_EQUAL(psa_crypto_driver_pake_get_peer(&operation.data.inputs,
                                                   (uint8_t *) &peer_ret,
                                                   peer_len_ret,
                                                   &buffer_len_ret),
               PSA_SUCCESS);

    TEST_EQUAL(buffer_len_ret, peer_len);
    PSA_ASSERT(memcmp(peer_ret, peer, buffer_len_ret));
exit:
    PSA_ASSERT(psa_pake_abort(&operation));
    PSA_DONE();
}
/* END_CASE */