/* BEGIN_HEADER */
#include "mbedtls/gcm.h"

/* Use the multipart interface to process the encrypted data in two parts
 * and check that the output matches the expected output.
 * The context must have been set up with the key. */
static int check_multipart(mbedtls_gcm_context *ctx,
                           int mode,
                           const data_t *iv,
                           const data_t *add,
                           const data_t *input,
                           const data_t *expected_output,
                           const data_t *tag,
                           size_t n1,
                           size_t n1_add)
{
    int ok = 0;
    uint8_t *output = NULL;
    size_t n2 = input->len - n1;
    size_t n2_add = add->len - n1_add;
    size_t olen;

    /* Sanity checks on the test data */
    TEST_ASSERT(n1 <= input->len);
    TEST_ASSERT(n1_add <= add->len);
    TEST_EQUAL(input->len, expected_output->len);

    TEST_EQUAL(0, mbedtls_gcm_starts(ctx, mode,
                                     iv->x, iv->len));
    TEST_EQUAL(0, mbedtls_gcm_update_ad(ctx, add->x, n1_add));
    TEST_EQUAL(0, mbedtls_gcm_update_ad(ctx, add->x + n1_add, n2_add));

    /* Allocate a tight buffer for each update call. This way, if the function
     * tries to write beyond the advertised required buffer size, this will
     * count as an overflow for memory sanitizers and static checkers. */
    TEST_CALLOC(output, n1);
    olen = 0xdeadbeef;
    TEST_EQUAL(0, mbedtls_gcm_update(ctx, input->x, n1, output, n1, &olen));
    TEST_EQUAL(n1, olen);
    TEST_MEMORY_COMPARE(output, olen, expected_output->x, n1);
    mbedtls_free(output);
    output = NULL;

    TEST_CALLOC(output, n2);
    olen = 0xdeadbeef;
    TEST_EQUAL(0, mbedtls_gcm_update(ctx, input->x + n1, n2, output, n2, &olen));
    TEST_EQUAL(n2, olen);
    TEST_MEMORY_COMPARE(output, olen, expected_output->x + n1, n2);
    mbedtls_free(output);
    output = NULL;

    TEST_CALLOC(output, tag->len);
    TEST_EQUAL(0, mbedtls_gcm_finish(ctx, NULL, 0, &olen, output, tag->len));
    TEST_EQUAL(0, olen);
    TEST_MEMORY_COMPARE(output, tag->len, tag->x, tag->len);
    mbedtls_free(output);
    output = NULL;

    ok = 1;
exit:
    mbedtls_free(output);
    return ok;
}

static void check_cipher_with_empty_ad(mbedtls_gcm_context *ctx,
                                       int mode,
                                       const data_t *iv,
                                       const data_t *input,
                                       const data_t *expected_output,
                                       const data_t *tag,
                                       size_t ad_update_count)
{
    size_t n;
    uint8_t *output = NULL;
    size_t olen;

    /* Sanity checks on the test data */
    TEST_EQUAL(input->len, expected_output->len);

    TEST_EQUAL(0, mbedtls_gcm_starts(ctx, mode,
                                     iv->x, iv->len));

    for (n = 0; n < ad_update_count; n++) {
        TEST_EQUAL(0, mbedtls_gcm_update_ad(ctx, NULL, 0));
    }

    /* Allocate a tight buffer for each update call. This way, if the function
     * tries to write beyond the advertised required buffer size, this will
     * count as an overflow for memory sanitizers and static checkers. */
    TEST_CALLOC(output, input->len);
    olen = 0xdeadbeef;
    TEST_EQUAL(0, mbedtls_gcm_update(ctx, input->x, input->len, output, input->len, &olen));
    TEST_EQUAL(input->len, olen);
    TEST_MEMORY_COMPARE(output, olen, expected_output->x, input->len);
    mbedtls_free(output);
    output = NULL;

    TEST_CALLOC(output, tag->len);
    TEST_EQUAL(0, mbedtls_gcm_finish(ctx, NULL, 0, &olen, output, tag->len));
    TEST_EQUAL(0, olen);
    TEST_MEMORY_COMPARE(output, tag->len, tag->x, tag->len);

exit:
    mbedtls_free(output);
}

