diff --git a/library/ecp_curves.c b/library/ecp_curves.c index 7d029de1fc..186dabef20 100644 --- a/library/ecp_curves.c +++ b/library/ecp_curves.c @@ -5211,36 +5211,39 @@ cleanup: } MBEDTLS_STATIC_TESTABLE -int mbedtls_ecp_mod_p521_raw(mbedtls_mpi_uint *N_p, size_t N_n) +int mbedtls_ecp_mod_p521_raw(mbedtls_mpi_uint *X, size_t X_limbs) { mbedtls_mpi_uint carry = 0; - if (N_n > 2 * P521_WIDTH - 1) { - N_n = 2 * P521_WIDTH - 1; + if (X_limbs > 2 * P521_WIDTH - 1) { + X_limbs = 2 * P521_WIDTH - 1; } - if (N_n < P521_WIDTH) { + if (X_limbs < P521_WIDTH) { return 0; } - if (N_n > P521_WIDTH) { - /* Helper references for top part of N */ - mbedtls_mpi_uint *NT_p = N_p + P521_WIDTH; - size_t NT_n = N_n - P521_WIDTH; + if (X_limbs > P521_WIDTH) { + /* Helper references for bottom part of X */ + mbedtls_mpi_uint *X0 = X; + size_t X0_limbs = P521_WIDTH; + /* Helper references for top part of X */ + mbedtls_mpi_uint *X1 = X + X0_limbs; + size_t X1_limbs = X_limbs - X0_limbs; - /* Split N as A0 + 2^(512 + biL) A1 and compute A0 + 2^(biL - 9) * A1. + /* Split X as X0 + 2^(512 + biL) X1 and compute X0 + 2^(biL - 9) * X1. * This can be done in place. */ mbedtls_mpi_uint shift = ((mbedtls_mpi_uint) 1u) << (biL - 9); - carry = mbedtls_mpi_core_mla(N_p, P521_WIDTH, NT_p, NT_n, shift); + carry = mbedtls_mpi_core_mla(X0, X0_limbs, X1, X1_limbs, shift); /* Clear top part */ - memset(NT_p, 0, sizeof(mbedtls_mpi_uint) * NT_n); + memset(X1, 0, X1_limbs * sizeof(mbedtls_mpi_uint)); } - mbedtls_mpi_uint remainder[P521_WIDTH] = { 0 }; - remainder[0] = carry << (biL - 9); - remainder[0] += (N_p[P521_WIDTH - 1] >> 9); - N_p[P521_WIDTH - 1] &= P521_MASK; - (void) mbedtls_mpi_core_add(N_p, N_p, remainder, P521_WIDTH); + mbedtls_mpi_uint addend[P521_WIDTH] = { 0 }; + addend[0] = carry << (biL - 9); + addend[0] += (X[P521_WIDTH - 1] >> 9); + X[P521_WIDTH - 1] &= P521_MASK; + (void) mbedtls_mpi_core_add(X, X, addend, P521_WIDTH); return 0; }