/* BEGIN_HEADER */
#include "mbedtls/pkcs12.h"
#include "common.h"

typedef enum {
    USE_NULL_INPUT = 0,
    USE_GIVEN_INPUT = 1,
} input_usage_method_t;

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PKCS12_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void pkcs12_derive_key(int md_type, int key_size_arg,
                       data_t *password_arg, int password_usage,
                       data_t *salt_arg, int salt_usage,
                       int iterations,
                       data_t *expected_output, int expected_status)

{
    unsigned char *output_data = NULL;

    unsigned char *password = NULL;
    size_t password_len = 0;
    unsigned char *salt = NULL;
    size_t salt_len = 0;
    size_t key_size = key_size_arg;

    MD_PSA_INIT();

    if (password_usage == USE_GIVEN_INPUT) {
        password = password_arg->x;
    }

    password_len = password_arg->len;

    if (salt_usage == USE_GIVEN_INPUT) {
        salt = salt_arg->x;
    }

    salt_len = salt_arg->len;

    TEST_CALLOC(output_data, key_size);

    int ret = mbedtls_pkcs12_derivation(output_data,
                                        key_size,
                                        password,
                                        password_len,
                                        salt,
                                        salt_len,
                                        md_type,
                                        MBEDTLS_PKCS12_DERIVE_KEY,
                                        iterations);

    TEST_EQUAL(ret, expected_status);

    if (expected_status == 0) {
        TEST_MEMORY_COMPARE(expected_output->x, expected_output->len,
                            output_data, key_size);
    }

exit:
    mbedtls_free(output_data);
    MD_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_ASN1_PARSE_C */
void pkcs12_pbe_encrypt(int params_tag, int cipher, int md, data_t *params_hex, data_t *pw,
                        data_t *data, int outsize, int ref_ret, data_t *ref_out)
{
    int my_ret;
    mbedtls_asn1_buf pbe_params;
    unsigned char *my_out = NULL;
    mbedtls_cipher_type_t cipher_alg = (mbedtls_cipher_type_t) cipher;
    mbedtls_md_type_t md_alg = (mbedtls_md_type_t) md;
#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
    size_t my_out_len = 0;
#endif

    MD_PSA_INIT();

    TEST_CALLOC(my_out, outsize);

    pbe_params.tag = params_tag;
    pbe_params.len = params_hex->len;
    pbe_params.p = params_hex->x;

    if (ref_ret != MBEDTLS_ERR_ASN1_BUF_TOO_SMALL) {
        my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_ENCRYPT, cipher_alg,
                                    md_alg, pw->x, pw->len, data->x, data->len, my_out);
        TEST_EQUAL(my_ret, ref_ret);
    }
    if (ref_ret == 0) {
        ASSERT_COMPARE(my_out, ref_out->len,
                       ref_out->x, ref_out->len);
    }

#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)

    pbe_params.tag = params_tag;
    pbe_params.len = params_hex->len;
    pbe_params.p = params_hex->x;

    my_ret = mbedtls_pkcs12_pbe_ext(&pbe_params, MBEDTLS_PKCS12_PBE_ENCRYPT, cipher_alg,
                                    md_alg, pw->x, pw->len, data->x, data->len, my_out,
                                    outsize, &my_out_len);
    TEST_EQUAL(my_ret, ref_ret);
    if (ref_ret == 0) {
        ASSERT_COMPARE(my_out, my_out_len,
                       ref_out->x, ref_out->len);
    }
#endif

exit:
    mbedtls_free(my_out);
    MD_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_ASN1_PARSE_C */
void pkcs12_pbe_decrypt(int params_tag, int cipher, int md, data_t *params_hex, data_t *pw,
                        data_t *data, int outsize, int ref_ret, data_t *ref_out)
{
    int my_ret;
    mbedtls_asn1_buf pbe_params;
    unsigned char *my_out = NULL;
    mbedtls_cipher_type_t cipher_alg = (mbedtls_cipher_type_t) cipher;
    mbedtls_md_type_t md_alg = (mbedtls_md_type_t) md;
#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
    size_t my_out_len = 0;
#endif

    MD_PSA_INIT();

    TEST_CALLOC(my_out, outsize);

    pbe_params.tag = params_tag;
    pbe_params.len = params_hex->len;
    pbe_params.p = params_hex->x;

    if (ref_ret != MBEDTLS_ERR_ASN1_BUF_TOO_SMALL) {
        my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_DECRYPT, cipher_alg,
                                    md_alg, pw->x, pw->len, data->x, data->len, my_out);
        TEST_EQUAL(my_ret, ref_ret);
    }

    if (ref_ret == 0) {
        ASSERT_COMPARE(my_out, ref_out->len,
                       ref_out->x, ref_out->len);
    }

#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)

    pbe_params.tag = params_tag;
    pbe_params.len = params_hex->len;
    pbe_params.p = params_hex->x;

    my_ret = mbedtls_pkcs12_pbe_ext(&pbe_params, MBEDTLS_PKCS12_PBE_DECRYPT, cipher_alg,
                                    md_alg, pw->x, pw->len, data->x, data->len, my_out,
                                    outsize, &my_out_len);
    TEST_EQUAL(my_ret, ref_ret);
    if (ref_ret == 0) {
        ASSERT_COMPARE(my_out, my_out_len,
                       ref_out->x, ref_out->len);
    }
#endif

exit:
    mbedtls_free(my_out);
    MD_PSA_DONE();
}
/* END_CASE */