/* BEGIN_HEADER */
#include "test/drivers/test_driver.h"

/* Auxiliary variables for pake tests.
   Global to silent the compiler when unused. */
size_t pake_expected_hit_count = 0;
int pake_in_driver = 0;
#if defined(PSA_WANT_ALG_JPAKE) && defined(PSA_WANT_KEY_TYPE_ECC_KEY_PAIR) && \
    defined(PSA_WANT_ECC_SECP_R1_256) && defined(PSA_WANT_ALG_SHA_256)
static void ecjpake_do_round(psa_algorithm_t alg, unsigned int primitive,
                             psa_pake_operation_t *server,
                             psa_pake_operation_t *client,
                             int client_input_first,
                             int round)
{
    unsigned char *buffer0 = NULL, *buffer1 = NULL;
    size_t buffer_length = (
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE) +
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC) +
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF)) * 2;
    /* The output should be exactly this size according to the spec */
    const size_t expected_size_key_share =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE);
    /* The output should be exactly this size according to the spec */
    const size_t expected_size_zk_public =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC);
    /* The output can be smaller: the spec allows stripping leading zeroes */
    const size_t max_expected_size_zk_proof =
        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF);
    size_t buffer0_off = 0;
    size_t buffer1_off = 0;
    size_t s_g1_len, s_g2_len, s_a_len;
    size_t s_g1_off, s_g2_off, s_a_off;
    size_t s_x1_pk_len, s_x2_pk_len, s_x2s_pk_len;
    size_t s_x1_pk_off, s_x2_pk_off, s_x2s_pk_off;
    size_t s_x1_pr_len, s_x2_pr_len, s_x2s_pr_len;
    size_t s_x1_pr_off, s_x2_pr_off, s_x2s_pr_off;
    size_t c_g1_len, c_g2_len, c_a_len;
    size_t c_g1_off, c_g2_off, c_a_off;
    size_t c_x1_pk_len, c_x2_pk_len, c_x2s_pk_len;
    size_t c_x1_pk_off, c_x2_pk_off, c_x2s_pk_off;
    size_t c_x1_pr_len, c_x2_pr_len, c_x2s_pr_len;
    size_t c_x1_pr_off, c_x2_pr_off, c_x2s_pr_off;
    psa_status_t status;

    ASSERT_ALLOC(buffer0, buffer_length);
    ASSERT_ALLOC(buffer1, buffer_length);

    switch (round) {
        case 1:
            /* Server first round Output */
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_g1_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_g1_len, expected_size_key_share);
            s_g1_off = buffer0_off;
            buffer0_off += s_g1_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x1_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_x1_pk_len, expected_size_zk_public);
            s_x1_pk_off = buffer0_off;
            buffer0_off += s_x1_pk_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x1_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(s_x1_pr_len, max_expected_size_zk_proof);
            s_x1_pr_off = buffer0_off;
            buffer0_off += s_x1_pr_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_g2_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_g2_len, expected_size_key_share);
            s_g2_off = buffer0_off;
            buffer0_off += s_g2_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_x2_pk_len, expected_size_zk_public);
            s_x2_pk_off = buffer0_off;
            buffer0_off += s_x2_pk_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(s_x2_pr_len, max_expected_size_zk_proof);
            s_x2_pr_off = buffer0_off;
            buffer0_off += s_x2_pr_len;

            if (client_input_first == 1) {
                /* Client first round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g1_off, s_g1_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x1_pk_off,
                                        s_x1_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x1_pr_off,
                                        s_x1_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g2_off,
                                        s_g2_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2_pk_off,
                                        s_x2_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2_pr_off,
                                        s_x2_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);
            }

            /* Adjust for indirect client driver setup in first pake_output call. */
            pake_expected_hit_count++;

            /* Client first round Output */
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_g1_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_g1_len, expected_size_key_share);
            c_g1_off = buffer1_off;
            buffer1_off += c_g1_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x1_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_x1_pk_len, expected_size_zk_public);
            c_x1_pk_off = buffer1_off;
            buffer1_off += c_x1_pk_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x1_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(c_x1_pr_len, max_expected_size_zk_proof);
            c_x1_pr_off = buffer1_off;
            buffer1_off += c_x1_pr_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_g2_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_g2_len, expected_size_key_share);
            c_g2_off = buffer1_off;
            buffer1_off += c_g2_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_x2_pk_len, expected_size_zk_public);
            c_x2_pk_off = buffer1_off;
            buffer1_off += c_x2_pk_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(c_x2_pr_len, max_expected_size_zk_proof);
            c_x2_pr_off = buffer1_off;
            buffer1_off += c_x2_pr_len;

            if (client_input_first == 0) {
                /* Client first round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g1_off, s_g1_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x1_pk_off,
                                        s_x1_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x1_pr_off,
                                        s_x1_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_g2_off,
                                        s_g2_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2_pk_off,
                                        s_x2_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2_pr_off,
                                        s_x2_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);
            }

            /* Server first round Input */
            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_g1_off, c_g1_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x1_pk_off, c_x1_pk_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x1_pr_off, c_x1_pr_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_g2_off, c_g2_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x2_pk_off, c_x2_pk_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x2_pr_off, c_x2_pr_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            break;

        case 2:
            /* Server second round Output */
            buffer0_off = 0;

            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_a_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_a_len, expected_size_key_share);
            s_a_off = buffer0_off;
            buffer0_off += s_a_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2s_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(s_x2s_pk_len, expected_size_zk_public);
            s_x2s_pk_off = buffer0_off;
            buffer0_off += s_x2s_pk_len;
            PSA_ASSERT(psa_pake_output(server, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer0 + buffer0_off,
                                       512 - buffer0_off, &s_x2s_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(s_x2s_pr_len, max_expected_size_zk_proof);
            s_x2s_pr_off = buffer0_off;
            buffer0_off += s_x2s_pr_len;

            if (client_input_first == 1) {
                /* Client second round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_a_off, s_a_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2s_pk_off,
                                        s_x2s_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2s_pr_off,
                                        s_x2s_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);
            }

            /* Client second round Output */
            buffer1_off = 0;

            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_a_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_a_len, expected_size_key_share);
            c_a_off = buffer1_off;
            buffer1_off += c_a_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2s_pk_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(c_x2s_pk_len, expected_size_zk_public);
            c_x2s_pk_off = buffer1_off;
            buffer1_off += c_x2s_pk_len;
            PSA_ASSERT(psa_pake_output(client, PSA_PAKE_STEP_ZK_PROOF,
                                       buffer1 + buffer1_off,
                                       512 - buffer1_off, &c_x2s_pr_len));
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_LE_U(c_x2s_pr_len, max_expected_size_zk_proof);
            c_x2s_pr_off = buffer1_off;
            buffer1_off += c_x2s_pr_len;

            if (client_input_first == 0) {
                /* Client second round Input */
                status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
                                        buffer0 + s_a_off, s_a_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PUBLIC,
                                        buffer0 + s_x2s_pk_off,
                                        s_x2s_pk_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);

                status = psa_pake_input(client, PSA_PAKE_STEP_ZK_PROOF,
                                        buffer0 + s_x2s_pr_off,
                                        s_x2s_pr_len);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                           pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
                TEST_EQUAL(status, PSA_SUCCESS);
            }

            /* Server second round Input */
            status = psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
                                    buffer1 + c_a_off, c_a_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PUBLIC,
                                    buffer1 + c_x2s_pk_off, c_x2s_pk_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            status = psa_pake_input(server, PSA_PAKE_STEP_ZK_PROOF,
                                    buffer1 + c_x2s_pr_off, c_x2s_pr_len);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
                       pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
            TEST_EQUAL(status, PSA_SUCCESS);

            break;
    }

exit:
    mbedtls_free(buffer0);
    mbedtls_free(buffer1);
}
#endif /* PSA_WANT_ALG_JPAKE */

#if defined(PSA_WANT_KEY_TYPE_RSA_PUBLIC_KEY)
/* Sanity checks on the output of RSA encryption.
 *
 * \param modulus               Key modulus. Must not have leading zeros.
 * \param private_exponent      Key private exponent.
 * \param alg                   An RSA algorithm.
 * \param input_data            The input plaintext.
 * \param buf                   The ciphertext produced by the driver.
 * \param length                Length of \p buf in bytes.
 */
static int sanity_check_rsa_encryption_result(
    psa_algorithm_t alg,
    const data_t *modulus, const data_t *private_exponent,
    const data_t *input_data,
    uint8_t *buf, size_t length)
{
#if defined(MBEDTLS_BIGNUM_C)
    mbedtls_mpi N, D, C, X;
    mbedtls_mpi_init(&N);
    mbedtls_mpi_init(&D);
    mbedtls_mpi_init(&C);
    mbedtls_mpi_init(&X);
#endif /* MBEDTLS_BIGNUM_C */

    int ok = 0;

    TEST_ASSERT(length == modulus->len);

#if defined(MBEDTLS_BIGNUM_C)
    /* Perform the private key operation */
    TEST_ASSERT(mbedtls_mpi_read_binary(&N, modulus->x, modulus->len) == 0);
    TEST_ASSERT(mbedtls_mpi_read_binary(&D,
                                        private_exponent->x,
                                        private_exponent->len) == 0);
    TEST_ASSERT(mbedtls_mpi_read_binary(&C, buf, length) == 0);
    TEST_ASSERT(mbedtls_mpi_exp_mod(&X, &C, &D, &N, NULL) == 0);

    /* Sanity checks on the padded plaintext */
    TEST_ASSERT(mbedtls_mpi_write_binary(&X, buf, length) == 0);

    if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
        TEST_ASSERT(length > input_data->len + 2);
        TEST_EQUAL(buf[0], 0x00);
        TEST_EQUAL(buf[1], 0x02);
        TEST_EQUAL(buf[length - input_data->len - 1], 0x00);
        ASSERT_COMPARE(buf + length - input_data->len, input_data->len,
                       input_data->x, input_data->len);
    } else if (PSA_ALG_IS_RSA_OAEP(alg)) {
        TEST_EQUAL(buf[0], 0x00);
        /* The rest is too hard to check */
    } else {
        TEST_ASSERT(!"Encryption result sanity check not implemented for RSA algorithm");
    }
#endif /* MBEDTLS_BIGNUM_C */

    ok = 1;

exit:
#if defined(MBEDTLS_BIGNUM_C)
    mbedtls_mpi_free(&N);
    mbedtls_mpi_free(&D);
    mbedtls_mpi_free(&C);
    mbedtls_mpi_free(&X);