static void check_empty_cipher_with_ad(mbedtls_gcm_context *ctx,
                                       int mode,
                                       const data_t *iv,
                                       const data_t *add,
                                       const data_t *tag,
                                       size_t cipher_update_count)
{
    size_t olen;
    size_t n;
    uint8_t *output_tag = NULL;

    TEST_EQUAL(0, mbedtls_gcm_starts(ctx, mode, iv->x, iv->len));
    TEST_EQUAL(0, mbedtls_gcm_update_ad(ctx, add->x, add->len));

    for (n = 0; n < cipher_update_count; n++) {
        olen = 0xdeadbeef;
        TEST_EQUAL(0, mbedtls_gcm_update(ctx, NULL, 0, NULL, 0, &olen));
        TEST_EQUAL(0, olen);
    }

    TEST_CALLOC(output_tag, tag->len);
    TEST_EQUAL(0, mbedtls_gcm_finish(ctx, NULL, 0, &olen,
                                     output_tag, tag->len));
    TEST_EQUAL(0, olen);
    TEST_MEMORY_COMPARE(output_tag, tag->len, tag->x, tag->len);

exit:
    mbedtls_free(output_tag);
}

static void check_no_cipher_no_ad(mbedtls_gcm_context *ctx,
                                  int mode,
                                  const data_t *iv,
                                  const data_t *tag)
{
    uint8_t *output = NULL;
    size_t olen = 0;

    TEST_EQUAL(0, mbedtls_gcm_starts(ctx, mode,
                                     iv->x, iv->len));
    TEST_CALLOC(output, tag->len);
    TEST_EQUAL(0, mbedtls_gcm_finish(ctx, NULL, 0, &olen, output, tag->len));
    TEST_EQUAL(0, olen);
    TEST_MEMORY_COMPARE(output, tag->len, tag->x, tag->len);

exit:
    mbedtls_free(output);
}

static void gcm_reset_ctx(mbedtls_gcm_context *ctx, const uint8_t *key,
                          size_t key_bits, const uint8_t *iv, size_t iv_len,
                          int starts_ret)
{
    int mode = MBEDTLS_GCM_ENCRYPT;
    mbedtls_cipher_id_t valid_cipher = MBEDTLS_CIPHER_ID_AES;

    mbedtls_gcm_init(ctx);
    TEST_EQUAL(mbedtls_gcm_setkey(ctx, valid_cipher, key, key_bits), 0);
    TEST_EQUAL(starts_ret, mbedtls_gcm_starts(ctx, mode, iv, iv_len));
exit:
    /* empty */
    return;
}

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_GCM_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void gcm_bad_parameters(int cipher_id, int direction,
                        data_t *key_str, data_t *src_str,
                        data_t *iv_str, data_t *add_str,
                        int tag_len_bits, int gcm_result)
{
    unsigned char output[128];
    unsigned char tag_output[16];
    mbedtls_gcm_context ctx;
    size_t tag_len = tag_len_bits / 8;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    memset(output, 0x00, sizeof(output));
    memset(tag_output, 0x00, sizeof(tag_output));

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    TEST_ASSERT(mbedtls_gcm_crypt_and_tag(&ctx, direction, src_str->len, iv_str->x, iv_str->len,
                                          add_str->x, add_str->len, src_str->x, output, tag_len,
                                          tag_output) == gcm_result);

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag(int cipher_id, data_t *key_str,
                         data_t *src_str, data_t *iv_str,
                         data_t *add_str, data_t *dst,
                         int tag_len_bits, data_t *tag,
                         int init_result)
{
    unsigned char output[128];
    unsigned char tag_output[16];
    mbedtls_gcm_context ctx;
    size_t tag_len = tag_len_bits / 8;
    size_t n1;
    size_t n1_add;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    memset(output, 0x00, 128);
    memset(tag_output, 0x00, 16);


    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == init_result);
    if (init_result == 0) {
        TEST_ASSERT(mbedtls_gcm_crypt_and_tag(&ctx, MBEDTLS_GCM_ENCRYPT, src_str->len, iv_str->x,
                                              iv_str->len, add_str->x, add_str->len, src_str->x,
                                              output, tag_len, tag_output) == 0);

        TEST_MEMORY_COMPARE(output, src_str->len, dst->x, dst->len);
        TEST_MEMORY_COMPARE(tag_output, tag_len, tag->x, tag->len);

        for (n1 = 0; n1 <= src_str->len; n1 += 1) {
            for (n1_add = 0; n1_add <= add_str->len; n1_add += 1) {
                mbedtls_test_set_step(n1 * 10000 + n1_add);
                if (!check_multipart(&ctx, MBEDTLS_GCM_ENCRYPT,
                                     iv_str, add_str, src_str,
                                     dst, tag,
                                     n1, n1_add)) {
                    goto exit;
                }
            }
        }
    }

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify(int cipher_id, data_t *key_str,
                            data_t *src_str, data_t *iv_str,
                            data_t *add_str, int tag_len_bits,
                            data_t *tag_str, char *result,
                            data_t *pt_result, int init_result)
{
    unsigned char output[128];
    mbedtls_gcm_context ctx;
    int ret;
    size_t tag_len = tag_len_bits / 8;
    size_t n1;
    size_t n1_add;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    memset(output, 0x00, 128);


    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == init_result);
    if (init_result == 0) {
        ret = mbedtls_gcm_auth_decrypt(&ctx,
                                       src_str->len,
                                       iv_str->x,
                                       iv_str->len,
                                       add_str->x,
                                       add_str->len,
                                       tag_str->x,
                                       tag_len,
                                       src_str->x,
                                       output);

        if (strcmp("FAIL", result) == 0) {
            TEST_ASSERT(ret == MBEDTLS_ERR_GCM_AUTH_FAILED);
        } else {
            TEST_ASSERT(ret == 0);
            TEST_MEMORY_COMPARE(output, src_str->len, pt_result->x, pt_result->len);

            for (n1 = 0; n1 <= src_str->len; n1 += 1) {
                for (n1_add = 0; n1_add <= add_str->len; n1_add += 1) {
                    mbedtls_test_set_step(n1 * 10000 + n1_add);
                    if (!check_multipart(&ctx, MBEDTLS_GCM_DECRYPT,
                                         iv_str, add_str, src_str,
                                         pt_result, tag_str,
                                         n1, n1_add)) {
                        goto exit;
                    }
                }
            }
        }
    }

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_empty_cipher(int cipher_id,
                                         data_t *key_str,
                                         data_t *iv_str,
                                         data_t *add_str,
                                         data_t *tag_str,
                                         int cipher_update_calls)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_empty_cipher_with_ad(&ctx, MBEDTLS_GCM_DECRYPT,
                               iv_str, add_str, tag_str,
                               cipher_update_calls);

    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_empty_ad(int cipher_id,
                                     data_t *key_str,
                                     data_t *iv_str,
                                     data_t *src_str,
                                     data_t *tag_str,
                                     data_t *pt_result,
                                     int ad_update_calls)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_cipher_with_empty_ad(&ctx, MBEDTLS_GCM_DECRYPT,
                               iv_str, src_str, pt_result, tag_str,
                               ad_update_calls);

    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_decrypt_and_verify_no_ad_no_cipher(int cipher_id,
                                            data_t *key_str,
                                            data_t *iv_str,
                                            data_t *tag_str)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_no_cipher_no_ad(&ctx, MBEDTLS_GCM_DECRYPT,
                          iv_str, tag_str);

    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag_empty_cipher(int cipher_id,
                                      data_t *key_str,
                                      data_t *iv_str,
                                      data_t *add_str,
                                      data_t *tag_str,
                                      int cipher_update_calls)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_empty_cipher_with_ad(&ctx, MBEDTLS_GCM_ENCRYPT,
                               iv_str, add_str, tag_str,
                               cipher_update_calls);

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_tag_empty_ad(int cipher_id,
                                  data_t *key_str,
                                  data_t *iv_str,
                                  data_t *src_str,
                                  data_t *dst,
                                  data_t *tag_str,
                                  int ad_update_calls)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_cipher_with_empty_ad(&ctx, MBEDTLS_GCM_ENCRYPT,
                               iv_str, src_str, dst, tag_str,
                               ad_update_calls);

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_encrypt_and_verify_no_ad_no_cipher(int cipher_id,
                                            data_t *key_str,
                                            data_t *iv_str,
                                            data_t *tag_str)
{
    mbedtls_gcm_context ctx;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);

    TEST_ASSERT(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8) == 0);
    check_no_cipher_no_ad(&ctx, MBEDTLS_GCM_ENCRYPT,
                          iv_str, tag_str);

    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_invalid_param()
{
    mbedtls_gcm_context ctx;
    unsigned char valid_buffer[] = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06 };
    mbedtls_cipher_id_t valid_cipher = MBEDTLS_CIPHER_ID_AES;
    int invalid_bitlen = 1;

    mbedtls_gcm_init(&ctx);

    /* mbedtls_gcm_setkey */
    TEST_EQUAL(
        MBEDTLS_ERR_GCM_BAD_INPUT,
        mbedtls_gcm_setkey(&ctx, valid_cipher, valid_buffer, invalid_bitlen));

