/* BEGIN_HEADER */
#include "mbedtls/dhm.h"

int check_get_value(const mbedtls_dhm_context *ctx,
                    mbedtls_dhm_parameter param,
                    const mbedtls_mpi *expected)
{
    mbedtls_mpi actual;
    int ok = 0;
    mbedtls_mpi_init(&actual);

    TEST_ASSERT(mbedtls_dhm_get_value(ctx, param, &actual) == 0);
    TEST_ASSERT(mbedtls_mpi_cmp_mpi(&actual, expected) == 0);
    ok = 1;

exit:
    mbedtls_mpi_free(&actual);
    return ok;
}

/* Sanity checks on a Diffie-Hellman parameter: check the length-value
 * syntax and check that the value is the expected one (taken from the
 * DHM context by the caller). */
static int check_dhm_param_output(const mbedtls_mpi *expected,
                                  const unsigned char *buffer,
                                  size_t size,
                                  size_t *offset)
{
    size_t n;
    mbedtls_mpi actual;
    int ok = 0;
    mbedtls_mpi_init(&actual);

    ++mbedtls_test_info.step;

    TEST_ASSERT(size >= *offset + 2);
    n = (buffer[*offset] << 8) | buffer[*offset + 1];
    *offset += 2;
    /* The DHM param output from Mbed TLS has leading zeros stripped, as
     * permitted but not required by RFC 5246 \S4.4. */
    TEST_EQUAL(n, mbedtls_mpi_size(expected));
    TEST_ASSERT(size >= *offset + n);
    TEST_EQUAL(0, mbedtls_mpi_read_binary(&actual, buffer + *offset, n));
    TEST_EQUAL(0, mbedtls_mpi_cmp_mpi(expected, &actual));
    *offset += n;

    ok = 1;
exit:
    mbedtls_mpi_free(&actual);
    return ok;
}

/* Sanity checks on Diffie-Hellman parameters: syntax, range, and comparison
 * against the context. */
