diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index e76086a12c..33a6533750 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2498,6 +2498,13 @@ MBEDTLS_CHECK_RETURN_CRITICAL
 int mbedtls_ssl_tls13_write_binders_of_pre_shared_key_ext(
     mbedtls_ssl_context *ssl,
     unsigned char *buf, unsigned char *end );
+
+/**
+ * \brief Remove psk from handshake context
+ *
+ * \param[in]   ssl     SSL context
+ */
+void mbedtls_ssl_remove_psk( mbedtls_ssl_context *ssl );
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index fe38a0939a..1cda3a7444 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1662,7 +1662,7 @@ int mbedtls_ssl_conf_psk( mbedtls_ssl_config *conf,
     return( ret );
 }
 
-static void ssl_remove_psk( mbedtls_ssl_context *ssl )
+void mbedtls_ssl_remove_psk( mbedtls_ssl_context *ssl )
 {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) )
@@ -1682,6 +1682,7 @@ static void ssl_remove_psk( mbedtls_ssl_context *ssl )
         mbedtls_platform_zeroize( ssl->handshake->psk,
                                   ssl->handshake->psk_len );
         mbedtls_free( ssl->handshake->psk );
+        ssl->handshake->psk = NULL;
         ssl->handshake->psk_len = 0;
     }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
@@ -1703,7 +1704,7 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
     if( psk_len > MBEDTLS_PSK_MAX_LEN )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-    ssl_remove_psk( ssl );
+    mbedtls_ssl_remove_psk( ssl );
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
@@ -1780,7 +1781,7 @@ int mbedtls_ssl_set_hs_psk_opaque( mbedtls_ssl_context *ssl,
         ( ssl->handshake == NULL ) )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-    ssl_remove_psk( ssl );
+    mbedtls_ssl_remove_psk( ssl );
     ssl->handshake->psk_opaque = psk;
     return( 0 );
 }
@@ -3522,25 +3523,7 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl )
 #endif
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) )
-    {
-        /* The maintenance of the external PSK key slot is the
-         * user's responsibility. */
-        if( ssl->handshake->psk_opaque_is_internal )
-        {
-            psa_destroy_key( ssl->handshake->psk_opaque );
-            ssl->handshake->psk_opaque_is_internal = 0;
-        }
-        ssl->handshake->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT;
-    }
-#else
-    if( handshake->psk != NULL )
-    {
-        mbedtls_platform_zeroize( handshake->psk, handshake->psk_len );
-        mbedtls_free( handshake->psk );
-    }
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    mbedtls_ssl_remove_psk( ssl );
 #endif
 
 #if defined(MBEDTLS_X509_CRT_PARSE_C) && \
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 15656fe7f8..d527959542 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -761,6 +761,9 @@ static int ssl_tls13_determine_key_exchange_mode( mbedtls_ssl_context *ssl )
     else
     if( ssl_tls13_check_ephemeral_key_exchange( ssl ) )
     {
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+        mbedtls_ssl_remove_psk( ssl );
+#endif
         ssl->handshake->key_exchange_mode =
             MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL;
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "key exchange mode: ephemeral" ) );