/* BEGIN_HEADER */
#include "mbedtls/pk.h"
#include "mbedtls/pem.h"
#include "mbedtls/oid.h"
#include "psa/crypto_sizes.h"

typedef enum {
    TEST_PEM,
    TEST_DER
} pkwrite_file_format_t;

/* Helper function for removing "\r" chars from a buffer. */
static void fix_new_lines(unsigned char *in_str, size_t *len)
{
    size_t chars_left;
    unsigned int i;

    for (i = 0; (i < *len) && (*len > 0); i++) {
        if (in_str[i] == '\r') {
            if (i < (*len - 1)) {
                chars_left = *len - i - 1;
                memmove(&in_str[i], &in_str[i+1], chars_left);
            } else {
                in_str[i] = '\0';
            }
            *len = *len - 1;
        }
    }
}

static void pk_write_check_common(char *key_file, int is_public_key, int is_der)
{
    mbedtls_pk_context key;
    unsigned char *buf = NULL;
    unsigned char *check_buf = NULL;
    unsigned char *start_buf;
    size_t buf_len, check_buf_len;
    int ret;

    mbedtls_pk_init(&key);
    USE_PSA_INIT();

    /* Note: if mbedtls_pk_load_file() successfully reads the file, then
       it also allocates check_buf, which should be freed on exit */
    TEST_EQUAL(mbedtls_pk_load_file(key_file, &check_buf, &check_buf_len), 0);
    TEST_ASSERT(check_buf_len > 0);

    /* Windows' line ending is different from the Linux's one ("\r\n" vs "\n").
     * Git treats PEM files as text, so when on Windows, it replaces new lines
     * with "\r\n" on checkout.
     * Unfortunately mbedtls_pk_load_file() loads files in binary format,
     * while mbedtls_pk_write_pubkey_pem() goes through the I/O layer which
     * uses "\n" for newlines in both Windows and Linux.
     * Here we remove the extra "\r" so that "buf" and "check_buf" can be
     * easily compared later. */
    if (!is_der) {
        fix_new_lines(check_buf, &check_buf_len);
    }
    TEST_ASSERT(check_buf_len > 0);

    ASSERT_ALLOC(buf, check_buf_len);

    if (is_public_key) {
        TEST_EQUAL(mbedtls_pk_parse_public_keyfile(&key, key_file), 0);
        if (is_der) {
            ret = mbedtls_pk_write_pubkey_der(&key, buf, check_buf_len);
        } else {
#if defined(MBEDTLS_PEM_WRITE_C)
            ret = mbedtls_pk_write_pubkey_pem(&key, buf, check_buf_len);
#else
            ret = MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
#endif
        }
    } else {
        TEST_EQUAL(mbedtls_pk_parse_keyfile(&key, key_file, NULL,
                                            mbedtls_test_rnd_std_rand, NULL), 0);
        if (is_der) {
            ret = mbedtls_pk_write_key_der(&key, buf, check_buf_len);
        } else {
#if defined(MBEDTLS_PEM_WRITE_C)
            ret = mbedtls_pk_write_key_pem(&key, buf, check_buf_len);
#else
            ret = MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
#endif
        }
    }

    if (is_der) {
        TEST_LE_U(1, ret);
        buf_len = ret;
        start_buf = buf + check_buf_len - buf_len;
    } else {
        TEST_EQUAL(ret, 0);
        buf_len = strlen((char *) buf) + 1; /* +1 takes the string terminator into account */
        start_buf = buf;
    }

    ASSERT_COMPARE(start_buf, buf_len, check_buf, check_buf_len);

exit:
    mbedtls_free(buf);
    mbedtls_free(check_buf);
    mbedtls_pk_free(&key);
    USE_PSA_DONE();
}
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PK_PARSE_C:MBEDTLS_PK_WRITE_C:MBEDTLS_BIGNUM_C:MBEDTLS_FS_IO
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void pk_write_pubkey_check(char *key_file, int is_der)
{
    pk_write_check_common(key_file, 1, is_der);
    goto exit; /* make the compiler happy */
}
/* END_CASE */

/* BEGIN_CASE */
void pk_write_key_check(char *key_file, int is_der)
{
    pk_write_check_common(key_file, 0, is_der);
    goto exit; /* make the compiler happy */
}
/* END_CASE */

/* BEGIN_CASE */
void pk_write_public_from_private(char *priv_key_file, char *pub_key_file)
{
    mbedtls_pk_context priv_key;
    uint8_t *derived_key_raw = NULL;
    size_t derived_key_len = 0;
    uint8_t *pub_key_raw = NULL;
    size_t pub_key_len = 0;
#if defined(MBEDTLS_USE_PSA_CRYPTO)
    mbedtls_svc_key_id_t opaque_key_id = MBEDTLS_SVC_KEY_ID_INIT;
#endif /* MBEDTLS_USE_PSA_CRYPTO */

    mbedtls_pk_init(&priv_key);
    USE_PSA_INIT();

    TEST_EQUAL(mbedtls_pk_parse_keyfile(&priv_key, priv_key_file, NULL,
                                        mbedtls_test_rnd_std_rand, NULL), 0);
    TEST_EQUAL(mbedtls_pk_load_file(pub_key_file, &pub_key_raw,
                                    &pub_key_len), 0);

    derived_key_len = pub_key_len;
    ASSERT_ALLOC(derived_key_raw, derived_key_len);

    TEST_EQUAL(mbedtls_pk_write_pubkey_der(&priv_key, derived_key_raw,
                                           derived_key_len), pub_key_len);

    ASSERT_COMPARE(derived_key_raw, derived_key_len,
                   pub_key_raw, pub_key_len);

#if defined(MBEDTLS_USE_PSA_CRYPTO)
    mbedtls_platform_zeroize(derived_key_raw, sizeof(derived_key_raw));

    TEST_EQUAL(mbedtls_pk_wrap_as_opaque(&priv_key, &opaque_key_id,
                                         PSA_ALG_NONE, PSA_KEY_USAGE_EXPORT,
                                         PSA_ALG_NONE), 0);

    TEST_EQUAL(mbedtls_pk_write_pubkey_der(&priv_key, derived_key_raw,
                                           derived_key_len), pub_key_len);

    ASSERT_COMPARE(derived_key_raw, derived_key_len,
                   pub_key_raw, pub_key_len);
#endif /* MBEDTLS_USE_PSA_CRYPTO */

exit:
#if defined(MBEDTLS_USE_PSA_CRYPTO)
    psa_destroy_key(opaque_key_id);
#endif /* MBEDTLS_USE_PSA_CRYPTO */
    mbedtls_free(derived_key_raw);
    mbedtls_free(pub_key_raw);
    mbedtls_pk_free(&priv_key);
    USE_PSA_DONE();
}
/* END_CASE */