diff --git a/library/bignum.c b/library/bignum.c index b11239e274..5cd1c3e842 100644 --- a/library/bignum.c +++ b/library/bignum.c @@ -1339,29 +1339,32 @@ cleanup: /** * Helper for mbedtls_mpi subtraction. * - * Calculate d - s where d and s have the same size. + * Calculate l - r where l and r have the same size. * This function operates modulo (2^ciL)^n and returns the carry - * (1 if there was a wraparound, i.e. if `d < s`, and 0 otherwise). + * (1 if there was a wraparound, i.e. if `l < r`, and 0 otherwise). * - * \param n Number of limbs of \p d and \p s. - * \param[in,out] d On input, the left operand. - * On output, the result of the subtraction: - * \param[in] s The right operand. + * d may be aliased to l or r. * - * \return 1 if `d < s`. - * 0 if `d >= s`. + * \param n Number of limbs of \p d, \p l and \p r. + * \param[out] d The result of the subtraction. + * \param[in] l The left operand. + * \param[in] r The right operand. + * + * \return 1 if `l < r`. + * 0 if `l >= r`. */ static mbedtls_mpi_uint mpi_sub_hlp( size_t n, mbedtls_mpi_uint *d, - const mbedtls_mpi_uint *s ) + const mbedtls_mpi_uint *l, + const mbedtls_mpi_uint *r ) { size_t i; - mbedtls_mpi_uint c, z; + mbedtls_mpi_uint c = 0, t, z; - for( i = c = 0; i < n; i++, s++, d++ ) + for( i = 0; i < n; i++ ) { - z = ( *d < c ); *d -= c; - c = ( *d < *s ) + z; *d -= *s; + z = ( l[i] < c ); t = l[i] - c; + c = ( t < r[i] ) + z; d[i] = t - r[i]; } return( c ); @@ -1372,7 +1375,6 @@ static mbedtls_mpi_uint mpi_sub_hlp( size_t n, */ int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B ) { - mbedtls_mpi TB; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; size_t n; mbedtls_mpi_uint carry; @@ -1380,29 +1382,21 @@ int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi MPI_VALIDATE_RET( A != NULL ); MPI_VALIDATE_RET( B != NULL ); - mbedtls_mpi_init( &TB ); - - if( X == B ) - { - MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) ); - B = &TB; - } - - if( X != A ) - MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, A ) ); - - /* - * X should always be positive as a result of unsigned subtractions. - */ - X->s = 1; - - ret = 0; - for( n = B->n; n > 0; n-- ) if( B->p[n - 1] != 0 ) break; - carry = mpi_sub_hlp( n, X->p, B->p ); + MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, A->n ) ); + + /* Set the high limbs of X to match A. Don't touch the lower limbs + * because X might be aliased to B, and we must not overwrite the + * significant digits of B. */ + if( A->n > n ) + memcpy( X->p + n, A->p + n, ( A->n - n ) * ciL ); + if( X->n > A->n ) + memset( X->p + A->n, 0, ( X->n - A->n ) * ciL ); + + carry = mpi_sub_hlp( n, X->p, A->p, B->p ); if( carry != 0 ) { /* Propagate the carry to the first nonzero limb of X. */ @@ -1418,10 +1412,10 @@ int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi --X->p[n]; } + /* X should always be positive as a result of unsigned subtractions. */ + X->s = 1; + cleanup: - - mbedtls_mpi_free( &TB ); - return( ret ); } @@ -2065,7 +2059,7 @@ static void mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi * do the calculation without using conditional tests. */ /* Set d to d0 + (2^biL)^n - N where d0 is the current value of d. */ d[n] += 1; - d[n] -= mpi_sub_hlp( n, d, N->p ); + d[n] -= mpi_sub_hlp( n, d, d, N->p ); /* If d0 < N then d < (2^biL)^n * so d[n] == 0 and we want to keep A as it is. * If d0 >= N then d >= (2^biL)^n, and d <= (2^biL)^n + N < 2 * (2^biL)^n