/* BEGIN_HEADER */
#include "mbedtls/nist_kw.h"
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_NIST_KW_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE depends_on:MBEDTLS_SELF_TEST:MBEDTLS_AES_C */
void mbedtls_nist_kw_self_test()
{
    TEST_ASSERT(mbedtls_nist_kw_self_test(1) == 0);
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_AES_C */
void mbedtls_nist_kw_mix_contexts()
{
    mbedtls_nist_kw_context ctx1, ctx2;
    unsigned char key[16];
    unsigned char plaintext[32];
    unsigned char ciphertext1[40];
    unsigned char ciphertext2[40];
    size_t output_len, i;

    memset(plaintext, 0, sizeof(plaintext));
    memset(ciphertext1, 0, sizeof(ciphertext1));
    memset(ciphertext2, 0, sizeof(ciphertext2));
    memset(key, 0, sizeof(key));

    /*
     * 1. Check wrap and unwrap with two separate contexts
     */
    mbedtls_nist_kw_init(&ctx1);
    mbedtls_nist_kw_init(&ctx2);

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx1,
                                       MBEDTLS_CIPHER_ID_AES,
                                       key, sizeof(key) * 8,
                                       1) == 0);

    TEST_ASSERT(mbedtls_nist_kw_wrap(&ctx1, MBEDTLS_KW_MODE_KW,
                                     plaintext, sizeof(plaintext),
                                     ciphertext1, &output_len,
                                     sizeof(ciphertext1)) == 0);
    TEST_ASSERT(output_len == sizeof(ciphertext1));

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx2,
                                       MBEDTLS_CIPHER_ID_AES,
                                       key, sizeof(key) * 8,
                                       0) == 0);

    TEST_ASSERT(mbedtls_nist_kw_unwrap(&ctx2, MBEDTLS_KW_MODE_KW,
                                       ciphertext1, output_len,
                                       plaintext, &output_len,
                                       sizeof(plaintext)) == 0);

    TEST_ASSERT(output_len == sizeof(plaintext));
    for (i = 0; i < sizeof(plaintext); i++) {
        TEST_ASSERT(plaintext[i] == 0);
    }
    mbedtls_nist_kw_free(&ctx1);
    mbedtls_nist_kw_free(&ctx2);

    /*
     * 2. Check wrapping with two modes, on same context
     */
    mbedtls_nist_kw_init(&ctx1);
    mbedtls_nist_kw_init(&ctx2);
    output_len = sizeof(ciphertext1);

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx1,
                                       MBEDTLS_CIPHER_ID_AES,
                                       key, sizeof(key) * 8,
                                       1) == 0);

    TEST_ASSERT(mbedtls_nist_kw_wrap(&ctx1, MBEDTLS_KW_MODE_KW,
                                     plaintext, sizeof(plaintext),
                                     ciphertext1, &output_len,
                                     sizeof(ciphertext1)) == 0);
    TEST_ASSERT(output_len == sizeof(ciphertext1));

    TEST_ASSERT(mbedtls_nist_kw_wrap(&ctx1, MBEDTLS_KW_MODE_KWP,
                                     plaintext, sizeof(plaintext),
                                     ciphertext2, &output_len,
                                     sizeof(ciphertext2)) == 0);

    TEST_ASSERT(output_len == sizeof(ciphertext2));

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx2,
                                       MBEDTLS_CIPHER_ID_AES,
                                       key, sizeof(key) * 8,
                                       0) == 0);

    TEST_ASSERT(mbedtls_nist_kw_unwrap(&ctx2, MBEDTLS_KW_MODE_KW,
                                       ciphertext1, sizeof(ciphertext1),
                                       plaintext, &output_len,
                                       sizeof(plaintext)) == 0);

    TEST_ASSERT(output_len == sizeof(plaintext));

    for (i = 0; i < sizeof(plaintext); i++) {
        TEST_ASSERT(plaintext[i] == 0);
    }

    TEST_ASSERT(mbedtls_nist_kw_unwrap(&ctx2, MBEDTLS_KW_MODE_KWP,
                                       ciphertext2, sizeof(ciphertext2),
                                       plaintext, &output_len,
                                       sizeof(plaintext)) == 0);

    TEST_ASSERT(output_len == sizeof(plaintext));

    for (i = 0; i < sizeof(plaintext); i++) {
        TEST_ASSERT(plaintext[i] == 0);
    }

