/* BEGIN_HEADER */
#include <stdint.h>

#include "common.h"

#include "psa/crypto.h"

#include "psa_crypto_core.h"
#include "psa_crypto_invasive.h"

#include "test/psa_crypto_helpers.h"
#include "test/memory.h"

/* Helper to fill a buffer with a data pattern. The pattern is not
 * important, it just allows a basic check that the correct thing has
 * been written, in a way that will detect an error in offset. */
static void fill_buffer_pattern(uint8_t *buffer, size_t len)
{
    for (size_t i = 0; i < len; i++) {
        buffer[i] = (uint8_t) (i % 256);
    }
}
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C:MBEDTLS_TEST_HOOKS
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void copy_input(int src_len, int dst_len, psa_status_t exp_status)
{
    uint8_t *src_buffer = NULL;
    uint8_t *dst_buffer = NULL;
    psa_status_t status;

    TEST_CALLOC(src_buffer, src_len);
    TEST_CALLOC(dst_buffer, dst_len);

    fill_buffer_pattern(src_buffer, src_len);

    status = psa_crypto_copy_input(src_buffer, src_len, dst_buffer, dst_len);
    TEST_EQUAL(status, exp_status);

    if (exp_status == PSA_SUCCESS) {
        MBEDTLS_TEST_MEMORY_UNPOISON(src_buffer, src_len);
        /* Note: We compare the first src_len bytes of each buffer, as this is what was copied. */
        TEST_MEMORY_COMPARE(src_buffer, src_len, dst_buffer, src_len);
    }

exit:
    mbedtls_free(src_buffer);
    mbedtls_free(dst_buffer);
}
/* END_CASE */

/* BEGIN_CASE */
void copy_output(int src_len, int dst_len, psa_status_t exp_status)
{
    uint8_t *src_buffer = NULL;
    uint8_t *dst_buffer = NULL;
    psa_status_t status;

    TEST_CALLOC(src_buffer, src_len);
    TEST_CALLOC(dst_buffer, dst_len);

    fill_buffer_pattern(src_buffer, src_len);

    status = psa_crypto_copy_output(src_buffer, src_len, dst_buffer, dst_len);
    TEST_EQUAL(status, exp_status);

    if (exp_status == PSA_SUCCESS) {
        MBEDTLS_TEST_MEMORY_UNPOISON(dst_buffer, dst_len);
        /* Note: We compare the first src_len bytes of each buffer, as this is what was copied. */
        TEST_MEMORY_COMPARE(src_buffer, src_len, dst_buffer, src_len);
    }

exit:
    mbedtls_free(src_buffer);
    mbedtls_free(dst_buffer);
}
/* END_CASE */

/* BEGIN_CASE */
void local_input_alloc(int input_len, psa_status_t exp_status)
{
    uint8_t *input = NULL;
    psa_crypto_local_input_t local_input;
    psa_status_t status;

    local_input.buffer = NULL;

    TEST_CALLOC(input, input_len);
    fill_buffer_pattern(input, input_len);

    status = psa_crypto_local_input_alloc(input, input_len, &local_input);
    TEST_EQUAL(status, exp_status);

    if (exp_status == PSA_SUCCESS) {
        MBEDTLS_TEST_MEMORY_UNPOISON(input, input_len);
        if (input_len != 0) {
            TEST_ASSERT(local_input.buffer != input);
        }
        TEST_MEMORY_COMPARE(input, input_len,
                            local_input.buffer, local_input.length);
    }

exit:
    mbedtls_free(local_input.buffer);
    mbedtls_free(input);
}
/* END_CASE */

/* BEGIN_CASE */
void local_input_free(int input_len)
{
    psa_crypto_local_input_t local_input;

    local_input.buffer = NULL;
    local_input.length = input_len;
    TEST_CALLOC(local_input.buffer, local_input.length);

    psa_crypto_local_input_free(&local_input);

    TEST_ASSERT(local_input.buffer == NULL);
    TEST_EQUAL(local_input.length, 0);

exit:
    mbedtls_free(local_input.buffer);
    local_input.buffer = NULL;
    local_input.length = 0;
}
/* END_CASE */

