Add RNG parameter to check_pair functions

- mbedtls_ecp_check_pub_priv() because it calls ecp_mul()
- mbedtls_pk_check_pair() because it calls the former

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
This commit is contained in:
Manuel Pégourié-Gonnard 2021-06-15 11:29:26 +02:00
parent f8c24bf507
commit 39be1410fd
6 changed files with 41 additions and 13 deletions

View File

@ -603,6 +603,8 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx,
* *
* \param pub Context holding a public key. * \param pub Context holding a public key.
* \param prv Context holding a private (and public) key. * \param prv Context holding a private (and public) key.
* \param f_rng RNG function, must not be \c NULL.
* \param p_rng RNG parameter
* *
* \return \c 0 on success (keys were checked and match each other). * \return \c 0 on success (keys were checked and match each other).
* \return #MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE if the keys could not * \return #MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE if the keys could not
@ -610,7 +612,10 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx,
* \return #MBEDTLS_ERR_PK_BAD_INPUT_DATA if a context is invalid. * \return #MBEDTLS_ERR_PK_BAD_INPUT_DATA if a context is invalid.
* \return Another non-zero value if the keys do not match. * \return Another non-zero value if the keys do not match.
*/ */
int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_context *prv ); int mbedtls_pk_check_pair( const mbedtls_pk_context *pub,
const mbedtls_pk_context *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng );
/** /**
* \brief Export debug information * \brief Export debug information

View File

@ -500,7 +500,10 @@ int mbedtls_pk_encrypt( mbedtls_pk_context *ctx,
/* /*
* Check public-private key pair * Check public-private key pair
*/ */
int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_context *prv ) int mbedtls_pk_check_pair( const mbedtls_pk_context *pub,
const mbedtls_pk_context *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng )
{ {
PK_VALIDATE_RET( pub != NULL ); PK_VALIDATE_RET( pub != NULL );
PK_VALIDATE_RET( prv != NULL ); PK_VALIDATE_RET( prv != NULL );
@ -511,6 +514,9 @@ int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_conte
return( MBEDTLS_ERR_PK_BAD_INPUT_DATA ); return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
} }
if( f_rng == NULL )
return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
if( prv->pk_info->check_pair_func == NULL ) if( prv->pk_info->check_pair_func == NULL )
return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE ); return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
@ -525,7 +531,7 @@ int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_conte
return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
} }
return( prv->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx ) ); return( prv->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx, f_rng, p_rng ) );
} }
/* /*

View File

@ -154,8 +154,12 @@ static int rsa_encrypt_wrap( void *ctx,
ilen, input, output ) ); ilen, input, output ) );
} }
static int rsa_check_pair_wrap( const void *pub, const void *prv ) static int rsa_check_pair_wrap( const void *pub, const void *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng )
{ {
(void) f_rng;
(void) p_rng;
return( mbedtls_rsa_check_pub_priv( (const mbedtls_rsa_context *) pub, return( mbedtls_rsa_check_pub_priv( (const mbedtls_rsa_context *) pub,
(const mbedtls_rsa_context *) prv ) ); (const mbedtls_rsa_context *) prv ) );
} }
@ -388,10 +392,13 @@ cleanup:
#endif /* MBEDTLS_ECP_RESTARTABLE */ #endif /* MBEDTLS_ECP_RESTARTABLE */
#endif /* MBEDTLS_ECDSA_C */ #endif /* MBEDTLS_ECDSA_C */
static int eckey_check_pair( const void *pub, const void *prv ) static int eckey_check_pair( const void *pub, const void *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng )
{ {
return( mbedtls_ecp_check_pub_priv( (const mbedtls_ecp_keypair *) pub, return( mbedtls_ecp_check_pub_priv( (const mbedtls_ecp_keypair *) pub,
(const mbedtls_ecp_keypair *) prv ) ); (const mbedtls_ecp_keypair *) prv,
f_rng, p_rng ) );
} }
static void *eckey_alloc_wrap( void ) static void *eckey_alloc_wrap( void )
@ -799,7 +806,9 @@ static int rsa_alt_decrypt_wrap( void *ctx,
} }
#if defined(MBEDTLS_RSA_C) #if defined(MBEDTLS_RSA_C)
static int rsa_alt_check_pair( const void *pub, const void *prv ) static int rsa_alt_check_pair( const void *pub, const void *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng )
{ {
unsigned char sig[MBEDTLS_MPI_MAX_SIZE]; unsigned char sig[MBEDTLS_MPI_MAX_SIZE];
unsigned char hash[32]; unsigned char hash[32];
@ -813,7 +822,7 @@ static int rsa_alt_check_pair( const void *pub, const void *prv )
if( ( ret = rsa_alt_sign_wrap( (void *) prv, MBEDTLS_MD_NONE, if( ( ret = rsa_alt_sign_wrap( (void *) prv, MBEDTLS_MD_NONE,
hash, sizeof( hash ), hash, sizeof( hash ),
sig, &sig_len, NULL, NULL ) ) != 0 ) sig, &sig_len, f_rng, p_rng ) ) != 0 )
{ {
return( ret ); return( ret );
} }

View File

@ -85,7 +85,9 @@ struct mbedtls_pk_info_t
void *p_rng ); void *p_rng );
/** Check public-private key pair */ /** Check public-private key pair */
int (*check_pair_func)( const void *pub, const void *prv ); int (*check_pair_func)( const void *pub, const void *prv,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng );
/** Allocate a new context */ /** Allocate a new context */
void * (*ctx_alloc_func)( void ); void * (*ctx_alloc_func)( void );