#endif /* MBEDTLS_BIGNUM_C */
    return ok;
}
#endif
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C:MBEDTLS_PSA_CRYPTO_DRIVERS:PSA_CRYPTO_DRIVER_TEST
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void sign_hash(int key_type_arg,
               int alg_arg,
               int force_status_arg,
               data_t *key_input,
               data_t *data_input,
               data_t *expected_output,
               int fake_output,
               int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    psa_key_type_t key_type = key_type_arg;
    unsigned char *signature = NULL;
    size_t signature_size;
    size_t signature_length = 0xdeadbeef;
    psa_status_t actual_status;
    mbedtls_test_driver_signature_sign_hooks =
        mbedtls_test_driver_signature_hooks_init();

    PSA_ASSERT(psa_crypto_init());
    psa_set_key_type(&attributes,
                     key_type);
    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
    psa_set_key_algorithm(&attributes, alg);
    psa_import_key(&attributes,
                   key_input->x, key_input->len,
                   &key);

    mbedtls_test_driver_signature_sign_hooks.forced_status = force_status;
    if (fake_output == 1) {
        mbedtls_test_driver_signature_sign_hooks.forced_output =
            expected_output->x;
        mbedtls_test_driver_signature_sign_hooks.forced_output_length =
            expected_output->len;
    }

    /* Allocate a buffer which has the size advertized by the
     * library. */
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);
    signature_size = PSA_SIGN_OUTPUT_SIZE(key_type, key_bits, alg);

    TEST_ASSERT(signature_size != 0);
    TEST_ASSERT(signature_size <= PSA_SIGNATURE_MAX_SIZE);
    ASSERT_ALLOC(signature, signature_size);

    actual_status = psa_sign_hash(key, alg,
                                  data_input->x, data_input->len,
                                  signature, signature_size,
                                  &signature_length);
    TEST_EQUAL(actual_status, expected_status);
    if (expected_status == PSA_SUCCESS) {
        ASSERT_COMPARE(signature, signature_length,
                       expected_output->x, expected_output->len);
    }
    TEST_EQUAL(mbedtls_test_driver_signature_sign_hooks.hits, 1);

exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    mbedtls_free(signature);
    PSA_DONE();
    mbedtls_test_driver_signature_sign_hooks =
        mbedtls_test_driver_signature_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void verify_hash(int key_type_arg,
                 int key_type_public_arg,
                 int alg_arg,
                 int force_status_arg,
                 int register_public_key,
                 data_t *key_input,
                 data_t *data_input,
                 data_t *signature_input,
                 int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_key_type_t key_type_public = key_type_public_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t actual_status;
    mbedtls_test_driver_signature_verify_hooks =
        mbedtls_test_driver_signature_hooks_init();

    PSA_ASSERT(psa_crypto_init());
    if (register_public_key) {
        psa_set_key_type(&attributes, key_type_public);
        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
        psa_set_key_algorithm(&attributes, alg);
        psa_import_key(&attributes,
                       key_input->x, key_input->len,
                       &key);
    } else {
        psa_set_key_type(&attributes, key_type);
        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
        psa_set_key_algorithm(&attributes, alg);
        psa_import_key(&attributes,
                       key_input->x, key_input->len,
                       &key);
    }

    mbedtls_test_driver_signature_verify_hooks.forced_status = force_status;

    actual_status = psa_verify_hash(key, alg,
                                    data_input->x, data_input->len,
                                    signature_input->x, signature_input->len);
    TEST_EQUAL(actual_status, expected_status);
    TEST_EQUAL(mbedtls_test_driver_signature_verify_hooks.hits, 1);

exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_signature_verify_hooks =
        mbedtls_test_driver_signature_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void sign_message(int key_type_arg,
                  int alg_arg,
                  int force_status_arg,
                  data_t *key_input,
                  data_t *data_input,
                  data_t *expected_output,
                  int fake_output,
                  int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    psa_key_type_t key_type = key_type_arg;
    unsigned char *signature = NULL;
    size_t signature_size;
    size_t signature_length = 0xdeadbeef;
    psa_status_t actual_status;
    mbedtls_test_driver_signature_sign_hooks =
        mbedtls_test_driver_signature_hooks_init();

    PSA_ASSERT(psa_crypto_init());
    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_MESSAGE);
    psa_set_key_algorithm(&attributes, alg);
    psa_import_key(&attributes,
                   key_input->x, key_input->len,
                   &key);

    mbedtls_test_driver_signature_sign_hooks.forced_status = force_status;
    if (fake_output == 1) {
        mbedtls_test_driver_signature_sign_hooks.forced_output =
            expected_output->x;
        mbedtls_test_driver_signature_sign_hooks.forced_output_length =
            expected_output->len;
    }

    /* Allocate a buffer which has the size advertized by the
     * library. */
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);
    signature_size = PSA_SIGN_OUTPUT_SIZE(key_type, key_bits, alg);

    TEST_ASSERT(signature_size != 0);
    TEST_ASSERT(signature_size <= PSA_SIGNATURE_MAX_SIZE);
    ASSERT_ALLOC(signature, signature_size);

    actual_status = psa_sign_message(key, alg,
                                     data_input->x, data_input->len,
                                     signature, signature_size,
                                     &signature_length);
    TEST_EQUAL(actual_status, expected_status);
    if (expected_status == PSA_SUCCESS) {
        ASSERT_COMPARE(signature, signature_length,
                       expected_output->x, expected_output->len);
    }
    /* In the builtin algorithm the driver is called twice. */
    TEST_EQUAL(mbedtls_test_driver_signature_sign_hooks.hits,
               force_status == PSA_ERROR_NOT_SUPPORTED ? 2 : 1);

exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    mbedtls_free(signature);
    PSA_DONE();
    mbedtls_test_driver_signature_sign_hooks =
        mbedtls_test_driver_signature_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void verify_message(int key_type_arg,
                    int key_type_public_arg,
                    int alg_arg,
                    int force_status_arg,
                    int register_public_key,
                    data_t *key_input,
                    data_t *data_input,
                    data_t *signature_input,
                    int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_key_type_t key_type_public = key_type_public_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t actual_status;
    mbedtls_test_driver_signature_verify_hooks =
        mbedtls_test_driver_signature_hooks_init();

    PSA_ASSERT(psa_crypto_init());
    if (register_public_key) {
        psa_set_key_type(&attributes, key_type_public);
        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_MESSAGE);
        psa_set_key_algorithm(&attributes, alg);
        psa_import_key(&attributes,
                       key_input->x, key_input->len,
                       &key);
    } else {
        psa_set_key_type(&attributes, key_type);
        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_MESSAGE);
        psa_set_key_algorithm(&attributes, alg);
        psa_import_key(&attributes,
                       key_input->x, key_input->len,
                       &key);
    }

    mbedtls_test_driver_signature_verify_hooks.forced_status = force_status;

    actual_status = psa_verify_message(key, alg,
                                       data_input->x, data_input->len,
                                       signature_input->x, signature_input->len);
    TEST_EQUAL(actual_status, expected_status);
    /* In the builtin algorithm the driver is called twice. */
    TEST_EQUAL(mbedtls_test_driver_signature_verify_hooks.hits,
               force_status == PSA_ERROR_NOT_SUPPORTED ? 2 : 1);

exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_signature_verify_hooks =
        mbedtls_test_driver_signature_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_ECDSA:PSA_WANT_ECC_SECP_R1_256 */
void generate_key(int force_status_arg,
                  data_t *fake_output,
                  int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_algorithm_t alg = PSA_ALG_ECDSA(PSA_ALG_SHA_256);
    const uint8_t *expected_output = NULL;
    size_t expected_output_length = 0;
    psa_status_t actual_status;
    uint8_t actual_output[PSA_KEY_EXPORT_ECC_KEY_PAIR_MAX_SIZE(256)] = { 0 };
    size_t actual_output_length;
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();

    psa_set_key_type(&attributes,
                     PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1));
    psa_set_key_bits(&attributes, 256);
    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_EXPORT);
    psa_set_key_algorithm(&attributes, alg);

    if (fake_output->len > 0) {
        expected_output =
            mbedtls_test_driver_key_management_hooks.forced_output =
                fake_output->x;

        expected_output_length =
            mbedtls_test_driver_key_management_hooks.forced_output_length =
                fake_output->len;
    }

    mbedtls_test_driver_key_management_hooks.hits = 0;
    mbedtls_test_driver_key_management_hooks.forced_status = force_status;

    PSA_ASSERT(psa_crypto_init());

    actual_status = psa_generate_key(&attributes, &key);
    TEST_EQUAL(mbedtls_test_driver_key_management_hooks.hits, 1);
    TEST_EQUAL(actual_status, expected_status);

    if (actual_status == PSA_SUCCESS) {
        psa_export_key(key, actual_output, sizeof(actual_output), &actual_output_length);

        if (fake_output->len > 0) {
            ASSERT_COMPARE(actual_output, actual_output_length,
                           expected_output, expected_output_length);
        } else {
            size_t zeroes = 0;
            for (size_t i = 0; i < sizeof(actual_output); i++) {
                if (actual_output[i] == 0) {
                    zeroes++;
                }
            }
            TEST_ASSERT(zeroes != sizeof(actual_output));
        }
    }
exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void validate_key(int force_status_arg,
                  int location,
                  int owner_id_arg,
                  int id_arg,
                  int key_type_arg,
                  data_t *key_input,
                  int expected_status_arg)
{
    psa_key_lifetime_t lifetime =
        PSA_KEY_LIFETIME_FROM_PERSISTENCE_AND_LOCATION( \
            PSA_KEY_PERSISTENCE_DEFAULT, location);
    mbedtls_svc_key_id_t id = mbedtls_svc_key_id_make(owner_id_arg, id_arg);
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t actual_status;
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();

    psa_set_key_id(&attributes, id);
    psa_set_key_type(&attributes,
                     key_type);
    psa_set_key_lifetime(&attributes, lifetime);
    psa_set_key_bits(&attributes, 0);
    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_EXPORT);

    mbedtls_test_driver_key_management_hooks.forced_status = force_status;

    PSA_ASSERT(psa_crypto_init());

    actual_status = psa_import_key(&attributes, key_input->x, key_input->len, &key);
    TEST_EQUAL(mbedtls_test_driver_key_management_hooks.hits, 1);
    TEST_EQUAL(actual_status, expected_status);
    TEST_EQUAL(mbedtls_test_driver_key_management_hooks.location, location);
exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void export_key(int force_status_arg,
                data_t *fake_output,
                int key_in_type_arg,
                data_t *key_in,
                int key_out_type_arg,
                data_t *expected_output,
                int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_key_handle_t handle = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_key_type_t input_key_type = key_in_type_arg;
    psa_key_type_t output_key_type = key_out_type_arg;
    const uint8_t *expected_output_ptr = NULL;
    size_t expected_output_length = 0;
    psa_status_t actual_status;
    uint8_t actual_output[PSA_KEY_EXPORT_ECC_PUBLIC_KEY_MAX_SIZE(256)] = { 0 };
    size_t actual_output_length;
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();

    psa_set_key_type(&attributes, input_key_type);
    psa_set_key_bits(&attributes, 256);
    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_EXPORT);

    PSA_ASSERT(psa_crypto_init());
    PSA_ASSERT(psa_import_key(&attributes, key_in->x, key_in->len, &handle));

    if (fake_output->len > 0) {
        expected_output_ptr =
            mbedtls_test_driver_key_management_hooks.forced_output =
                fake_output->x;

        expected_output_length =
            mbedtls_test_driver_key_management_hooks.forced_output_length =
                fake_output->len;
    } else {
        expected_output_ptr = expected_output->x;
        expected_output_length = expected_output->len;
    }

    mbedtls_test_driver_key_management_hooks.hits = 0;
    mbedtls_test_driver_key_management_hooks.forced_status = force_status;

    if (PSA_KEY_TYPE_IS_ECC_PUBLIC_KEY(output_key_type)) {
        actual_status = psa_export_public_key(handle,
                                              actual_output,
                                              sizeof(actual_output),
                                              &actual_output_length);
    } else {
        actual_status = psa_export_key(handle,
                                       actual_output,
                                       sizeof(actual_output),
                                       &actual_output_length);
    }
    TEST_EQUAL(actual_status, expected_status);

    if (PSA_KEY_TYPE_IS_ECC_PUBLIC_KEY(output_key_type) &&
        !PSA_KEY_TYPE_IS_ECC_PUBLIC_KEY(input_key_type)) {
        TEST_EQUAL(mbedtls_test_driver_key_management_hooks.hits, 1);
    }

    if (actual_status == PSA_SUCCESS) {
        ASSERT_COMPARE(actual_output, actual_output_length,
                       expected_output_ptr, expected_output_length);
    }
exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(handle);
    PSA_DONE();
    mbedtls_test_driver_key_management_hooks =
        mbedtls_test_driver_key_management_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void key_agreement(int alg_arg,
                   int force_status_arg,
                   int our_key_type_arg,
                   data_t *our_key_data,
                   data_t *peer_key_data,
                   data_t *expected_output,
                   data_t *fake_output,
                   int expected_status_arg)
{
    psa_status_t force_status = force_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_type_t our_key_type = our_key_type_arg;
    mbedtls_svc_key_id_t our_key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    const uint8_t *expected_output_ptr = NULL;
    size_t expected_output_length = 0;
    unsigned char *actual_output = NULL;
    size_t actual_output_length = ~0;
    size_t key_bits;
    psa_status_t actual_status;
    mbedtls_test_driver_key_agreement_hooks =
        mbedtls_test_driver_key_agreement_hooks_init();

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, our_key_type);
    PSA_ASSERT(psa_import_key(&attributes,
                              our_key_data->x, our_key_data->len,
                              &our_key));

    PSA_ASSERT(psa_get_key_attributes(our_key, &attributes));
    key_bits = psa_get_key_bits(&attributes);

    TEST_LE_U(expected_output->len,
              PSA_RAW_KEY_AGREEMENT_OUTPUT_SIZE(our_key_type, key_bits));
    TEST_LE_U(PSA_RAW_KEY_AGREEMENT_OUTPUT_SIZE(our_key_type, key_bits),
              PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE);

    if (fake_output->len > 0) {
        expected_output_ptr =
            mbedtls_test_driver_key_agreement_hooks.forced_output =
                fake_output->x;

        expected_output_length =
            mbedtls_test_driver_key_agreement_hooks.forced_output_length =
                fake_output->len;
    } else {
        expected_output_ptr = expected_output->x;
        expected_output_length = expected_output->len;
    }

    mbedtls_test_driver_key_agreement_hooks.hits = 0;
    mbedtls_test_driver_key_agreement_hooks.forced_status = force_status;

    ASSERT_ALLOC(actual_output, expected_output->len);
    actual_status = psa_raw_key_agreement(alg, our_key,
                                          peer_key_data->x, peer_key_data->len,
                                          actual_output, expected_output->len,
                                          &actual_output_length);
    TEST_EQUAL(actual_status, expected_status);
    TEST_EQUAL(mbedtls_test_driver_key_agreement_hooks.hits, 1);

    if (actual_status == PSA_SUCCESS) {
        ASSERT_COMPARE(actual_output, actual_output_length,
                       expected_output_ptr, expected_output_length);
    }
    mbedtls_free(actual_output);
    actual_output = NULL;
    actual_output_length = ~0;

exit:
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(our_key);
    PSA_DONE();
    mbedtls_test_driver_key_agreement_hooks =
        mbedtls_test_driver_key_agreement_hooks_init();
}

/* END_CASE */

/* BEGIN_CASE */
void cipher_encrypt_validation(int alg_arg,
                               int key_type_arg,
                               data_t *key_data,
                               data_t *input)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    size_t iv_size = PSA_CIPHER_IV_LENGTH(key_type, alg);
    unsigned char *output1 = NULL;
    size_t output1_buffer_size = 0;
    size_t output1_length = 0;
    unsigned char *output2 = NULL;
    size_t output2_buffer_size = 0;
    size_t output2_length = 0;
    size_t function_output_length = 0;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    output1_buffer_size = PSA_CIPHER_ENCRYPT_OUTPUT_SIZE(key_type, alg, input->len);
    output2_buffer_size = PSA_CIPHER_UPDATE_OUTPUT_SIZE(key_type, alg, input->len) +
                          PSA_CIPHER_FINISH_OUTPUT_SIZE(key_type, alg);
    ASSERT_ALLOC(output1, output1_buffer_size);
    ASSERT_ALLOC(output2, output2_buffer_size);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    PSA_ASSERT(psa_cipher_encrypt(key, alg, input->x, input->len, output1,
                                  output1_buffer_size, &output1_length));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    PSA_ASSERT(psa_cipher_encrypt_setup(&operation, key, alg));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    PSA_ASSERT(psa_cipher_set_iv(&operation, output1, iv_size));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    PSA_ASSERT(psa_cipher_update(&operation,
                                 input->x, input->len,
                                 output2, output2_buffer_size,
                                 &function_output_length));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    output2_length += function_output_length;
    PSA_ASSERT(psa_cipher_finish(&operation,
                                 output2 + output2_length,
                                 output2_buffer_size - output2_length,
                                 &function_output_length));
    /* Finish will have called abort as well, so expecting two hits here */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 2);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    output2_length += function_output_length;

    PSA_ASSERT(psa_cipher_abort(&operation));
    // driver function should've been called as part of the finish() core routine
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);
    ASSERT_COMPARE(output1 + iv_size, output1_length - iv_size,
                   output2, output2_length);

exit:
    psa_cipher_abort(&operation);
    mbedtls_free(output1);
    mbedtls_free(output2);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_encrypt_multipart(int alg_arg,
                              int key_type_arg,
                              data_t *key_data,
                              data_t *iv,
                              data_t *input,
                              int first_part_size_arg,
                              int output1_length_arg,
                              int output2_length_arg,
                              data_t *expected_output,
                              int mock_output_arg,
                              int force_status_arg,
                              int expected_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_status_t status;
    psa_status_t expected_status = expected_status_arg;
    psa_status_t force_status = force_status_arg;
    size_t first_part_size = first_part_size_arg;
    size_t output1_length = output1_length_arg;
    size_t output2_length = output2_length_arg;
    unsigned char *output = NULL;
    size_t output_buffer_size = 0;
    size_t function_output_length = 0;
    size_t total_output_length = 0;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
    mbedtls_test_driver_cipher_hooks.forced_status = force_status;

    /* Test operation initialization */
    mbedtls_psa_cipher_operation_t mbedtls_operation =
        MBEDTLS_PSA_CIPHER_OPERATION_INIT;

    mbedtls_transparent_test_driver_cipher_operation_t transparent_operation =
        MBEDTLS_TRANSPARENT_TEST_DRIVER_CIPHER_OPERATION_INIT;

    mbedtls_opaque_test_driver_cipher_operation_t opaque_operation =
        MBEDTLS_OPAQUE_TEST_DRIVER_CIPHER_OPERATION_INIT;

    operation.ctx.mbedtls_ctx = mbedtls_operation;
    operation.ctx.transparent_test_driver_ctx = transparent_operation;
    operation.ctx.opaque_test_driver_ctx = opaque_operation;

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    PSA_ASSERT(psa_cipher_encrypt_setup(&operation, key, alg));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    PSA_ASSERT(psa_cipher_set_iv(&operation, iv->x, iv->len));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 1 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;

    output_buffer_size = ((size_t) input->len +
                          PSA_BLOCK_CIPHER_BLOCK_LENGTH(key_type));
    ASSERT_ALLOC(output, output_buffer_size);

    if (mock_output_arg) {
        mbedtls_test_driver_cipher_hooks.forced_output = expected_output->x;
        mbedtls_test_driver_cipher_hooks.forced_output_length = expected_output->len;
    }

    TEST_ASSERT(first_part_size <= input->len);
    PSA_ASSERT(psa_cipher_update(&operation, input->x, first_part_size,
                                 output, output_buffer_size,
                                 &function_output_length));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 1 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;

    TEST_ASSERT(function_output_length == output1_length);
    total_output_length += function_output_length;

    if (first_part_size < input->len) {
        PSA_ASSERT(psa_cipher_update(&operation,
                                     input->x + first_part_size,
                                     input->len - first_part_size,
                                     output + total_output_length,
                                     output_buffer_size - total_output_length,
                                     &function_output_length));
        TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
        mbedtls_test_driver_cipher_hooks.hits = 0;

        TEST_ASSERT(function_output_length == output2_length);
        total_output_length += function_output_length;
    }

    if (mock_output_arg) {
        mbedtls_test_driver_cipher_hooks.forced_output = NULL;
        mbedtls_test_driver_cipher_hooks.forced_output_length = 0;
    }

    status =  psa_cipher_finish(&operation,
                                output + total_output_length,
                                output_buffer_size - total_output_length,
                                &function_output_length);
    /* Finish will have called abort as well, so expecting two hits here */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 2 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;
    total_output_length += function_output_length;
    TEST_EQUAL(status, expected_status);

    if (expected_status == PSA_SUCCESS) {
        PSA_ASSERT(psa_cipher_abort(&operation));
        TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);

        ASSERT_COMPARE(expected_output->x, expected_output->len,
                       output, total_output_length);
    }

