diff --git a/include/mbedtls/pk.h b/include/mbedtls/pk.h index 36af5d98d0..5c7b2f646b 100644 --- a/include/mbedtls/pk.h +++ b/include/mbedtls/pk.h @@ -603,6 +603,8 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx, * * \param pub Context holding a public key. * \param prv Context holding a private (and public) key. + * \param f_rng RNG function, must not be \c NULL. + * \param p_rng RNG parameter * * \return \c 0 on success (keys were checked and match each other). * \return #MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE if the keys could not @@ -610,7 +612,10 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx, * \return #MBEDTLS_ERR_PK_BAD_INPUT_DATA if a context is invalid. * \return Another non-zero value if the keys do not match. */ -int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_context *prv ); +int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, + const mbedtls_pk_context *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ); /** * \brief Export debug information diff --git a/library/pk.c b/library/pk.c index 06021e26c0..275d34bb1b 100644 --- a/library/pk.c +++ b/library/pk.c @@ -500,7 +500,10 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx, /* * Check public-private key pair */ -int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_context *prv ) +int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, + const mbedtls_pk_context *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ) { PK_VALIDATE_RET( pub != NULL ); PK_VALIDATE_RET( prv != NULL ); @@ -511,6 +514,9 @@ int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_conte return( MBEDTLS_ERR_PK_BAD_INPUT_DATA ); } + if( f_rng == NULL ) + return( MBEDTLS_ERR_PK_BAD_INPUT_DATA ); + if( prv->pk_info->check_pair_func == NULL ) return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE ); @@ -525,7 +531,7 @@ int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_conte return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); } - return( prv->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx ) ); + return( prv->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx, f_rng, p_rng ) ); } /* diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 7c317c52d3..864e495b3c 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -154,8 +154,12 @@ static int rsa_encrypt_wrap( void *ctx, ilen, input, output ) ); } -static int rsa_check_pair_wrap( const void *pub, const void *prv ) +static int rsa_check_pair_wrap( const void *pub, const void *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ) { + (void) f_rng; + (void) p_rng; return( mbedtls_rsa_check_pub_priv( (const mbedtls_rsa_context *) pub, (const mbedtls_rsa_context *) prv ) ); } @@ -388,10 +392,13 @@ cleanup: #endif /* MBEDTLS_ECP_RESTARTABLE */ #endif /* MBEDTLS_ECDSA_C */ -static int eckey_check_pair( const void *pub, const void *prv ) +static int eckey_check_pair( const void *pub, const void *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ) { return( mbedtls_ecp_check_pub_priv( (const mbedtls_ecp_keypair *) pub, - (const mbedtls_ecp_keypair *) prv ) ); + (const mbedtls_ecp_keypair *) prv, + f_rng, p_rng ) ); } static void *eckey_alloc_wrap( void ) @@ -799,7 +806,9 @@ static int rsa_alt_decrypt_wrap( void *ctx, } #if defined(MBEDTLS_RSA_C) -static int rsa_alt_check_pair( const void *pub, const void *prv ) +static int rsa_alt_check_pair( const void *pub, const void *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ) { unsigned char sig[MBEDTLS_MPI_MAX_SIZE]; unsigned char hash[32]; @@ -813,7 +822,7 @@ static int rsa_alt_check_pair( const void *pub, const void *prv ) if( ( ret = rsa_alt_sign_wrap( (void *) prv, MBEDTLS_MD_NONE, hash, sizeof( hash ), - sig, &sig_len, NULL, NULL ) ) != 0 ) + sig, &sig_len, f_rng, p_rng ) ) != 0 ) { return( ret ); } diff --git a/library/pk_wrap.h b/library/pk_wrap.h index f7f938a88d..b2db63739f 100644 --- a/library/pk_wrap.h +++ b/library/pk_wrap.h @@ -85,7 +85,9 @@ struct mbedtls_pk_info_t void *p_rng ); /** Check public-private key pair */ - int (*check_pair_func)( const void *pub, const void *prv ); + int (*check_pair_func)( const void *pub, const void *prv, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ); /** Allocate a new context */ void * (*ctx_alloc_func)( void ); diff --git a/programs/x509/cert_write.c b/programs/x509/cert_write.c index ff7cf98074..041f459cfc 100644 --- a/programs/x509/cert_write.c +++ b/programs/x509/cert_write.c @@ -606,7 +606,8 @@ int main( int argc, char *argv[] ) // if( strlen( opt.issuer_crt ) ) { - if( mbedtls_pk_check_pair( &issuer_crt.MBEDTLS_PRIVATE(pk), issuer_key ) != 0 ) + if( mbedtls_pk_check_pair( &issuer_crt.MBEDTLS_PRIVATE(pk), issuer_key, + mbedtls_ctr_drbg_random, &ctr_drbg ) != 0 ) { mbedtls_printf( " failed\n ! issuer_key does not match " "issuer certificate\n\n" ); diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function index 573c9d4306..b46cf05cfb 100644 --- a/tests/suites/test_suite_pk.function +++ b/tests/suites/test_suite_pk.function @@ -177,7 +177,8 @@ void pk_psa_utils( ) /* unsupported functions: check_pair, debug */ TEST_ASSERT( mbedtls_pk_setup( &pk2, mbedtls_pk_info_from_type( MBEDTLS_PK_ECKEY ) ) == 0 ); - TEST_ASSERT( mbedtls_pk_check_pair( &pk, &pk2 ) + TEST_ASSERT( mbedtls_pk_check_pair( &pk, &pk2, + mbedtls_test_rnd_std_rand, NULL ) == MBEDTLS_ERR_PK_TYPE_MISMATCH ); TEST_ASSERT( mbedtls_pk_debug( &pk, &dbg ) == MBEDTLS_ERR_PK_TYPE_MISMATCH ); @@ -350,7 +351,9 @@ void mbedtls_pk_check_pair( char * pub_file, char * prv_file, int ret ) TEST_ASSERT( mbedtls_pk_parse_public_keyfile( &pub, pub_file ) == 0 ); TEST_ASSERT( mbedtls_pk_parse_keyfile( &prv, prv_file, NULL ) == 0 ); - TEST_ASSERT( mbedtls_pk_check_pair( &pub, &prv ) == ret ); + TEST_ASSERT( mbedtls_pk_check_pair( &pub, &prv, + mbedtls_test_rnd_std_rand, NULL ) + == ret ); #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_RSA_ALT_SUPPORT) if( mbedtls_pk_get_type( &prv ) == MBEDTLS_PK_RSA ) @@ -358,7 +361,9 @@ void mbedtls_pk_check_pair( char * pub_file, char * prv_file, int ret ) TEST_ASSERT( mbedtls_pk_setup_rsa_alt( &alt, mbedtls_pk_rsa( prv ), mbedtls_rsa_decrypt_func, mbedtls_rsa_sign_func, mbedtls_rsa_key_len_func ) == 0 ); - TEST_ASSERT( mbedtls_pk_check_pair( &pub, &alt ) == ret ); + TEST_ASSERT( mbedtls_pk_check_pair( &pub, &alt, + mbedtls_test_rnd_std_rand, NULL ) + == ret ); } #endif