From 259c2135457e77bdcccb407fe4b4c1b6269efb05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= Date: Fri, 15 Jul 2022 12:09:08 +0200 Subject: [PATCH] Tune API of internal function mgf_mask in RSA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a first step towards making a version of this function that uses PSA when MD is not available. Signed-off-by: Manuel Pégourié-Gonnard --- library/rsa.c | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/library/rsa.c b/library/rsa.c index 17a7d9e7c8..74390af144 100644 --- a/library/rsa.c +++ b/library/rsa.c @@ -1095,11 +1095,13 @@ cleanup: * \param dlen length of destination buffer * \param src source of the mask generation * \param slen length of the source buffer - * \param md_ctx message digest context to use + * \param md_alg message digest to use */ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src, - size_t slen, mbedtls_md_context_t *md_ctx ) + size_t slen, mbedtls_md_type_t md_alg ) { + const mbedtls_md_info_t *md_info; + mbedtls_md_context_t md_ctx; unsigned char mask[MBEDTLS_MD_MAX_SIZE]; unsigned char counter[4]; unsigned char *p; @@ -1107,10 +1109,19 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src, size_t i, use_len; int ret = 0; + mbedtls_md_init( &md_ctx ); memset( mask, 0, MBEDTLS_MD_MAX_SIZE ); memset( counter, 0, 4 ); - hlen = mbedtls_md_get_size( md_ctx->md_info ); + md_info = mbedtls_md_info_from_type( md_alg ); + if( md_info == NULL ) + return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA ); + + mbedtls_md_init( &md_ctx ); + if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 ) + goto exit; + + hlen = mbedtls_md_get_size( md_info ); /* Generate and apply dbMask */ p = dst; @@ -1121,13 +1132,13 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src, if( dlen < hlen ) use_len = dlen; - if( ( ret = mbedtls_md_starts( md_ctx ) ) != 0 ) + if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 ) goto exit; - if( ( ret = mbedtls_md_update( md_ctx, src, slen ) ) != 0 ) + if( ( ret = mbedtls_md_update( &md_ctx, src, slen ) ) != 0 ) goto exit; - if( ( ret = mbedtls_md_update( md_ctx, counter, 4 ) ) != 0 ) + if( ( ret = mbedtls_md_update( &md_ctx, counter, 4 ) ) != 0 ) goto exit; - if( ( ret = mbedtls_md_finish( md_ctx, mask ) ) != 0 ) + if( ( ret = mbedtls_md_finish( &md_ctx, mask ) ) != 0 ) goto exit; for( i = 0; i < use_len; ++i ) @@ -1139,6 +1150,7 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src, } exit: + mbedtls_md_free( &md_ctx ); mbedtls_platform_zeroize( mask, sizeof( mask ) ); return( ret ); @@ -1208,12 +1220,12 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx, /* maskedDB: Apply dbMask to DB */ if( ( ret = mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen, - &md_ctx ) ) != 0 ) + ctx->hash_id ) ) != 0 ) goto exit; /* maskedSeed: Apply seedMask to seed */ if( ( ret = mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1, - &md_ctx ) ) != 0 ) + ctx->hash_id ) ) != 0 ) goto exit; exit: @@ -1384,10 +1396,10 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx, /* seed: Apply seedMask to maskedSeed */ if( ( ret = mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1, - &md_ctx ) ) != 0 || + ctx->hash_id ) ) != 0 || /* DB: Apply dbMask to maskedDB */ ( ret = mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen, - &md_ctx ) ) != 0 ) + ctx->hash_id ) ) != 0 ) { mbedtls_md_free( &md_ctx ); goto cleanup; @@ -1648,7 +1660,7 @@ static int rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx, /* maskedDB: Apply dbMask to DB */ if( ( ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, - &md_ctx ) ) != 0 ) + ctx->hash_id ) ) != 0 ) goto exit; msb = mbedtls_mpi_bitlen( &ctx->N ) - 1; @@ -2029,7 +2041,7 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx, if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 ) goto exit; - ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, &md_ctx ); + ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, mgf1_hash_id ); if( ret != 0 ) goto exit;