exit:
    psa_cipher_abort(&operation);
    mbedtls_free(output);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_decrypt_multipart(int alg_arg,
                              int key_type_arg,
                              data_t *key_data,
                              data_t *iv,
                              data_t *input,
                              int first_part_size_arg,
                              int output1_length_arg,
                              int output2_length_arg,
                              data_t *expected_output,
                              int mock_output_arg,
                              int force_status_arg,
                              int expected_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_status_t status;
    psa_status_t expected_status = expected_status_arg;
    psa_status_t force_status = force_status_arg;
    size_t first_part_size = first_part_size_arg;
    size_t output1_length = output1_length_arg;
    size_t output2_length = output2_length_arg;
    unsigned char *output = NULL;
    size_t output_buffer_size = 0;
    size_t function_output_length = 0;
    size_t total_output_length = 0;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
    mbedtls_test_driver_cipher_hooks.forced_status = force_status;

    /* Test operation initialization */
    mbedtls_psa_cipher_operation_t mbedtls_operation =
        MBEDTLS_PSA_CIPHER_OPERATION_INIT;

    mbedtls_transparent_test_driver_cipher_operation_t transparent_operation =
        MBEDTLS_TRANSPARENT_TEST_DRIVER_CIPHER_OPERATION_INIT;

    mbedtls_opaque_test_driver_cipher_operation_t opaque_operation =
        MBEDTLS_OPAQUE_TEST_DRIVER_CIPHER_OPERATION_INIT;

    operation.ctx.mbedtls_ctx = mbedtls_operation;
    operation.ctx.transparent_test_driver_ctx = transparent_operation;
    operation.ctx.opaque_test_driver_ctx = opaque_operation;

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    PSA_ASSERT(psa_cipher_decrypt_setup(&operation, key, alg));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    PSA_ASSERT(psa_cipher_set_iv(&operation, iv->x, iv->len));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 1 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;

    output_buffer_size = ((size_t) input->len +
                          PSA_BLOCK_CIPHER_BLOCK_LENGTH(key_type));
    ASSERT_ALLOC(output, output_buffer_size);

    if (mock_output_arg) {
        mbedtls_test_driver_cipher_hooks.forced_output = expected_output->x;
        mbedtls_test_driver_cipher_hooks.forced_output_length = expected_output->len;
    }

    TEST_ASSERT(first_part_size <= input->len);
    PSA_ASSERT(psa_cipher_update(&operation,
                                 input->x, first_part_size,
                                 output, output_buffer_size,
                                 &function_output_length));
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 1 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;

    TEST_ASSERT(function_output_length == output1_length);
    total_output_length += function_output_length;

    if (first_part_size < input->len) {
        PSA_ASSERT(psa_cipher_update(&operation,
                                     input->x + first_part_size,
                                     input->len - first_part_size,
                                     output + total_output_length,
                                     output_buffer_size - total_output_length,
                                     &function_output_length));
        TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 1 : 0));
        mbedtls_test_driver_cipher_hooks.hits = 0;

        TEST_ASSERT(function_output_length == output2_length);
        total_output_length += function_output_length;
    }

    if (mock_output_arg) {
        mbedtls_test_driver_cipher_hooks.forced_output = NULL;
        mbedtls_test_driver_cipher_hooks.forced_output_length = 0;
    }

    status = psa_cipher_finish(&operation,
                               output + total_output_length,
                               output_buffer_size - total_output_length,
                               &function_output_length);
    /* Finish will have called abort as well, so expecting two hits here */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, (force_status == PSA_SUCCESS ? 2 : 0));
    mbedtls_test_driver_cipher_hooks.hits = 0;
    total_output_length += function_output_length;
    TEST_EQUAL(status, expected_status);

    if (expected_status == PSA_SUCCESS) {
        PSA_ASSERT(psa_cipher_abort(&operation));
        TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);

        ASSERT_COMPARE(expected_output->x, expected_output->len,
                       output, total_output_length);
    }

exit:
    psa_cipher_abort(&operation);
    mbedtls_free(output);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_decrypt(int alg_arg,
                    int key_type_arg,
                    data_t *key_data,
                    data_t *iv,
                    data_t *input_arg,
                    data_t *expected_output,
                    int mock_output_arg,
                    int force_status_arg,
                    int expected_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_status_t status;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_status_t force_status = force_status_arg;
    unsigned char *input = NULL;
    size_t input_buffer_size = 0;
    unsigned char *output = NULL;
    size_t output_buffer_size = 0;
    size_t output_length = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
    mbedtls_test_driver_cipher_hooks.forced_status = force_status;

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    /* Allocate input buffer and copy the iv and the plaintext */
    input_buffer_size = ((size_t) input_arg->len + (size_t) iv->len);
    if (input_buffer_size > 0) {
        ASSERT_ALLOC(input, input_buffer_size);
        memcpy(input, iv->x, iv->len);
        memcpy(input + iv->len, input_arg->x, input_arg->len);
    }

    output_buffer_size = PSA_CIPHER_DECRYPT_OUTPUT_SIZE(key_type, alg, input_buffer_size);
    ASSERT_ALLOC(output, output_buffer_size);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    if (mock_output_arg) {
        mbedtls_test_driver_cipher_hooks.forced_output = expected_output->x;
        mbedtls_test_driver_cipher_hooks.forced_output_length = expected_output->len;
    }

    status = psa_cipher_decrypt(key, alg, input, input_buffer_size, output,
                                output_buffer_size, &output_length);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    TEST_EQUAL(status, expected_status);

    if (expected_status == PSA_SUCCESS) {
        ASSERT_COMPARE(expected_output->x, expected_output->len,
                       output, output_length);
    }

exit:
    mbedtls_free(input);
    mbedtls_free(output);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_entry_points(int alg_arg, int key_type_arg,
                         data_t *key_data, data_t *iv,
                         data_t *input)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_status_t status;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    unsigned char *output = NULL;
    size_t output_buffer_size = 0;
    size_t function_output_length = 0;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();

    ASSERT_ALLOC(output, input->len + 16);
    output_buffer_size = input->len + 16;

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    /*
     * Test encrypt failure
     * First test that if we don't force a driver error, encryption is
     * successful, then force driver error.
     */
    status = psa_cipher_encrypt(
        key, alg, input->x, input->len,
        output, output_buffer_size, &function_output_length);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, PSA_SUCCESS);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    /* Set the output buffer in a given state. */
    for (size_t i = 0; i < output_buffer_size; i++) {
        output[i] = 0xa5;
    }

    status = psa_cipher_encrypt(
        key, alg, input->x, input->len,
        output, output_buffer_size, &function_output_length);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, PSA_ERROR_GENERIC_ERROR);
    /*
     * Check that the output buffer is still in the same state.
     * This will fail if the output buffer is used by the core to pass the IV
     * it generated to the driver (and is not restored).
     */
    for (size_t i = 0; i < output_buffer_size; i++) {
        TEST_EQUAL(output[i], 0xa5);
    }
    mbedtls_test_driver_cipher_hooks.hits = 0;

    /* Test setup call, encrypt */
    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    status = psa_cipher_encrypt_setup(&operation, key, alg);
    /* When setup fails, it shouldn't call any further entry points */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_set_iv(&operation, iv->x, iv->len);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);

    /* Test setup call failure, decrypt */
    status = psa_cipher_decrypt_setup(&operation, key, alg);
    /* When setup fails, it shouldn't call any further entry points */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_set_iv(&operation, iv->x, iv->len);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);

    /* Test IV setting failure */
    mbedtls_test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
    status = psa_cipher_encrypt_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    status = psa_cipher_set_iv(&operation, iv->x, iv->len);
    /* When setting the IV fails, it should call abort too */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 2);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    /* Failure should prevent further operations from executing on the driver */
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);
    psa_cipher_abort(&operation);

    /* Test IV generation failure */
    mbedtls_test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
    status = psa_cipher_encrypt_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    /* Set the output buffer in a given state. */
    for (size_t i = 0; i < 16; i++) {
        output[i] = 0xa5;
    }

    status = psa_cipher_generate_iv(&operation, output, 16, &function_output_length);
    /* When generating the IV fails, it should call abort too */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 2);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    /*
     * Check that the output buffer is still in the same state.
     * This will fail if the output buffer is used by the core to pass the IV
     * it generated to the driver (and is not restored).
     */
    for (size_t i = 0; i < 16; i++) {
        TEST_EQUAL(output[i], 0xa5);
    }
    /* Failure should prevent further operations from executing on the driver */
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);
    psa_cipher_abort(&operation);

    /* Test update failure */
    mbedtls_test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
    status = psa_cipher_encrypt_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    status = psa_cipher_set_iv(&operation, iv->x, iv->len);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    /* When the update call fails, it should call abort too */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 2);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    /* Failure should prevent further operations from executing on the driver */
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);
    psa_cipher_abort(&operation);

    /* Test finish failure */
    mbedtls_test_driver_cipher_hooks.forced_status = PSA_SUCCESS;
    status = psa_cipher_encrypt_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    status = psa_cipher_set_iv(&operation, iv->x, iv->len);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    mbedtls_test_driver_cipher_hooks.hits = 0;

    mbedtls_test_driver_cipher_hooks.forced_status = PSA_ERROR_GENERIC_ERROR;
    status = psa_cipher_finish(&operation,
                               output + function_output_length,
                               output_buffer_size - function_output_length,
                               &function_output_length);
    /* When the finish call fails, it should call abort too */
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 2);
    TEST_EQUAL(status, mbedtls_test_driver_cipher_hooks.forced_status);
    /* Failure should prevent further operations from executing on the driver */
    mbedtls_test_driver_cipher_hooks.hits = 0;
    status = psa_cipher_update(&operation,
                               input->x, input->len,
                               output, output_buffer_size,
                               &function_output_length);
    TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 0);
    psa_cipher_abort(&operation);

exit:
    psa_cipher_abort(&operation);
    mbedtls_free(output);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_cipher_hooks = mbedtls_test_driver_cipher_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void aead_encrypt(int key_type_arg, data_t *key_data,
                  int alg_arg,
                  data_t *nonce,
                  data_t *additional_data,
                  data_t *input_data,
                  data_t *expected_result,
                  int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    psa_status_t forced_status = forced_status_arg;
    unsigned char *output_data = NULL;
    size_t output_size = 0;
    size_t output_length = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);

    output_size = input_data->len + PSA_AEAD_TAG_LENGTH(key_type, key_bits,
                                                        alg);
    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
     * should be exact. */
    TEST_EQUAL(output_size,
               PSA_AEAD_ENCRYPT_OUTPUT_SIZE(key_type, alg, input_data->len));
    TEST_ASSERT(output_size <=
                PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE(input_data->len));
    ASSERT_ALLOC(output_data, output_size);

    mbedtls_test_driver_aead_hooks.forced_status = forced_status;
    status = psa_aead_encrypt(key, alg,
                              nonce->x, nonce->len,
                              additional_data->x, additional_data->len,
                              input_data->x, input_data->len,
                              output_data, output_size,
                              &output_length);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_encrypt, 1);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.driver_status, forced_status);

    TEST_EQUAL(status, (forced_status == PSA_ERROR_NOT_SUPPORTED) ?
               PSA_SUCCESS : forced_status);

    if (status == PSA_SUCCESS) {
        ASSERT_COMPARE(expected_result->x, expected_result->len,
                       output_data, output_length);
    }

exit:
    psa_destroy_key(key);
    mbedtls_free(output_data);
    PSA_DONE();
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void aead_decrypt(int key_type_arg, data_t *key_data,
                  int alg_arg,
                  data_t *nonce,
                  data_t *additional_data,
                  data_t *input_data,
                  data_t *expected_data,
                  int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    psa_status_t forced_status = forced_status_arg;
    unsigned char *output_data = NULL;
    size_t output_size = 0;
    size_t output_length = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);

    output_size = input_data->len - PSA_AEAD_TAG_LENGTH(key_type, key_bits,
                                                        alg);
    ASSERT_ALLOC(output_data, output_size);

    mbedtls_test_driver_aead_hooks.forced_status = forced_status;
    status = psa_aead_decrypt(key, alg,
                              nonce->x, nonce->len,
                              additional_data->x,
                              additional_data->len,
                              input_data->x, input_data->len,
                              output_data, output_size,
                              &output_length);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_decrypt, 1);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.driver_status, forced_status);

    TEST_EQUAL(status, (forced_status == PSA_ERROR_NOT_SUPPORTED) ?
               PSA_SUCCESS : forced_status);

    if (status == PSA_SUCCESS) {
        ASSERT_COMPARE(expected_data->x, expected_data->len,
                       output_data, output_length);
    }