exit:
    mbedtls_nist_kw_free(&ctx1);
    mbedtls_nist_kw_free(&ctx2);
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_nist_kw_setkey(int cipher_id, int key_size,
                            int is_wrap, int result)
{
    mbedtls_nist_kw_context ctx;
    unsigned char key[32];
    int ret;

    mbedtls_nist_kw_init(&ctx);

    memset(key, 0x2A, sizeof(key));
    TEST_ASSERT((unsigned) key_size <= 8 * sizeof(key));

    ret = mbedtls_nist_kw_setkey(&ctx, cipher_id, key, key_size, is_wrap);
    TEST_ASSERT(ret == result);

exit:
    mbedtls_nist_kw_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_AES_C */
void nist_kw_plaintext_lengths(int in_len, int out_len, int mode, int res)
{
    mbedtls_nist_kw_context ctx;
    unsigned char key[16];
    unsigned char *plaintext = NULL;
    unsigned char *ciphertext = NULL;
    size_t output_len = out_len;

    mbedtls_nist_kw_init(&ctx);

    memset(key, 0, sizeof(key));

    if (in_len != 0) {
        plaintext = mbedtls_calloc(1, in_len);
        TEST_ASSERT(plaintext != NULL);
    }

    if (out_len != 0) {
        ciphertext = mbedtls_calloc(1, output_len);
        TEST_ASSERT(ciphertext != NULL);
    }

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx, MBEDTLS_CIPHER_ID_AES,
                                       key, 8 * sizeof(key), 1) == 0);

    TEST_ASSERT(mbedtls_nist_kw_wrap(&ctx, mode, plaintext, in_len,
                                     ciphertext, &output_len,
                                     output_len) == res);
    if (res == 0) {
        if (mode == MBEDTLS_KW_MODE_KWP) {
            TEST_ASSERT(output_len == (size_t) in_len + 8 -
                        (in_len % 8) + 8);
        } else {
            TEST_ASSERT(output_len == (size_t) in_len + 8);
        }
    } else {
        TEST_ASSERT(output_len == 0);
    }

exit:
    mbedtls_free(ciphertext);
    mbedtls_free(plaintext);
    mbedtls_nist_kw_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_AES_C */
void nist_kw_ciphertext_lengths(int in_len, int out_len, int mode, int res)
{
    mbedtls_nist_kw_context ctx;
    unsigned char key[16];
    unsigned char *plaintext = NULL;
    unsigned char *ciphertext = NULL;
    int unwrap_ret;
    size_t output_len = out_len;

    mbedtls_nist_kw_init(&ctx);

    memset(key, 0, sizeof(key));

    if (out_len != 0) {
        plaintext = mbedtls_calloc(1, output_len);
        TEST_ASSERT(plaintext != NULL);
    }
    if (in_len != 0) {
        ciphertext = mbedtls_calloc(1, in_len);
        TEST_ASSERT(ciphertext != NULL);
    }

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx, MBEDTLS_CIPHER_ID_AES,
                                       key, 8 * sizeof(key), 0) == 0);
    unwrap_ret = mbedtls_nist_kw_unwrap(&ctx, mode, ciphertext, in_len,
                                        plaintext, &output_len,
                                        output_len);

    if (res == 0) {
        TEST_ASSERT(unwrap_ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED);
    } else {
        TEST_ASSERT(unwrap_ret == res);
    }

    TEST_ASSERT(output_len == 0);

exit:
    mbedtls_free(ciphertext);
    mbedtls_free(plaintext);
    mbedtls_nist_kw_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_nist_kw_wrap(int cipher_id, int mode, data_t *key, data_t *msg,
                          data_t *expected_result)
{
    unsigned char result[528];
    mbedtls_nist_kw_context ctx;
    size_t result_len, i, padlen;

    mbedtls_nist_kw_init(&ctx);

    memset(result, '+', sizeof(result));

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx, cipher_id,
                                       key->x, key->len * 8, 1) == 0);

    /* Test with input == output */
    TEST_ASSERT(mbedtls_nist_kw_wrap(&ctx, mode, msg->x, msg->len,
                                     result, &result_len, sizeof(result)) == 0);

    TEST_ASSERT(result_len == expected_result->len);

    TEST_ASSERT(memcmp(expected_result->x, result, result_len) == 0);

    padlen = (msg->len % 8 != 0) ? 8 - (msg->len % 8) : 0;
    /* Check that the function didn't write beyond the end of the buffer. */
    for (i = msg->len + 8 + padlen; i < sizeof(result); i++) {
        TEST_ASSERT(result[i] == '+');
    }

exit:
    mbedtls_nist_kw_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_nist_kw_unwrap(int cipher_id, int mode, data_t *key, data_t *msg,
                            data_t *expected_result, int expected_ret)
{
    unsigned char result[528];
    mbedtls_nist_kw_context ctx;
    size_t result_len, i;

    mbedtls_nist_kw_init(&ctx);

    memset(result, '+', sizeof(result));

    TEST_ASSERT(mbedtls_nist_kw_setkey(&ctx, cipher_id,
                                       key->x, key->len * 8, 0) == 0);

    /* Test with input == output */
    TEST_ASSERT(mbedtls_nist_kw_unwrap(&ctx, mode, msg->x, msg->len,
                                       result, &result_len, sizeof(result)) == expected_ret);
    if (expected_ret == 0) {
        TEST_ASSERT(result_len == expected_result->len);
        TEST_ASSERT(memcmp(expected_result->x, result, result_len) == 0);
    } else {
        TEST_ASSERT(result_len == 0);
    }

    /* Check that the function didn't write beyond the end of the buffer. */
    for (i = msg->len - 8; i < sizeof(result); i++) {
        TEST_ASSERT(result[i] == '+');
    }

exit:
    mbedtls_nist_kw_free(&ctx);
}
/* END_CASE */