static int check_dhm_params(const mbedtls_dhm_context *ctx,
                            size_t x_size,
                            const unsigned char *ske, size_t ske_len)
{
    size_t offset = 0;

    /* Check that ctx->X and ctx->GX are within range. */
    TEST_ASSERT(mbedtls_mpi_cmp_int(&ctx->X, 1) > 0);
    TEST_ASSERT(mbedtls_mpi_cmp_mpi(&ctx->X, &ctx->P) < 0);
    TEST_ASSERT(mbedtls_mpi_size(&ctx->X) <= x_size);
    TEST_ASSERT(mbedtls_mpi_cmp_int(&ctx->GX, 1) > 0);
    TEST_ASSERT(mbedtls_mpi_cmp_mpi(&ctx->GX, &ctx->P) < 0);

    /* Check ske: it must contain P, G and G^X, each prefixed with a
     * 2-byte size. */
    if (!check_dhm_param_output(&ctx->P, ske, ske_len, &offset)) {
        goto exit;
    }
    if (!check_dhm_param_output(&ctx->G, ske, ske_len, &offset)) {
        goto exit;
    }
    if (!check_dhm_param_output(&ctx->GX, ske, ske_len, &offset)) {
        goto exit;
    }
    TEST_EQUAL(offset, ske_len);

    return 1;
exit:
    return 0;
}

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_DHM_C:MBEDTLS_BIGNUM_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void dhm_do_dhm(char *input_P, int x_size,
                char *input_G, int result)
{
    mbedtls_dhm_context ctx_srv;
    mbedtls_dhm_context ctx_cli;
    unsigned char ske[1000];
    unsigned char *p = ske;
    unsigned char pub_cli[1000];
    unsigned char sec_srv[1000];
    unsigned char sec_cli[1000];
    size_t ske_len = 0;
    size_t pub_cli_len = 0;
    size_t sec_srv_len;
    size_t sec_cli_len;
    int i;
    mbedtls_test_rnd_pseudo_info rnd_info;

    mbedtls_dhm_init(&ctx_srv);
    mbedtls_dhm_init(&ctx_cli);
    memset(ske, 0x00, 1000);
    memset(pub_cli, 0x00, 1000);
    memset(sec_srv, 0x00, 1000);
    memset(sec_cli, 0x00, 1000);
    memset(&rnd_info, 0x00, sizeof(mbedtls_test_rnd_pseudo_info));

    /*
     * Set params
     */
    TEST_ASSERT(mbedtls_test_read_mpi(&ctx_srv.P, input_P) == 0);
    TEST_ASSERT(mbedtls_test_read_mpi(&ctx_srv.G, input_G) == 0);
    pub_cli_len = mbedtls_mpi_size(&ctx_srv.P);
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_P, &ctx_srv.P));
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_G, &ctx_srv.G));

    /*
     * First key exchange
     */
    mbedtls_test_set_step(10);
    TEST_ASSERT(mbedtls_dhm_make_params(&ctx_srv, x_size, ske, &ske_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == result);
    if (result != 0) {
        goto exit;
    }
    if (!check_dhm_params(&ctx_srv, x_size, ske, ske_len)) {
        goto exit;
    }

    ske[ske_len++] = 0;
    ske[ske_len++] = 0;
    TEST_ASSERT(mbedtls_dhm_read_params(&ctx_cli, &p, ske + ske_len) == 0);
    /* The domain parameters must be the same on both side. */
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_P, &ctx_srv.P));
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_G, &ctx_srv.G));

    TEST_ASSERT(mbedtls_dhm_make_public(&ctx_cli, x_size, pub_cli, pub_cli_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);
    TEST_ASSERT(mbedtls_dhm_read_public(&ctx_srv, pub_cli, pub_cli_len) == 0);

    TEST_ASSERT(mbedtls_dhm_calc_secret(&ctx_srv, sec_srv, sizeof(sec_srv),
                                        &sec_srv_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);
    TEST_ASSERT(mbedtls_dhm_calc_secret(&ctx_cli, sec_cli, sizeof(sec_cli),
                                        &sec_cli_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);

    TEST_ASSERT(sec_srv_len == sec_cli_len);
    TEST_ASSERT(sec_srv_len != 0);
    TEST_ASSERT(memcmp(sec_srv, sec_cli, sec_srv_len) == 0);

    /* Internal value checks */
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_X, &ctx_cli.X));
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_X, &ctx_srv.X));
    /* Cross-checks */
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_GX, &ctx_srv.GY));
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_GY, &ctx_srv.GX));
    TEST_ASSERT(check_get_value(&ctx_cli, MBEDTLS_DHM_PARAM_K, &ctx_srv.K));
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_GX, &ctx_cli.GY));
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_GY, &ctx_cli.GX));
    TEST_ASSERT(check_get_value(&ctx_srv, MBEDTLS_DHM_PARAM_K, &ctx_cli.K));

    /* Re-do calc_secret on server a few times to test update of blinding values */
    for (i = 0; i < 3; i++) {
        mbedtls_test_set_step(20 + i);
        sec_srv_len = 1000;
        TEST_ASSERT(mbedtls_dhm_calc_secret(&ctx_srv, sec_srv,
                                            sizeof(sec_srv), &sec_srv_len,
                                            &mbedtls_test_rnd_pseudo_rand,
                                            &rnd_info) == 0);

        TEST_ASSERT(sec_srv_len == sec_cli_len);
        TEST_ASSERT(sec_srv_len != 0);
        TEST_ASSERT(memcmp(sec_srv, sec_cli, sec_srv_len) == 0);
    }

    /*
     * Second key exchange to test change of blinding values on server
     */
    p = ske;

    mbedtls_test_set_step(30);
    TEST_ASSERT(mbedtls_dhm_make_params(&ctx_srv, x_size, ske, &ske_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);
    if (!check_dhm_params(&ctx_srv, x_size, ske, ske_len)) {
        goto exit;
    }
    ske[ske_len++] = 0;
    ske[ske_len++] = 0;
    TEST_ASSERT(mbedtls_dhm_read_params(&ctx_cli, &p, ske + ske_len) == 0);

    TEST_ASSERT(mbedtls_dhm_make_public(&ctx_cli, x_size, pub_cli, pub_cli_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);
    TEST_ASSERT(mbedtls_dhm_read_public(&ctx_srv, pub_cli, pub_cli_len) == 0);

    TEST_ASSERT(mbedtls_dhm_calc_secret(&ctx_srv, sec_srv, sizeof(sec_srv),
                                        &sec_srv_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);
    TEST_ASSERT(mbedtls_dhm_calc_secret(&ctx_cli, sec_cli, sizeof(sec_cli),
                                        &sec_cli_len,
                                        &mbedtls_test_rnd_pseudo_rand,
                                        &rnd_info) == 0);

    TEST_ASSERT(sec_srv_len == sec_cli_len);
    TEST_ASSERT(sec_srv_len != 0);
    TEST_ASSERT(memcmp(sec_srv, sec_cli, sec_srv_len) == 0);

exit:
    mbedtls_dhm_free(&ctx_srv);
    mbedtls_dhm_free(&ctx_cli);
}
/* END_CASE */

/* BEGIN_CASE */
void dhm_make_public(int P_bytes, char *input_G, int result)
{
    mbedtls_mpi P, G;
    mbedtls_dhm_context ctx;
    unsigned char output[MBEDTLS_MPI_MAX_SIZE];

    mbedtls_mpi_init(&P);
    mbedtls_mpi_init(&G);
    mbedtls_dhm_init(&ctx);

    TEST_ASSERT(mbedtls_mpi_lset(&P, 1) == 0);
    TEST_ASSERT(mbedtls_mpi_shift_l(&P, (P_bytes * 8) - 1) == 0);
    TEST_ASSERT(mbedtls_mpi_set_bit(&P, 0, 1) == 0);

    TEST_ASSERT(mbedtls_test_read_mpi(&G, input_G) == 0);

    TEST_ASSERT(mbedtls_dhm_set_group(&ctx, &P, &G) == 0);
    TEST_ASSERT(mbedtls_dhm_make_public(&ctx, (int) mbedtls_mpi_size(&P),
                                        output, sizeof(output),
                                        &mbedtls_test_rnd_pseudo_rand,
                                        NULL) == result);

exit:
    mbedtls_mpi_free(&P);
    mbedtls_mpi_free(&G);
    mbedtls_dhm_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_FS_IO */
void dhm_file(char *filename, char *p, char *g, int len)
{
    mbedtls_dhm_context ctx;
    mbedtls_mpi P, G;

    mbedtls_dhm_init(&ctx);
    mbedtls_mpi_init(&P); mbedtls_mpi_init(&G);

    TEST_ASSERT(mbedtls_test_read_mpi(&P, p) == 0);
    TEST_ASSERT(mbedtls_test_read_mpi(&G, g) == 0);

    TEST_ASSERT(mbedtls_dhm_parse_dhmfile(&ctx, filename) == 0);

    TEST_EQUAL(mbedtls_dhm_get_len(&ctx), (size_t) len);
    TEST_EQUAL(mbedtls_dhm_get_bitlen(&ctx), mbedtls_mpi_bitlen(&P));
    TEST_ASSERT(check_get_value(&ctx, MBEDTLS_DHM_PARAM_P, &P));
    TEST_ASSERT(check_get_value(&ctx, MBEDTLS_DHM_PARAM_G, &G));

exit:
    mbedtls_mpi_free(&P); mbedtls_mpi_free(&G);
    mbedtls_dhm_free(&ctx);
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_SELF_TEST */
void dhm_selftest()
{
    TEST_ASSERT(mbedtls_dhm_self_test(1) == 0);
}
/* END_CASE */