exit:
    psa_destroy_key(key);
    mbedtls_free(output_data);
    PSA_DONE();
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void mac_sign(int key_type_arg,
              data_t *key_data,
              int alg_arg,
              data_t *input,
              data_t *expected_mac,
              int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    uint8_t *actual_mac = NULL;
    size_t mac_buffer_size =
        PSA_MAC_LENGTH(key_type, PSA_BYTES_TO_BITS(key_data->len), alg);
    size_t mac_length = 0;
    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
    psa_status_t forced_status = forced_status_arg;
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();

    TEST_ASSERT(mac_buffer_size <= PSA_MAC_MAX_SIZE);
    /* We expect PSA_MAC_LENGTH to be exact. */
    TEST_ASSERT(expected_mac->len == mac_buffer_size);

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    ASSERT_ALLOC(actual_mac, mac_buffer_size);
    mbedtls_test_driver_mac_hooks.forced_status = forced_status;

    /*
     * Calculate the MAC, one-shot case.
     */
    status = psa_mac_compute(key, alg,
                             input->x, input->len,
                             actual_mac, mac_buffer_size,
                             &mac_length);

    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(forced_status, status);
    }

    PSA_ASSERT(psa_mac_abort(&operation));
    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);

    if (forced_status == PSA_SUCCESS) {
        ASSERT_COMPARE(expected_mac->x, expected_mac->len,
                       actual_mac, mac_length);
    }

    mbedtls_free(actual_mac);
    actual_mac = NULL;

exit:
    psa_mac_abort(&operation);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_free(actual_mac);
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void mac_sign_multipart(int key_type_arg,
                        data_t *key_data,
                        int alg_arg,
                        data_t *input,
                        data_t *expected_mac,
                        int fragments_count,
                        int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    uint8_t *actual_mac = NULL;
    size_t mac_buffer_size =
        PSA_MAC_LENGTH(key_type, PSA_BYTES_TO_BITS(key_data->len), alg);
    size_t mac_length = 0;
    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
    psa_status_t forced_status = forced_status_arg;
    uint8_t *input_x = input->x;
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();

    TEST_ASSERT(mac_buffer_size <= PSA_MAC_MAX_SIZE);
    /* We expect PSA_MAC_LENGTH to be exact. */
    TEST_ASSERT(expected_mac->len == mac_buffer_size);

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    ASSERT_ALLOC(actual_mac, mac_buffer_size);
    mbedtls_test_driver_mac_hooks.forced_status = forced_status;

    /*
     * Calculate the MAC, multipart case.
     */
    status = psa_mac_sign_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);

    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(forced_status, status);
    }

    if (fragments_count) {
        TEST_ASSERT((input->len / fragments_count) > 0);
    }

    for (int i = 0; i < fragments_count; i++) {
        int fragment_size = input->len / fragments_count;
        if (i == fragments_count - 1) {
            fragment_size += (input->len % fragments_count);
        }

        status = psa_mac_update(&operation,
                                input_x, fragment_size);
        if (forced_status == PSA_SUCCESS) {
            TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 2 + i);
        } else {
            TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
        }
        if (forced_status == PSA_SUCCESS ||
            forced_status == PSA_ERROR_NOT_SUPPORTED) {
            PSA_ASSERT(status);
        } else {
            TEST_EQUAL(PSA_ERROR_BAD_STATE, status);
        }
        input_x += fragment_size;
    }

    status = psa_mac_sign_finish(&operation,
                                 actual_mac, mac_buffer_size,
                                 &mac_length);
    if (forced_status == PSA_SUCCESS) {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 3 + fragments_count);
    } else {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    }

    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(PSA_ERROR_BAD_STATE, status);
    }

    PSA_ASSERT(psa_mac_abort(&operation));
    if (forced_status == PSA_SUCCESS) {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 3 + fragments_count);
    } else {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    }

    if (forced_status == PSA_SUCCESS) {
        ASSERT_COMPARE(expected_mac->x, expected_mac->len,
                       actual_mac, mac_length);
    }

    mbedtls_free(actual_mac);
    actual_mac = NULL;

exit:
    psa_mac_abort(&operation);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_free(actual_mac);
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void mac_verify(int key_type_arg,
                data_t *key_data,
                int alg_arg,
                data_t *input,
                data_t *expected_mac,
                int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
    psa_status_t forced_status = forced_status_arg;
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();

    TEST_ASSERT(expected_mac->len <= PSA_MAC_MAX_SIZE);

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    mbedtls_test_driver_mac_hooks.forced_status = forced_status;

    /*
     * Verify the MAC, one-shot case.
     */
    status = psa_mac_verify(key, alg,
                            input->x, input->len,
                            expected_mac->x, expected_mac->len);
    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(forced_status, status);
    }

    PSA_ASSERT(psa_mac_abort(&operation));
    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
exit:
    psa_mac_abort(&operation);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void mac_verify_multipart(int key_type_arg,
                          data_t *key_data,
                          int alg_arg,
                          data_t *input,
                          data_t *expected_mac,
                          int fragments_count,
                          int forced_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
    psa_status_t forced_status = forced_status_arg;
    uint8_t *input_x = input->x;
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();

    TEST_ASSERT(expected_mac->len <= PSA_MAC_MAX_SIZE);

    PSA_ASSERT(psa_crypto_init());

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    mbedtls_test_driver_mac_hooks.forced_status = forced_status;

    /*
     * Verify the MAC, multi-part case.
     */
    status = psa_mac_verify_setup(&operation, key, alg);
    TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);

    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(forced_status, status);
    }

    if (fragments_count) {
        TEST_ASSERT((input->len / fragments_count) > 0);
    }

    for (int i = 0; i < fragments_count; i++) {
        int fragment_size = input->len / fragments_count;
        if (i == fragments_count - 1) {
            fragment_size += (input->len % fragments_count);
        }

        status = psa_mac_update(&operation,
                                input_x, fragment_size);
        if (forced_status == PSA_SUCCESS) {
            TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 2 + i);
        } else {
            TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
        }

        if (forced_status == PSA_SUCCESS ||
            forced_status == PSA_ERROR_NOT_SUPPORTED) {
            PSA_ASSERT(status);
        } else {
            TEST_EQUAL(PSA_ERROR_BAD_STATE, status);
        }
        input_x += fragment_size;
    }

    status = psa_mac_verify_finish(&operation,
                                   expected_mac->x,
                                   expected_mac->len);
    if (forced_status == PSA_SUCCESS) {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 3 + fragments_count);
    } else {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    }

    if (forced_status == PSA_SUCCESS ||
        forced_status == PSA_ERROR_NOT_SUPPORTED) {
        PSA_ASSERT(status);
    } else {
        TEST_EQUAL(PSA_ERROR_BAD_STATE, status);
    }


    PSA_ASSERT(psa_mac_abort(&operation));
    if (forced_status == PSA_SUCCESS) {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 3 + fragments_count);
    } else {
        TEST_EQUAL(mbedtls_test_driver_mac_hooks.hits, 1);
    }

exit:
    psa_mac_abort(&operation);
    psa_destroy_key(key);
    PSA_DONE();
    mbedtls_test_driver_mac_hooks = mbedtls_test_driver_mac_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_CRYPTO_DRIVER_TEST:MBEDTLS_PSA_CRYPTO_DRIVERS:MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS */
void builtin_key_export(int builtin_key_id_arg,
                        int builtin_key_type_arg,
                        int builtin_key_bits_arg,
                        int builtin_key_algorithm_arg,
                        data_t *expected_output,
                        int expected_status_arg)
{
    psa_key_id_t builtin_key_id = (psa_key_id_t) builtin_key_id_arg;
    psa_key_type_t builtin_key_type = (psa_key_type_t) builtin_key_type_arg;
    psa_algorithm_t builtin_key_alg = (psa_algorithm_t) builtin_key_algorithm_arg;
    size_t builtin_key_bits = (size_t) builtin_key_bits_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    mbedtls_svc_key_id_t key = mbedtls_svc_key_id_make(0, builtin_key_id);
    uint8_t *output_buffer = NULL;
    size_t output_size = 0;
    psa_status_t actual_status;

    PSA_ASSERT(psa_crypto_init());
    ASSERT_ALLOC(output_buffer, expected_output->len);

    actual_status = psa_export_key(key, output_buffer, expected_output->len, &output_size);

    if (expected_status == PSA_SUCCESS) {
        PSA_ASSERT(actual_status);
        TEST_EQUAL(output_size, expected_output->len);
        ASSERT_COMPARE(output_buffer, output_size,
                       expected_output->x, expected_output->len);

        PSA_ASSERT(psa_get_key_attributes(key, &attributes));
        TEST_EQUAL(psa_get_key_bits(&attributes), builtin_key_bits);
        TEST_EQUAL(psa_get_key_type(&attributes), builtin_key_type);
        TEST_EQUAL(psa_get_key_algorithm(&attributes), builtin_key_alg);
    } else {
        if (actual_status != expected_status) {
            fprintf(stderr, "Expected %d but got %d\n", expected_status, actual_status);
        }
        TEST_EQUAL(actual_status, expected_status);
        TEST_EQUAL(output_size, 0);
    }

exit:
    mbedtls_free(output_buffer);
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_CRYPTO_DRIVER_TEST:MBEDTLS_PSA_CRYPTO_DRIVERS:MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS */
void builtin_pubkey_export(int builtin_key_id_arg,
                           int builtin_key_type_arg,
                           int builtin_key_bits_arg,
                           int builtin_key_algorithm_arg,
                           data_t *expected_output,
                           int expected_status_arg)
{
    psa_key_id_t builtin_key_id = (psa_key_id_t) builtin_key_id_arg;
    psa_key_type_t builtin_key_type = (psa_key_type_t) builtin_key_type_arg;
    psa_algorithm_t builtin_key_alg = (psa_algorithm_t) builtin_key_algorithm_arg;
    size_t builtin_key_bits = (size_t) builtin_key_bits_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    mbedtls_svc_key_id_t key = mbedtls_svc_key_id_make(0, builtin_key_id);
    uint8_t *output_buffer = NULL;
    size_t output_size = 0;
    psa_status_t actual_status;

    PSA_ASSERT(psa_crypto_init());
    ASSERT_ALLOC(output_buffer, expected_output->len);

    actual_status = psa_export_public_key(key, output_buffer, expected_output->len, &output_size);

    if (expected_status == PSA_SUCCESS) {
        PSA_ASSERT(actual_status);
        TEST_EQUAL(output_size, expected_output->len);
        ASSERT_COMPARE(output_buffer, output_size,
                       expected_output->x, expected_output->len);

        PSA_ASSERT(psa_get_key_attributes(key, &attributes));
        TEST_EQUAL(psa_get_key_bits(&attributes), builtin_key_bits);
        TEST_EQUAL(psa_get_key_type(&attributes), builtin_key_type);
        TEST_EQUAL(psa_get_key_algorithm(&attributes), builtin_key_alg);
    } else {
        TEST_EQUAL(actual_status, expected_status);
        TEST_EQUAL(output_size, 0);
    }

exit:
    mbedtls_free(output_buffer);
    psa_reset_key_attributes(&attributes);
    psa_destroy_key(key);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void hash_compute(int alg_arg,
                  data_t *input, data_t *hash,
                  int forced_status_arg,
                  int expected_status_arg)
{
    psa_algorithm_t alg = alg_arg;
    psa_status_t forced_status = forced_status_arg;
    psa_status_t expected_status = expected_status_arg;
    unsigned char *output = NULL;
    size_t output_length;

    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
    mbedtls_test_driver_hash_hooks.forced_status = forced_status;

    PSA_ASSERT(psa_crypto_init());
    ASSERT_ALLOC(output, PSA_HASH_LENGTH(alg));

    TEST_EQUAL(psa_hash_compute(alg, input->x, input->len,
                                output, PSA_HASH_LENGTH(alg),
                                &output_length), expected_status);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

    if (expected_status == PSA_SUCCESS) {
        ASSERT_COMPARE(output, output_length, hash->x, hash->len);
    }

exit:
    mbedtls_free(output);
    PSA_DONE();
    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void hash_multipart_setup(int alg_arg,
                          data_t *input, data_t *hash,
                          int forced_status_arg,
                          int expected_status_arg)
{
    psa_algorithm_t alg = alg_arg;
    psa_status_t forced_status = forced_status_arg;
    psa_status_t expected_status = expected_status_arg;
    unsigned char *output = NULL;
    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
    size_t output_length;

    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
    ASSERT_ALLOC(output, PSA_HASH_LENGTH(alg));

    PSA_ASSERT(psa_crypto_init());

    mbedtls_test_driver_hash_hooks.forced_status = forced_status;
    TEST_EQUAL(psa_hash_setup(&operation, alg), expected_status);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

    if (expected_status == PSA_SUCCESS) {
        PSA_ASSERT(psa_hash_update(&operation, input->x, input->len));
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits,
                   forced_status == PSA_ERROR_NOT_SUPPORTED ? 1 : 2);
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

        PSA_ASSERT(psa_hash_finish(&operation,
                                   output, PSA_HASH_LENGTH(alg),
                                   &output_length));
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits,
                   forced_status == PSA_ERROR_NOT_SUPPORTED ? 1 : 4);
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

        ASSERT_COMPARE(output, output_length, hash->x, hash->len);
    }

exit:
    psa_hash_abort(&operation);
    mbedtls_free(output);
    PSA_DONE();
    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void hash_multipart_update(int alg_arg,
                           data_t *input, data_t *hash,
                           int forced_status_arg)
{
    psa_algorithm_t alg = alg_arg;
    psa_status_t forced_status = forced_status_arg;
    unsigned char *output = NULL;
    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
    size_t output_length;

    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
    ASSERT_ALLOC(output, PSA_HASH_LENGTH(alg));

    PSA_ASSERT(psa_crypto_init());

    /*
     * Update inactive operation, the driver shouldn't be called.
     */
    TEST_EQUAL(psa_hash_update(&operation, input->x, input->len),
               PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 0);

    PSA_ASSERT(psa_hash_setup(&operation, alg));
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

    mbedtls_test_driver_hash_hooks.forced_status = forced_status;
    TEST_EQUAL(psa_hash_update(&operation, input->x, input->len),
               forced_status);
    /* One or two more calls to the driver interface: update or update + abort */
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits,
               forced_status == PSA_SUCCESS ? 2 : 3);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

    if (forced_status == PSA_SUCCESS) {
        mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
        PSA_ASSERT(psa_hash_finish(&operation,
                                   output, PSA_HASH_LENGTH(alg),
                                   &output_length));
        /* Two calls to the driver interface: update + abort */
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 2);
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

        ASSERT_COMPARE(output, output_length, hash->x, hash->len);
    }

exit:
    psa_hash_abort(&operation);
    mbedtls_free(output);
    PSA_DONE();
    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void hash_multipart_finish(int alg_arg,
                           data_t *input, data_t *hash,
                           int forced_status_arg)
{
    psa_algorithm_t alg = alg_arg;
    psa_status_t forced_status = forced_status_arg;
    unsigned char *output = NULL;
    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
    size_t output_length;

    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
    ASSERT_ALLOC(output, PSA_HASH_LENGTH(alg));

    PSA_ASSERT(psa_crypto_init());

    /*
     * Finish inactive operation, the driver shouldn't be called.
     */
    TEST_EQUAL(psa_hash_finish(&operation, output, PSA_HASH_LENGTH(alg),
                               &output_length),
               PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 0);

    PSA_ASSERT(psa_hash_setup(&operation, alg));
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

    PSA_ASSERT(psa_hash_update(&operation, input->x, input->len));
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 2);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

    mbedtls_test_driver_hash_hooks.forced_status = forced_status;
    TEST_EQUAL(psa_hash_finish(&operation,
                               output, PSA_HASH_LENGTH(alg),
                               &output_length),
               forced_status);
    /* Two more calls to the driver interface: finish + abort */
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 4);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

    if (forced_status == PSA_SUCCESS) {
        ASSERT_COMPARE(output, output_length, hash->x, hash->len);
    }

exit:
    psa_hash_abort(&operation);
    mbedtls_free(output);
    PSA_DONE();
    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void hash_clone(int alg_arg,
                data_t *input, data_t *hash,
                int forced_status_arg)
{
    psa_algorithm_t alg = alg_arg;
    psa_status_t forced_status = forced_status_arg;
    unsigned char *output = NULL;
    psa_hash_operation_t source_operation = PSA_HASH_OPERATION_INIT;
    psa_hash_operation_t target_operation = PSA_HASH_OPERATION_INIT;
    size_t output_length;

    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
    ASSERT_ALLOC(output, PSA_HASH_LENGTH(alg));

    PSA_ASSERT(psa_crypto_init());

    /*
     * Clone inactive operation, the driver shouldn't be called.
     */
    TEST_EQUAL(psa_hash_clone(&source_operation, &target_operation),
               PSA_ERROR_BAD_STATE);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 0);

    PSA_ASSERT(psa_hash_setup(&source_operation, alg));
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

    mbedtls_test_driver_hash_hooks.forced_status = forced_status;
    TEST_EQUAL(psa_hash_clone(&source_operation, &target_operation),
               forced_status);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits,
               forced_status == PSA_SUCCESS ? 2 : 3);
    TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, forced_status);

    if (forced_status == PSA_SUCCESS) {
        mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
        PSA_ASSERT(psa_hash_update(&target_operation,
                                   input->x, input->len));
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 1);
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

        PSA_ASSERT(psa_hash_finish(&target_operation,
                                   output, PSA_HASH_LENGTH(alg),
                                   &output_length));
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.hits, 3);
        TEST_EQUAL(mbedtls_test_driver_hash_hooks.driver_status, PSA_SUCCESS);

        ASSERT_COMPARE(output, output_length, hash->x, hash->len);
    }

exit:
    psa_hash_abort(&source_operation);
    psa_hash_abort(&target_operation);
    mbedtls_free(output);
    PSA_DONE();
    mbedtls_test_driver_hash_hooks = mbedtls_test_driver_hash_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void asymmetric_encrypt_decrypt(int alg_arg,
                                data_t *key_data,
                                data_t *input_data,
                                data_t *label,
                                data_t *fake_output_encrypt,
                                data_t *fake_output_decrypt,
                                int forced_status_encrypt_arg,
                                int forced_status_decrypt_arg,
                                int expected_status_encrypt_arg,
                                int expected_status_decrypt_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = PSA_KEY_TYPE_RSA_KEY_PAIR;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    unsigned char *output = NULL;
    size_t output_size;
    size_t output_length = ~0;
    unsigned char *output2 = NULL;
    size_t output2_size;
    size_t output2_length = ~0;
    psa_status_t forced_status_encrypt = forced_status_encrypt_arg;
    psa_status_t forced_status_decrypt = forced_status_decrypt_arg;
    psa_status_t expected_status_encrypt = expected_status_encrypt_arg;
    psa_status_t expected_status_decrypt = expected_status_decrypt_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    PSA_ASSERT(psa_crypto_init());
    mbedtls_test_driver_asymmetric_encryption_hooks =
        mbedtls_test_driver_asymmetric_encryption_hooks_init();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    /* Determine the maximum ciphertext length */
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);

    mbedtls_test_driver_asymmetric_encryption_hooks.forced_status =
        forced_status_encrypt;
    if (fake_output_encrypt->len > 0) {
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output =
            fake_output_encrypt->x;
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output_length =
            fake_output_encrypt->len;
        output_size = fake_output_encrypt->len;
        ASSERT_ALLOC(output, output_size);
    } else {
        output_size = PSA_ASYMMETRIC_ENCRYPT_OUTPUT_SIZE(key_type, key_bits, alg);
        TEST_ASSERT(output_size <= PSA_ASYMMETRIC_ENCRYPT_OUTPUT_MAX_SIZE);
        ASSERT_ALLOC(output, output_size);
    }

    /* We test encryption by checking that encrypt-then-decrypt gives back
     * the original plaintext because of the non-optional random
     * part of encryption process which prevents using fixed vectors. */
    TEST_EQUAL(psa_asymmetric_encrypt(key, alg,
                                      input_data->x, input_data->len,
                                      label->x, label->len,
                                      output, output_size,
                                      &output_length), expected_status_encrypt);
    /* We don't know what ciphertext length to expect, but check that
     * it looks sensible. */
    TEST_ASSERT(output_length <= output_size);

    if (expected_status_encrypt == PSA_SUCCESS) {
        if (fake_output_encrypt->len > 0) {
            ASSERT_COMPARE(fake_output_encrypt->x, fake_output_encrypt->len,
                           output, output_length);
        } else {
            mbedtls_test_driver_asymmetric_encryption_hooks.forced_status =
                forced_status_decrypt;
            if (fake_output_decrypt->len > 0) {
                mbedtls_test_driver_asymmetric_encryption_hooks.forced_output =
                    fake_output_decrypt->x;
                mbedtls_test_driver_asymmetric_encryption_hooks.forced_output_length =
                    fake_output_decrypt->len;
                output2_size = fake_output_decrypt->len;
                ASSERT_ALLOC(output2, output2_size);
            } else {
                output2_size = input_data->len;
                TEST_ASSERT(output2_size <=
                            PSA_ASYMMETRIC_DECRYPT_OUTPUT_SIZE(key_type, key_bits, alg));
                TEST_ASSERT(output2_size <= PSA_ASYMMETRIC_DECRYPT_OUTPUT_MAX_SIZE);
                ASSERT_ALLOC(output2, output2_size);
            }

            TEST_EQUAL(psa_asymmetric_decrypt(key, alg,
                                              output, output_length,
                                              label->x, label->len,
                                              output2, output2_size,
                                              &output2_length), expected_status_decrypt);
            if (expected_status_decrypt == PSA_SUCCESS) {
                if (fake_output_decrypt->len > 0) {
                    ASSERT_COMPARE(fake_output_decrypt->x, fake_output_decrypt->len,
                                   output2, output2_length);
                } else {
                    ASSERT_COMPARE(input_data->x, input_data->len,
                                   output2, output2_length);
                }
            }
        }
    }

