diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 19b8a41351..fe38a0939a 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1693,7 +1693,7 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     psa_key_attributes_t key_attributes = psa_key_attributes_init();
     psa_status_t status;
-    psa_algorithm_t alg;
+    psa_algorithm_t alg = PSA_ALG_ANY_HASH;
     mbedtls_svc_key_id_t key;
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
@@ -1706,17 +1706,26 @@ int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
     ssl_remove_psk( ssl );
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if( ssl->handshake->ciphersuite_info->mac == MBEDTLS_MD_SHA384)
-        alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_384);
-    else
-        alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256);
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+    if( ssl->tls_version == MBEDTLS_SSL_VERSION_TLS1_2 )
+    {
+        if( ssl->handshake->ciphersuite_info->mac == MBEDTLS_MD_SHA384)
+            alg = PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_384 );
+        else
+            alg = PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_256 );
+        psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+    }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-    psa_set_key_usage_flags( &key_attributes,
-                             PSA_KEY_USAGE_DERIVE | PSA_KEY_USAGE_EXPORT );
-#else
-    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
-#endif
+    if( ssl->tls_version == MBEDTLS_SSL_VERSION_TLS1_3 )
+    {
+        alg = PSA_ALG_HKDF_EXTRACT( PSA_ALG_ANY_HASH );
+        psa_set_key_usage_flags( &key_attributes,
+                                 PSA_KEY_USAGE_DERIVE | PSA_KEY_USAGE_EXPORT );
+    }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
+
     psa_set_key_algorithm( &key_attributes, alg );
     psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE );