From 05e92d67bbe15f30c9ad74c123362e1701125dad Mon Sep 17 00:00:00 2001
From: Mateusz Starzyk <mateusz.starzyk@mobica.com>
Date: Fri, 9 Jul 2021 12:44:07 +0200
Subject: [PATCH] Fix crypt mode configuration. Validate parameters in chunked
 input functions.

Signed-off-by: Mateusz Starzyk <mateusz.starzyk@mobica.com>
---
 library/ccm.c | 107 ++++++++++++++++++++++++++++----------------------
 1 file changed, 60 insertions(+), 47 deletions(-)

diff --git a/library/ccm.c b/library/ccm.c
index 4b1b499ad1..4f7ebfa827 100644
--- a/library/ccm.c
+++ b/library/ccm.c
@@ -52,8 +52,6 @@
 #define CCM_VALIDATE( cond ) \
     MBEDTLS_INTERNAL_VALIDATE( cond )
 
-#define CCM_ENCRYPT 0
-#define CCM_DECRYPT 1
 
 /*
  * Initialize context
@@ -174,6 +172,10 @@ static int mbedtls_ccm_calculate_first_block(mbedtls_ccm_context *ctx)
     if( !(ctx->state & CCM_STATE__STARTED) || !(ctx->state & CCM_STATE__LENGHTS_SET) )
         return 0;
 
+    if( ctx->tag_len == 0 && \
+        ( ctx->mode == MBEDTLS_CCM_ENCRYPT || ctx->mode == MBEDTLS_CCM_DECRYPT ) )
+        return( MBEDTLS_ERR_CCM_BAD_INPUT );
+
     /*
      * First block B_0:
      * 0        .. 0        flags
@@ -210,6 +212,13 @@ int mbedtls_ccm_starts( mbedtls_ccm_context *ctx,
                         const unsigned char *iv,
                         size_t iv_len )
 {
+    CCM_VALIDATE_RET( ctx != NULL );
+    CCM_VALIDATE_RET( iv != NULL );
+    CCM_VALIDATE_RET( mode == MBEDTLS_CCM_DECRYPT      || \
+                      mode == MBEDTLS_CCM_STAR_DECRYPT || \
+                      mode == MBEDTLS_CCM_ENCRYPT      || \
+                      mode == MBEDTLS_CCM_STAR_ENCRYPT );
+
     /* Also implies q is within bounds */
     if( iv_len < 7 || iv_len > 13 )
         return( MBEDTLS_ERR_CCM_BAD_INPUT );