exit:
    /*
     * Key attributes may have been returned by psa_get_key_attributes()
     * thus reset them as required.
     */
    psa_reset_key_attributes(&attributes);

    psa_destroy_key(key);
    mbedtls_free(output);
    mbedtls_free(output2);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void asymmetric_decrypt(int alg_arg,
                        data_t *key_data,
                        data_t *input_data,
                        data_t *label,
                        data_t *expected_output_data,
                        data_t *fake_output_decrypt,
                        int forced_status_decrypt_arg,
                        int expected_status_decrypt_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = PSA_KEY_TYPE_RSA_KEY_PAIR;
    psa_algorithm_t alg = alg_arg;
    unsigned char *output = NULL;
    size_t output_size;
    size_t output_length = ~0;
    psa_status_t forced_status_decrypt = forced_status_decrypt_arg;
    psa_status_t expected_status_decrypt = expected_status_decrypt_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    PSA_ASSERT(psa_crypto_init());
    mbedtls_test_driver_asymmetric_encryption_hooks =
        mbedtls_test_driver_asymmetric_encryption_hooks_init();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    mbedtls_test_driver_asymmetric_encryption_hooks.forced_status =
        forced_status_decrypt;

    if (fake_output_decrypt->len > 0) {
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output =
            fake_output_decrypt->x;
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output_length =
            fake_output_decrypt->len;
        output_size = fake_output_decrypt->len;
        ASSERT_ALLOC(output, output_size);
    } else {
        output_size = expected_output_data->len;
        ASSERT_ALLOC(output, expected_output_data->len);
    }

    TEST_EQUAL(psa_asymmetric_decrypt(key, alg,
                                      input_data->x, input_data->len,
                                      label->x, label->len,
                                      output, output_size,
                                      &output_length), expected_status_decrypt);
    if (expected_status_decrypt == PSA_SUCCESS) {
        TEST_EQUAL(output_length, expected_output_data->len);
        ASSERT_COMPARE(expected_output_data->x, expected_output_data->len,
                       output, output_length);
    }
exit:
    /*
     * Key attributes may have been returned by psa_get_key_attributes()
     * thus reset them as required.
     */
    psa_reset_key_attributes(&attributes);

    psa_destroy_key(key);
    mbedtls_free(output);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void asymmetric_encrypt(int alg_arg,
                        data_t *key_data,
                        data_t *modulus,
                        data_t *private_exponent,
                        data_t *input_data,
                        data_t *label,
                        data_t *fake_output_encrypt,
                        int forced_status_encrypt_arg,
                        int expected_status_encrypt_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = PSA_KEY_TYPE_RSA_PUBLIC_KEY;
    psa_algorithm_t alg = alg_arg;
    unsigned char *output = NULL;
    size_t output_size;
    size_t output_length = ~0;
    psa_status_t forced_status_encrypt = forced_status_encrypt_arg;
    psa_status_t expected_status_encrypt = expected_status_encrypt_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    PSA_ASSERT(psa_crypto_init());
    mbedtls_test_driver_asymmetric_encryption_hooks =
        mbedtls_test_driver_asymmetric_encryption_hooks_init();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    size_t key_bits = psa_get_key_bits(&attributes);

    mbedtls_test_driver_asymmetric_encryption_hooks.forced_status =
        forced_status_encrypt;

    if (fake_output_encrypt->len > 0) {
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output =
            fake_output_encrypt->x;
        mbedtls_test_driver_asymmetric_encryption_hooks.forced_output_length =
            fake_output_encrypt->len;
        output_size = fake_output_encrypt->len;
        ASSERT_ALLOC(output, output_size);
    } else {
        output_size = PSA_ASYMMETRIC_ENCRYPT_OUTPUT_SIZE(key_type, key_bits, alg);
        ASSERT_ALLOC(output, output_size);
    }

    TEST_EQUAL(psa_asymmetric_encrypt(key, alg,
                                      input_data->x, input_data->len,
                                      label->x, label->len,
                                      output, output_size,
                                      &output_length), expected_status_encrypt);
    if (expected_status_encrypt == PSA_SUCCESS) {
        if (fake_output_encrypt->len > 0) {
            TEST_EQUAL(fake_output_encrypt->len, output_length);
            ASSERT_COMPARE(fake_output_encrypt->x, fake_output_encrypt->len,
                           output, output_length);
        } else {
            /* Perform sanity checks on the output */
#if PSA_WANT_KEY_TYPE_RSA_PUBLIC_KEY
            if (PSA_KEY_TYPE_IS_RSA(key_type)) {
                if (!sanity_check_rsa_encryption_result(
                        alg, modulus, private_exponent,
                        input_data,
                        output, output_length)) {
                    goto exit;
                }
            } else
#endif
            {
                (void) modulus;
                (void) private_exponent;
                TEST_ASSERT(!"Encryption sanity checks not implemented for this key type");
            }
        }
    }
exit:
    /*
     * Key attributes may have been returned by psa_get_key_attributes()
     * thus reset them as required.
     */
    psa_reset_key_attributes(&attributes);

    psa_destroy_key(key);
    mbedtls_free(output);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void aead_encrypt_setup(int key_type_arg, data_t *key_data,
                        int alg_arg,
                        data_t *nonce,
                        data_t *additional_data,
                        data_t *input_data,
                        data_t *expected_ciphertext,
                        data_t *expected_tag,
                        int forced_status_arg,
                        int expected_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    size_t key_bits;
    psa_status_t forced_status = forced_status_arg;
    psa_status_t expected_status = expected_status_arg;
    uint8_t *output_data = NULL;
    size_t output_size = 0;
    size_t output_length = 0;
    size_t finish_output_length = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
    size_t tag_length = 0;
    uint8_t tag_buffer[PSA_AEAD_TAG_MAX_SIZE];

    psa_aead_operation_t operation = psa_aead_operation_init();

    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();

    PSA_INIT();

    mbedtls_test_driver_aead_hooks.forced_status = forced_status;

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));
    PSA_ASSERT(psa_get_key_attributes(key, &attributes));
    key_bits = psa_get_key_bits(&attributes);

    output_size = input_data->len + PSA_AEAD_TAG_LENGTH(key_type, key_bits,
                                                        alg);

    /* For all currently defined algorithms, PSA_AEAD_ENCRYPT_OUTPUT_SIZE
     * should be exact. */
    TEST_EQUAL(output_size,
               PSA_AEAD_ENCRYPT_OUTPUT_SIZE(key_type, alg, input_data->len));
    TEST_ASSERT(output_size <=
                PSA_AEAD_ENCRYPT_OUTPUT_MAX_SIZE(input_data->len));
    ASSERT_ALLOC(output_data, output_size);

    status = psa_aead_encrypt_setup(&operation, key, alg);

    TEST_EQUAL(status, expected_status);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_encrypt_setup, 1);

    if (status == PSA_SUCCESS) {
        /* Set the nonce. */
        PSA_ASSERT(psa_aead_set_nonce(&operation, nonce->x, nonce->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_set_nonce,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Check hooks hits and
         * set length (additional data and data to encrypt) */
        PSA_ASSERT(psa_aead_set_lengths(&operation, additional_data->len,
                                        input_data->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_set_lengths,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Pass the additional data */
        PSA_ASSERT(psa_aead_update_ad(&operation, additional_data->x,
                                      additional_data->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_update_ad,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Pass the data to encrypt */
        PSA_ASSERT(psa_aead_update(&operation, input_data->x, input_data->len,
                                   output_data, output_size, &output_length));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_update,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Finish the encryption operation */
        PSA_ASSERT(psa_aead_finish(&operation, output_data + output_length,
                                   output_size - output_length,
                                   &finish_output_length, tag_buffer,
                                   PSA_AEAD_TAG_MAX_SIZE, &tag_length));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_finish,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_abort,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Compare output_data and expected_ciphertext */
        ASSERT_COMPARE(expected_ciphertext->x, expected_ciphertext->len,
                       output_data, output_length + finish_output_length);

        /* Compare tag and expected_tag */
        ASSERT_COMPARE(expected_tag->x, expected_tag->len, tag_buffer, tag_length);
    }

exit:
    /* Cleanup */
    PSA_ASSERT(psa_destroy_key(key));
    mbedtls_free(output_data);
    PSA_DONE();
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();
}
/* END_CASE */

/* BEGIN_CASE */
void aead_decrypt_setup(int key_type_arg, data_t *key_data,
                        int alg_arg,
                        data_t *nonce,
                        data_t *additional_data,
                        data_t *input_ciphertext,
                        data_t *input_tag,
                        data_t *expected_result,
                        int forced_status_arg,
                        int expected_status_arg)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    unsigned char *output_data = NULL;
    size_t output_size = 0;
    size_t output_length = 0;
    size_t verify_output_length = 0;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_status_t forced_status = forced_status_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_status_t status = PSA_ERROR_GENERIC_ERROR;

    psa_aead_operation_t operation = psa_aead_operation_init();
    mbedtls_test_driver_aead_hooks = mbedtls_test_driver_aead_hooks_init();

    PSA_INIT();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, key_type);

    PSA_ASSERT(psa_import_key(&attributes, key_data->x, key_data->len,
                              &key));

    output_size = input_ciphertext->len;

    ASSERT_ALLOC(output_data, output_size);

    mbedtls_test_driver_aead_hooks.forced_status = forced_status;

    status = psa_aead_decrypt_setup(&operation, key, alg);

    TEST_EQUAL(status, (forced_status == PSA_ERROR_NOT_SUPPORTED) ?
               PSA_SUCCESS : forced_status);

    TEST_EQUAL(status, expected_status);
    TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_decrypt_setup, 1);

    if (status == PSA_SUCCESS) {
        PSA_ASSERT(psa_aead_set_nonce(&operation, nonce->x, nonce->len));
        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_set_nonce,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        PSA_ASSERT(psa_aead_set_lengths(&operation, additional_data->len,
                                        input_ciphertext->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_set_lengths,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        PSA_ASSERT(psa_aead_update_ad(&operation, additional_data->x,
                                      additional_data->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_update_ad,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        PSA_ASSERT(psa_aead_update(&operation, input_ciphertext->x,
                                   input_ciphertext->len, output_data,
                                   output_size, &output_length));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_update,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Offset applied to output_data in order to handle cases where verify()
         * outputs further data */
        PSA_ASSERT(psa_aead_verify(&operation, output_data + output_length,
                                   output_size - output_length,
                                   &verify_output_length, input_tag->x,
                                   input_tag->len));

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_verify,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        /* Since this is a decryption operation,
         * finish should never be hit */
        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_finish, 0);

        TEST_EQUAL(mbedtls_test_driver_aead_hooks.hits_abort,
                   forced_status == PSA_SUCCESS ? 1 : 0);

        ASSERT_COMPARE(expected_result->x, expected_result->len,
                       output_data, output_length + verify_output_length);
    }

exit:
    PSA_ASSERT(psa_destroy_key(key));
    mbedtls_free(output_data);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_status_arg,
                     data_t *forced_output, int expected_status_arg,
                     int fut)
{
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_status_t forced_status = forced_status_arg;
    psa_status_t forced_status_setup = forced_status_setup_arg;
    psa_status_t expected_status = expected_status_arg;
    psa_pake_operation_t operation = psa_pake_operation_init();
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_key_derivation_operation_t implicit_key =
        PSA_KEY_DERIVATION_OPERATION_INIT;
    psa_pake_primitive_t primitive = PSA_PAKE_PRIMITIVE(
        PSA_PAKE_PRIMITIVE_TYPE_ECC,
        PSA_ECC_FAMILY_SECP_R1, 256);
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    unsigned char *input_buffer = NULL;
    const size_t size_key_share = PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, primitive,
                                                      PSA_PAKE_STEP_KEY_SHARE);
    unsigned char *output_buffer = NULL;
    size_t output_len = 0;
    size_t output_size = PSA_PAKE_OUTPUT_SIZE(PSA_ALG_JPAKE, primitive,
                                              PSA_PAKE_STEP_KEY_SHARE);
    int in_driver = (forced_status_setup_arg == PSA_SUCCESS);

    ASSERT_ALLOC(input_buffer,
                 PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, primitive,
                                     PSA_PAKE_STEP_KEY_SHARE));
    memset(input_buffer, 0xAA, size_key_share);

    ASSERT_ALLOC(output_buffer,
                 PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, primitive,
                                     PSA_PAKE_STEP_KEY_SHARE));
    memset(output_buffer, 0x55, output_size);

    PSA_INIT();

    mbedtls_test_driver_pake_hooks = mbedtls_test_driver_pake_hooks_init();

    if (pw_data->len > 0) {
        psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
        psa_set_key_algorithm(&attributes, PSA_ALG_JPAKE);
        psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);
        PSA_ASSERT(psa_import_key(&attributes, pw_data->x, pw_data->len,
                                  &key));
    }

    psa_pake_cs_set_algorithm(&cipher_suite, PSA_ALG_JPAKE);
    psa_pake_cs_set_primitive(&cipher_suite, primitive);
    psa_pake_cs_set_hash(&cipher_suite, PSA_ALG_SHA_256);

    mbedtls_test_driver_pake_hooks.forced_status = forced_status_setup;

    /* Collecting input stage (no driver entry points) */

    TEST_EQUAL(psa_pake_setup(&operation, &cipher_suite),
               PSA_SUCCESS);

    TEST_EQUAL(psa_pake_set_role(&operation, PSA_PAKE_ROLE_SERVER),
               PSA_SUCCESS);

    TEST_EQUAL(psa_pake_set_password_key(&operation, key),
               PSA_SUCCESS);

    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);

    /* Computation stage (driver entry points) */

    switch (fut) {
        case 0: /* setup (via input) */
            /* --- psa_pake_input (driver: setup, input) --- */
            mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup;
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                      input_buffer, size_key_share),
                       expected_status);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.setup, 1);
            break;

        case 1: /* setup (via output) */
            /* --- psa_pake_output (driver: setup, output) --- */
            mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup;
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            TEST_EQUAL(psa_pake_output(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                       output_buffer, output_size, &output_len),
                       expected_status);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.setup, 1);
            break;

        case 2: /* input */
            /* --- psa_pake_input (driver: setup, input, abort) --- */
            mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup;
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                      input_buffer, size_key_share),
                       expected_status);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, in_driver ? 3 : 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.setup, 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.input, in_driver ? 1 : 0);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.abort, in_driver ? 1 : 0);
            break;

        case 3: /* output */
            /* --- psa_pake_output (driver: setup, output, (abort)) --- */
            mbedtls_test_driver_pake_hooks.forced_setup_status = forced_status_setup;
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            if (forced_output->len > 0) {
                mbedtls_test_driver_pake_hooks.forced_output = forced_output->x;
                mbedtls_test_driver_pake_hooks.forced_output_length = forced_output->len;
            }
            TEST_EQUAL(psa_pake_output(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                       output_buffer, output_size, &output_len),
                       expected_status);

            if (forced_output->len > 0) {
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, in_driver ? 2 : 1);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.setup, 1);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.output, in_driver ? 1 : 0);
                TEST_EQUAL(output_len, forced_output->len);
                TEST_EQUAL(memcmp(output_buffer, forced_output->x, output_len), 0);
            } else {
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, in_driver ? 3 : 1);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.setup, 1);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.output, in_driver ? 1 : 0);
                TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.abort, in_driver ? 1 : 0);
            }
            break;

        case 4: /* get_implicit_key */
            /* Call driver setup indirectly */
            TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                      input_buffer, size_key_share),
                       PSA_SUCCESS);

            /* Simulate that we are ready to get implicit key. */
            operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE;
            operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE;

            /* --- psa_pake_get_implicit_key --- */
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            memset(&mbedtls_test_driver_pake_hooks.hits, 0,
                   sizeof(mbedtls_test_driver_pake_hooks.hits));
            TEST_EQUAL(psa_pake_get_implicit_key(&operation, &implicit_key),
                       expected_status);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 2);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.implicit_key, 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.abort, 1);

            break;

        case 5: /* abort */
            /* Call driver setup indirectly */
            TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
                                      input_buffer, size_key_share),
                       PSA_SUCCESS);

            /* --- psa_pake_abort --- */
            mbedtls_test_driver_pake_hooks.forced_status = forced_status;
            memset(&mbedtls_test_driver_pake_hooks.hits, 0,
                   sizeof(mbedtls_test_driver_pake_hooks.hits));
            TEST_EQUAL(psa_pake_abort(&operation), expected_status);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 1);
            TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.abort, 1);
            break;

        default:
            break;
    }

    /* Clean up */
    mbedtls_test_driver_pake_hooks.forced_setup_status = PSA_SUCCESS;
    mbedtls_test_driver_pake_hooks.forced_status = PSA_SUCCESS;
    TEST_EQUAL(psa_pake_abort(&operation), PSA_SUCCESS);
