diff --git a/library/ssl_misc.h b/library/ssl_misc.h index e76086a12c..33a6533750 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -2498,6 +2498,13 @@ MBEDTLS_CHECK_RETURN_CRITICAL int mbedtls_ssl_tls13_write_binders_of_pre_shared_key_ext( mbedtls_ssl_context *ssl, unsigned char *buf, unsigned char *end ); + +/** + * \brief Remove psk from handshake context + * + * \param[in] ssl SSL context + */ +void mbedtls_ssl_remove_psk( mbedtls_ssl_context *ssl ); #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */ #endif /* ssl_misc.h */ diff --git a/library/ssl_tls.c b/library/ssl_tls.c index fe38a0939a..1cda3a7444 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1662,7 +1662,7 @@ int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf, return( ret ); } -static void ssl_remove_psk( mbedtls_ssl_context *ssl ) +void mbedtls_ssl_remove_psk( mbedtls_ssl_context *ssl ) { #if defined(MBEDTLS_USE_PSA_CRYPTO) if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) ) @@ -1682,6 +1682,7 @@ static void ssl_remove_psk( mbedtls_ssl_context *ssl ) mbedtls_platform_zeroize( ssl->handshake->psk, ssl->handshake->psk_len ); mbedtls_free( ssl->handshake->psk ); + ssl->handshake->psk = NULL; ssl->handshake->psk_len = 0; } #endif /* MBEDTLS_USE_PSA_CRYPTO */ @@ -1703,7 +1704,7 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, if( psk_len > MBEDTLS_PSK_MAX_LEN ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - ssl_remove_psk( ssl ); + mbedtls_ssl_remove_psk( ssl ); #if defined(MBEDTLS_USE_PSA_CRYPTO) #if defined(MBEDTLS_SSL_PROTO_TLS1_2) @@ -1780,7 +1781,7 @@ int mbedtls_ssl_set_hs_psk_opaque( mbedtls_ssl_context *ssl, ( ssl->handshake == NULL ) ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - ssl_remove_psk( ssl ); + mbedtls_ssl_remove_psk( ssl ); ssl->handshake->psk_opaque = psk; return( 0 ); } @@ -3522,25 +3523,7 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl ) #endif #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) -#if defined(MBEDTLS_USE_PSA_CRYPTO) - if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) ) - { - /* The maintenance of the external PSK key slot is the - * user's responsibility. */ - if( ssl->handshake->psk_opaque_is_internal ) - { - psa_destroy_key( ssl->handshake->psk_opaque ); - ssl->handshake->psk_opaque_is_internal = 0; - } - ssl->handshake->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT; - } -#else - if( handshake->psk != NULL ) - { - mbedtls_platform_zeroize( handshake->psk, handshake->psk_len ); - mbedtls_free( handshake->psk ); - } -#endif /* MBEDTLS_USE_PSA_CRYPTO */ + mbedtls_ssl_remove_psk( ssl ); #endif #if defined(MBEDTLS_X509_CRT_PARSE_C) && \ diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index 15656fe7f8..d527959542 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -761,6 +761,9 @@ static int ssl_tls13_determine_key_exchange_mode( mbedtls_ssl_context *ssl ) else if( ssl_tls13_check_ephemeral_key_exchange( ssl ) ) { +#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED) + mbedtls_ssl_remove_psk( ssl ); +#endif ssl->handshake->key_exchange_mode = MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL; MBEDTLS_SSL_DEBUG_MSG( 2, ( "key exchange mode: ephemeral" ) );