From 2e9711f7667880720f2aa6bb592823f88a1fd3ff Mon Sep 17 00:00:00 2001
From: Przemyslaw Stekiel <przemyslaw.stekiel@mobica.com>
Date: Thu, 13 Jan 2022 14:50:15 +0100
Subject: [PATCH] mbedtls_ssl_decrypt_buf(): replace mbedtls_cipher_crypt() and
 mbedtls_cipher_auth_decrypt_ext() with PSA calls

Signed-off-by: Przemyslaw Stekiel <przemyslaw.stekiel@mobica.com>
---
 library/ssl_msg.c | 90 ++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 89 insertions(+), 1 deletion(-)

diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 8964369021..fe7d1e5cb8 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -784,7 +784,6 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
             ssl_transform_aead_dynamic_iv_is_explicit( transform );
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
         psa_status_t status;
-        psa_cipher_operation_t cipher_op = PSA_CIPHER_OPERATION_INIT;
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 
@@ -1127,6 +1126,41 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
     if( mode == MBEDTLS_MODE_STREAM )
     {
         padlen = 0;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        psa_status_t status;
+        size_t part_len;
+        psa_cipher_operation_t cipher_op = PSA_CIPHER_OPERATION_INIT;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        status = psa_cipher_decrypt_setup( &cipher_op,
+                                     transform->psa_key_dec, transform->psa_alg );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_set_iv( &cipher_op, transform->iv_dec, transform->ivlen );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_update( &cipher_op,
+                                    data, rec->data_len,
+                                    data, rec->data_len, &olen );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_finish( &cipher_op,
+                                    data + olen, rec->data_len - olen,
+                                    &part_len );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        olen += part_len;
+#else
+
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_dec,
                                    transform->iv_dec,
                                    transform->ivlen,
@@ -1136,12 +1170,14 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
         if( rec->data_len != olen )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
             return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
         }
+
     }
     else
 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_STREAM */
@@ -1155,6 +1191,9 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
         unsigned char iv[12];
         unsigned char *dynamic_iv;
         size_t dynamic_iv_len;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        psa_status_t status;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
         /*
          * Extract dynamic part of nonce for AEAD decryption.
@@ -1229,6 +1268,18 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
         /*
          * Decrypt and authenticate
          */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        status = psa_aead_decrypt( transform->psa_key_dec,
+                               transform->psa_alg,
+                               iv, transform->ivlen,
+                               add_data, add_data_len,
+                               data, rec->data_len + transform->taglen,
+                               data, rec->buf_len - (data - rec->buf),
+                               &rec->data_len );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+#else
         if( ( ret = mbedtls_cipher_auth_decrypt_ext( &transform->cipher_ctx_dec,
                   iv, transform->ivlen,
                   add_data, add_data_len,
@@ -1243,6 +1294,8 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
 
             return( ret );
         }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
         auth_done++;
 
         /* Double-check that AEAD decryption doesn't change content length. */
@@ -1258,6 +1311,11 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
     if( mode == MBEDTLS_MODE_CBC )
     {
         size_t minlen = 0;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        psa_status_t status;
+        size_t part_len;
+        psa_cipher_operation_t cipher_op = PSA_CIPHER_OPERATION_INIT;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
         /*
          * Check immediate ciphertext sanity
@@ -1398,6 +1456,35 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
 
         /* We still have data_len % ivlen == 0 and data_len >= ivlen here. */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        status = psa_cipher_decrypt_setup( &cipher_op,
+                                     transform->psa_key_dec, transform->psa_alg );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_set_iv( &cipher_op, transform->iv_dec, transform->ivlen );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_update( &cipher_op,
+                                    data, rec->data_len,
+                                    data, rec->data_len, &olen );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        status = psa_cipher_finish( &cipher_op,
+                                    data + olen, rec->data_len - olen,
+                                    &part_len );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        olen += part_len;
+#else
+
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_dec,
                                    transform->iv_dec, transform->ivlen,
                                    data, rec->data_len, data, &olen ) ) != 0 )
@@ -1405,6 +1492,7 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
         /* Double-check that length hasn't changed during decryption. */
         if( rec->data_len != olen )