@@ -252,6 +261,8 @@ int mbedtls_ccm_set_lengths( mbedtls_ccm_context *ctx,
                              size_t plaintext_len,
                              size_t tag_len )
 {
+    CCM_VALIDATE_RET( ctx != NULL );
+
     /*
      * Check length requirements: SP800-38C A.1
      * Additional requirement: a < 2^16 - 2^8 to simplify the code.
@@ -283,6 +294,8 @@ int mbedtls_ccm_update_ad( mbedtls_ccm_context *ctx,
                            const unsigned char *add,
                            size_t add_len )
 {
+    CCM_VALIDATE_RET( ctx->add_len == 0 || add != NULL );
+
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char i;
     size_t olen, use_len, offset;
@@ -331,6 +344,9 @@ int mbedtls_ccm_update( mbedtls_ccm_context *ctx,
                         unsigned char *output, size_t output_size,
                         size_t *output_len )
 {
+    CCM_VALIDATE_RET( ctx->plaintext_len == 0 || input != NULL );
+    CCM_VALIDATE_RET( ctx->plaintext_len == 0 || output != NULL );
+
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char i;
     size_t use_len, offset, olen;
@@ -359,7 +375,8 @@ int mbedtls_ccm_update( mbedtls_ccm_context *ctx,
 
         if( use_len + offset == 16 || ctx->processed == ctx->plaintext_len )
         {
-            if( ctx->mode == CCM_ENCRYPT )
+            if( ctx->mode == MBEDTLS_CCM_ENCRYPT || \
+                ctx->mode == MBEDTLS_CCM_STAR_ENCRYPT )
             {
                 UPDATE_CBC_MAC;
                 ret = mbedtls_ccm_crypt( ctx, 0, use_len, ctx->b, output );
@@ -368,7 +385,8 @@ int mbedtls_ccm_update( mbedtls_ccm_context *ctx,
                 memset( ctx->b, 0, 16 );
             }
 
-            if( ctx->mode == CCM_DECRYPT )
+            if( ctx->mode == MBEDTLS_CCM_DECRYPT || \
+                ctx->mode == MBEDTLS_CCM_STAR_DECRYPT )
             {
                 ret = mbedtls_ccm_crypt( ctx, 0, use_len, ctx->b, output );
                 if( ret != 0 )
@@ -402,6 +420,7 @@ int mbedtls_ccm_finish( mbedtls_ccm_context *ctx,
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char i;
 
+    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
     /*
      * Authentication: reset counter and crypt/mask internal tag
      */
@@ -457,13 +476,7 @@ int mbedtls_ccm_star_encrypt_and_tag( mbedtls_ccm_context *ctx, size_t length,
                          const unsigned char *input, unsigned char *output,
                          unsigned char *tag, size_t tag_len )
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-    return( ccm_auth_crypt( ctx, CCM_ENCRYPT, length, iv, iv_len,
+    return( ccm_auth_crypt( ctx, MBEDTLS_CCM_STAR_ENCRYPT, length, iv, iv_len,
                             add, add_len, input, output, tag, tag_len ) );
 }
 
@@ -473,17 +486,25 @@ int mbedtls_ccm_encrypt_and_tag( mbedtls_ccm_context *ctx, size_t length,
                          const unsigned char *input, unsigned char *output,
                          unsigned char *tag, size_t tag_len )
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-    if( tag_len == 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
+    return( ccm_auth_crypt( ctx, MBEDTLS_CCM_ENCRYPT, length, iv, iv_len,
+                            add, add_len, input, output, tag, tag_len ) );
+}
 
-    return( mbedtls_ccm_star_encrypt_and_tag( ctx, length, iv, iv_len, add,
-                add_len, input, output, tag, tag_len ) );
+static int mbedtls_ccm_compare_tags(const unsigned char *tag1, const unsigned char *tag2, size_t tag_len)
+{
+    unsigned char i;
+    int diff;
+
+    /* Check tag in "constant-time" */
+    for( diff = 0, i = 0; i < tag_len; i++ )
+        diff |= tag1[i] ^ tag2[i];
+
+    if( diff != 0 )
+    {
+        return( MBEDTLS_ERR_CCM_AUTH_FAILED );
+    }
+
+    return( 0 );
 }
 
 /*
@@ -497,31 +518,18 @@ int mbedtls_ccm_star_auth_decrypt( mbedtls_ccm_context *ctx, size_t length,
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char check_tag[16];
-    unsigned char i;
-    int diff;
 
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
-    if( ( ret = ccm_auth_crypt( ctx, CCM_DECRYPT, length,
+    if( ( ret = ccm_auth_crypt( ctx, MBEDTLS_CCM_STAR_DECRYPT, length,
                                 iv, iv_len, add, add_len,
                                 input, output, check_tag, tag_len ) ) != 0 )
     {
         return( ret );
     }
 
-    /* Check tag in "constant-time" */
-    for( diff = 0, i = 0; i < tag_len; i++ )
-        diff |= tag[i] ^ check_tag[i];
-
-    if( diff != 0 )
+    if( ( ret = mbedtls_ccm_compare_tags( tag, check_tag, tag_len ) ) != 0 )
     {
         mbedtls_platform_zeroize( output, length );
-        return( MBEDTLS_ERR_CCM_AUTH_FAILED );
+        return( ret );
     }
 
     return( 0 );
@@ -533,18 +541,23 @@ int mbedtls_ccm_auth_decrypt( mbedtls_ccm_context *ctx, size_t length,
                       const unsigned char *input, unsigned char *output,
                       const unsigned char *tag, size_t tag_len )
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char check_tag[16];
 
-    if( tag_len == 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
+    if( ( ret = ccm_auth_crypt( ctx, MBEDTLS_CCM_DECRYPT, length,
+                                iv, iv_len, add, add_len,
+                                input, output, check_tag, tag_len ) ) != 0 )
+    {
+        return( ret );
+    }
 
-    return( mbedtls_ccm_star_auth_decrypt( ctx, length, iv, iv_len, add,
-                add_len, input, output, tag, tag_len ) );
+    if( ( ret = mbedtls_ccm_compare_tags( tag, check_tag, tag_len ) ) != 0 )
+    {
+        mbedtls_platform_zeroize( output, length );
+        return( ret );
+    }
+
+    return( 0 );
 }
 #endif /* !MBEDTLS_CCM_ALT */