exit:
    mbedtls_gcm_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_update_output_buffer_too_small(int cipher_id, int mode,
                                        data_t *key_str, const data_t *input,
                                        const data_t *iv)
{
    mbedtls_gcm_context ctx;
    uint8_t *output = NULL;
    size_t olen = 0;
    size_t output_len = input->len - 1;

    BLOCK_CIPHER_PSA_INIT();
    mbedtls_gcm_init(&ctx);
    TEST_EQUAL(mbedtls_gcm_setkey(&ctx, cipher_id, key_str->x, key_str->len * 8), 0);
    TEST_EQUAL(0, mbedtls_gcm_starts(&ctx, mode, iv->x, iv->len));

    TEST_CALLOC(output, output_len);
    TEST_EQUAL(MBEDTLS_ERR_GCM_BUFFER_TOO_SMALL,
               mbedtls_gcm_update(&ctx, input->x, input->len, output, output_len, &olen));

exit:
    mbedtls_free(output);
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
/* NISP SP 800-38D, Section 5.2.1.1 requires that bit length of IV should
 * satisfy 1 <= bit_len(IV) <= 2^64 - 1. */
void gcm_invalid_iv_len(void)
{
    mbedtls_gcm_context ctx;
    mbedtls_gcm_init(&ctx);
    uint8_t b16[16] = { 0 };

    BLOCK_CIPHER_PSA_INIT();

    // Invalid IV length 0
    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, 0, MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);

    // Only testable on platforms where sizeof(size_t) >= 8.
#if SIZE_MAX >= UINT64_MAX
    // Invalid IV length 2^61
    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, 1ULL << 61, MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);
#endif

    goto exit; /* To suppress error that exit is defined but not used */
exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_add_len_too_long(void)
{
    // Only testable on platforms where sizeof(size_t) >= 8.
#if SIZE_MAX >= UINT64_MAX
    mbedtls_gcm_context ctx;
    mbedtls_gcm_init(&ctx);
    uint8_t b16[16] = { 0 };
    BLOCK_CIPHER_PSA_INIT();

    /* NISP SP 800-38D, Section 5.2.1.1 requires that bit length of AD should
     * be <= 2^64 - 1, ie < 2^64. This is the minimum invalid length in bytes. */
    uint64_t len_max = 1ULL << 61;

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Feed AD that just exceeds the length limit
    TEST_EQUAL(mbedtls_gcm_update_ad(&ctx, b16, len_max),
               MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Feed AD that just exceeds the length limit in two calls
    TEST_EQUAL(mbedtls_gcm_update_ad(&ctx, b16, 1), 0);
    TEST_EQUAL(mbedtls_gcm_update_ad(&ctx, b16, len_max - 1),
               MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Test if potential total AD length overflow is handled properly
    TEST_EQUAL(mbedtls_gcm_update_ad(&ctx, b16, 1), 0);
    TEST_EQUAL(mbedtls_gcm_update_ad(&ctx, b16, UINT64_MAX), MBEDTLS_ERR_GCM_BAD_INPUT);

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
#endif
}
/* END_CASE */

/* BEGIN_CASE */
void gcm_input_len_too_long(void)
{
    // Only testable on platforms where sizeof(size_t) >= 8
#if SIZE_MAX >= UINT64_MAX
    mbedtls_gcm_context ctx;
    uint8_t b16[16] = { 0 };
    uint8_t out[1];
    size_t out_len;
    mbedtls_gcm_init(&ctx);
    BLOCK_CIPHER_PSA_INIT();

    /* NISP SP 800-38D, Section 5.2.1.1 requires that bit length of input should
     * be <= 2^39 - 256. This is the maximum valid length in bytes. */
    uint64_t len_max = (1ULL << 36) - 32;

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Feed input that just exceeds the length limit
    TEST_EQUAL(mbedtls_gcm_update(&ctx, b16, len_max + 1, out, len_max + 1,
                                  &out_len),
               MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Feed input that just exceeds the length limit in two calls
    TEST_EQUAL(mbedtls_gcm_update(&ctx, b16, 1, out, 1, &out_len), 0);
    TEST_EQUAL(mbedtls_gcm_update(&ctx, b16, len_max, out, len_max, &out_len),
               MBEDTLS_ERR_GCM_BAD_INPUT);
    mbedtls_gcm_free(&ctx);

    gcm_reset_ctx(&ctx, b16, sizeof(b16) * 8, b16, sizeof(b16), 0);
    // Test if potential total input length overflow is handled properly
    TEST_EQUAL(mbedtls_gcm_update(&ctx, b16, 1, out, 1, &out_len), 0);
    TEST_EQUAL(mbedtls_gcm_update(&ctx, b16, UINT64_MAX, out, UINT64_MAX,
                                  &out_len),
               MBEDTLS_ERR_GCM_BAD_INPUT);

exit:
    mbedtls_gcm_free(&ctx);
    BLOCK_CIPHER_PSA_DONE();
#endif
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_SELF_TEST:MBEDTLS_CCM_GCM_CAN_AES */
void gcm_selftest()
{
    BLOCK_CIPHER_PSA_INIT();
    TEST_ASSERT(mbedtls_gcm_self_test(1) == 0);
    BLOCK_CIPHER_PSA_DONE();
}
/* END_CASE */