exit:
    /*
     * Key attributes may have been returned by psa_get_key_attributes()
     * thus reset them as required.
     */
    psa_reset_key_attributes(&attributes);
    mbedtls_free(input_buffer);
    mbedtls_free(output_buffer);
    psa_destroy_key(key);
    mbedtls_test_driver_pake_hooks =
        mbedtls_test_driver_pake_hooks_init();
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256 */
void ecjpake_rounds(int alg_arg, int primitive_arg, int hash_arg,
                    int derive_alg_arg, data_t *pw_data,
                    int client_input_first, int in_driver)
{
    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
    psa_pake_operation_t server = psa_pake_operation_init();
    psa_pake_operation_t client = psa_pake_operation_init();
    psa_algorithm_t alg = alg_arg;
    psa_algorithm_t hash_alg = hash_arg;
    psa_algorithm_t derive_alg = derive_alg_arg;
    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    psa_key_derivation_operation_t server_derive =
        PSA_KEY_DERIVATION_OPERATION_INIT;
    psa_key_derivation_operation_t client_derive =
        PSA_KEY_DERIVATION_OPERATION_INIT;
    pake_in_driver = in_driver;
    /* driver setup is called indirectly through pake_output/pake_input */
    if (pake_in_driver) {
        pake_expected_hit_count = 2;
    } else {
        pake_expected_hit_count = 1;
    }

    PSA_INIT();

    mbedtls_test_driver_pake_hooks = mbedtls_test_driver_pake_hooks_init();

    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, alg);
    psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);
    PSA_ASSERT(psa_import_key(&attributes, pw_data->x, pw_data->len,
                              &key));

    psa_pake_cs_set_algorithm(&cipher_suite, alg);
    psa_pake_cs_set_primitive(&cipher_suite, primitive_arg);
    psa_pake_cs_set_hash(&cipher_suite, hash_alg);

    /* Get shared key */
    PSA_ASSERT(psa_key_derivation_setup(&server_derive, derive_alg));
    PSA_ASSERT(psa_key_derivation_setup(&client_derive, derive_alg));

    if (PSA_ALG_IS_TLS12_PSK_TO_MS(derive_alg)) {
        PSA_ASSERT(psa_key_derivation_input_bytes(&server_derive,
                                                  PSA_KEY_DERIVATION_INPUT_SEED,
                                                  (const uint8_t *) "", 0));
        PSA_ASSERT(psa_key_derivation_input_bytes(&client_derive,
                                                  PSA_KEY_DERIVATION_INPUT_SEED,
                                                  (const uint8_t *) "", 0));
    }

    if (!pake_in_driver) {
        mbedtls_test_driver_pake_hooks.forced_setup_status = PSA_ERROR_NOT_SUPPORTED;
    }

    PSA_ASSERT(psa_pake_setup(&server, &cipher_suite));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);
    PSA_ASSERT(psa_pake_setup(&client, &cipher_suite));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);


    PSA_ASSERT(psa_pake_set_role(&server, PSA_PAKE_ROLE_SERVER));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);
    PSA_ASSERT(psa_pake_set_role(&client, PSA_PAKE_ROLE_CLIENT));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);
    PSA_ASSERT(psa_pake_set_password_key(&server, key));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);
    PSA_ASSERT(psa_pake_set_password_key(&client, key));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total, 0);

    /* First round */
    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, 1);

    /* Second round */
    ecjpake_do_round(alg, primitive_arg, &server, &client,
                     client_input_first, 2);

    /* After the key is obtained operation is aborted.
       Adapt counter of expected hits. */
    if (pake_in_driver) {
        pake_expected_hit_count++;
    }

    PSA_ASSERT(psa_pake_get_implicit_key(&server, &server_derive));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
               pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);

    /* After the key is obtained operation is aborted.
       Adapt counter of expected hits. */
    if (pake_in_driver) {
        pake_expected_hit_count++;
    }

    PSA_ASSERT(psa_pake_get_implicit_key(&client, &client_derive));
    TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits.total,
               pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
exit:
    psa_key_derivation_abort(&server_derive);
    psa_key_derivation_abort(&client_derive);
    psa_destroy_key(key);
    psa_pake_abort(&server);
    psa_pake_abort(&client);
    PSA_DONE();
}
/* END_CASE */