View File

@ -606,7 +606,8 @@ int main( int argc, char *argv[] )
// //
if( strlen( opt.issuer_crt ) ) if( strlen( opt.issuer_crt ) )
{ {
if( mbedtls_pk_check_pair( &issuer_crt.MBEDTLS_PRIVATE(pk), issuer_key ) != 0 ) if( mbedtls_pk_check_pair( &issuer_crt.MBEDTLS_PRIVATE(pk), issuer_key,
mbedtls_ctr_drbg_random, &ctr_drbg ) != 0 )
{ {
mbedtls_printf( " failed\n ! issuer_key does not match " mbedtls_printf( " failed\n ! issuer_key does not match "
"issuer certificate\n\n" ); "issuer certificate\n\n" );

View File

@ -177,7 +177,8 @@ void pk_psa_utils( )
/* unsupported functions: check_pair, debug */ /* unsupported functions: check_pair, debug */
TEST_ASSERT( mbedtls_pk_setup( &pk2, TEST_ASSERT( mbedtls_pk_setup( &pk2,
mbedtls_pk_info_from_type( MBEDTLS_PK_ECKEY ) ) == 0 ); mbedtls_pk_info_from_type( MBEDTLS_PK_ECKEY ) ) == 0 );
TEST_ASSERT( mbedtls_pk_check_pair( &pk, &pk2 ) TEST_ASSERT( mbedtls_pk_check_pair( &pk, &pk2,
mbedtls_test_rnd_std_rand, NULL )
== MBEDTLS_ERR_PK_TYPE_MISMATCH ); == MBEDTLS_ERR_PK_TYPE_MISMATCH );
TEST_ASSERT( mbedtls_pk_debug( &pk, &dbg ) TEST_ASSERT( mbedtls_pk_debug( &pk, &dbg )
== MBEDTLS_ERR_PK_TYPE_MISMATCH ); == MBEDTLS_ERR_PK_TYPE_MISMATCH );
@ -350,7 +351,9 @@ void mbedtls_pk_check_pair( char * pub_file, char * prv_file, int ret )
TEST_ASSERT( mbedtls_pk_parse_public_keyfile( &pub, pub_file ) == 0 ); TEST_ASSERT( mbedtls_pk_parse_public_keyfile( &pub, pub_file ) == 0 );
TEST_ASSERT( mbedtls_pk_parse_keyfile( &prv, prv_file, NULL ) == 0 ); TEST_ASSERT( mbedtls_pk_parse_keyfile( &prv, prv_file, NULL ) == 0 );
TEST_ASSERT( mbedtls_pk_check_pair( &pub, &prv ) == ret ); TEST_ASSERT( mbedtls_pk_check_pair( &pub, &prv,
mbedtls_test_rnd_std_rand, NULL )
== ret );
#if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_RSA_ALT_SUPPORT) #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
if( mbedtls_pk_get_type( &prv ) == MBEDTLS_PK_RSA ) if( mbedtls_pk_get_type( &prv ) == MBEDTLS_PK_RSA )
@ -358,7 +361,9 @@ void mbedtls_pk_check_pair( char * pub_file, char * prv_file, int ret )
TEST_ASSERT( mbedtls_pk_setup_rsa_alt( &alt, mbedtls_pk_rsa( prv ), TEST_ASSERT( mbedtls_pk_setup_rsa_alt( &alt, mbedtls_pk_rsa( prv ),
mbedtls_rsa_decrypt_func, mbedtls_rsa_sign_func, mbedtls_rsa_decrypt_func, mbedtls_rsa_sign_func,
mbedtls_rsa_key_len_func ) == 0 ); mbedtls_rsa_key_len_func ) == 0 );
TEST_ASSERT( mbedtls_pk_check_pair( &pub, &alt ) == ret ); TEST_ASSERT( mbedtls_pk_check_pair( &pub, &alt,
mbedtls_test_rnd_std_rand, NULL )
== ret );
} }
#endif #endif