/* BEGIN_CASE */
void local_input_round_trip()
{
    psa_crypto_local_input_t local_input;
    uint8_t input[200];
    psa_status_t status;

    fill_buffer_pattern(input, sizeof(input));

    status = psa_crypto_local_input_alloc(input, sizeof(input), &local_input);
    TEST_EQUAL(status, PSA_SUCCESS);

    MBEDTLS_TEST_MEMORY_UNPOISON(input, sizeof(input));
    TEST_MEMORY_COMPARE(local_input.buffer, local_input.length,
                        input, sizeof(input));
    TEST_ASSERT(local_input.buffer != input);

    psa_crypto_local_input_free(&local_input);
    TEST_ASSERT(local_input.buffer == NULL);
    TEST_EQUAL(local_input.length, 0);
}
/* END_CASE */

/* BEGIN_CASE */
void local_output_alloc(int output_len, psa_status_t exp_status)
{
    uint8_t *output = NULL;
    psa_crypto_local_output_t local_output;
    psa_status_t status;

    local_output.buffer = NULL;

    TEST_CALLOC(output, output_len);

    status = psa_crypto_local_output_alloc(output, output_len, &local_output);
    TEST_EQUAL(status, exp_status);

    if (exp_status == PSA_SUCCESS) {
        TEST_ASSERT(local_output.original == output);
        TEST_EQUAL(local_output.length, output_len);
    }

exit:
    mbedtls_free(local_output.buffer);
    local_output.original = NULL;
    local_output.buffer = NULL;
    local_output.length = 0;
    mbedtls_free(output);
    output = NULL;
}
/* END_CASE */

/* BEGIN_CASE */
void local_output_free(int output_len, int original_is_null,
                       psa_status_t exp_status)
{
    uint8_t *output = NULL;
    uint8_t *buffer_copy_for_comparison = NULL;
    psa_crypto_local_output_t local_output = PSA_CRYPTO_LOCAL_OUTPUT_INIT;
    psa_status_t status;

    if (!original_is_null) {
        TEST_CALLOC(output, output_len);
    }
    TEST_CALLOC(buffer_copy_for_comparison, output_len);
    TEST_CALLOC(local_output.buffer, output_len);
    local_output.length = output_len;
    local_output.original = output;

    if (local_output.length != 0) {
        fill_buffer_pattern(local_output.buffer, local_output.length);
        memcpy(buffer_copy_for_comparison, local_output.buffer, local_output.length);
    }

    status = psa_crypto_local_output_free(&local_output);
    TEST_EQUAL(status, exp_status);

    if (exp_status == PSA_SUCCESS) {
        MBEDTLS_TEST_MEMORY_UNPOISON(output, output_len);
        TEST_ASSERT(local_output.buffer == NULL);
        TEST_EQUAL(local_output.length, 0);
        TEST_MEMORY_COMPARE(buffer_copy_for_comparison, output_len,
                            output, output_len);
    }

exit:
    mbedtls_free(output);
    mbedtls_free(buffer_copy_for_comparison);
    mbedtls_free(local_output.buffer);
    local_output.length = 0;
}
/* END_CASE */

/* BEGIN_CASE */
void local_output_round_trip()
{
    psa_crypto_local_output_t local_output;
    uint8_t output[200];
    uint8_t *buffer_copy_for_comparison = NULL;
    psa_status_t status;

    status = psa_crypto_local_output_alloc(output, sizeof(output), &local_output);
    TEST_EQUAL(status, PSA_SUCCESS);
    TEST_ASSERT(local_output.buffer != output);

    /* Simulate the function generating output */
    fill_buffer_pattern(local_output.buffer, local_output.length);

    TEST_CALLOC(buffer_copy_for_comparison, local_output.length);
    memcpy(buffer_copy_for_comparison, local_output.buffer, local_output.length);

    psa_crypto_local_output_free(&local_output);
    TEST_ASSERT(local_output.buffer == NULL);
    TEST_EQUAL(local_output.length, 0);

    MBEDTLS_TEST_MEMORY_UNPOISON(output, sizeof(output));
    /* Check that the buffer was correctly copied back */
    TEST_MEMORY_COMPARE(output, sizeof(output),
                        buffer_copy_for_comparison, sizeof(output));

exit:
    mbedtls_free(buffer_copy_for_comparison);
}
/* END_CASE */