From 008d2bf80b813138d1375bc73ab617812cd7103d Mon Sep 17 00:00:00 2001
From: XiaokangQian <xiaokang.qian@arm.com>
Date: Thu, 14 Jul 2022 07:54:01 +0000
Subject: [PATCH] Address comments in psk client review

Improve comments
Refine cipher suite related code in psk
Refine get_psk_offered()

Change-Id: Ic3b0b5f86eb1e71f11bb499961aa8494284f1840
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
---
 library/ssl_client.c        |   2 +-
 library/ssl_misc.h          |  39 +++------
 library/ssl_tls13_client.c  | 152 ++++++++++++------------------------
 library/ssl_tls13_generic.c |  38 +++++++++
 4 files changed, 97 insertions(+), 134 deletions(-)

diff --git a/library/ssl_client.c b/library/ssl_client.c
index 5e775cca41..8e4e9688fa 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -653,7 +653,7 @@ static int ssl_write_client_hello_body( mbedtls_ssl_context *ssl,
             return( ret );
         p += output_len;
     }
-#endif /* MBEDTLS_SSL_PROTO_TLS1_3 || MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 && MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
     /* Write the length of the list of extensions. */
     extensions_len = p - p_extensions_len - 2;
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 5037e449b3..30c3c3a64c 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2423,34 +2423,11 @@ int mbedtls_ssl_check_dtls_clihlo_cookie(
 /* Check if we have any PSK to offer, returns 0 if PSK is available.
  * Assign the psk and ticket if pointers are present.
  */
-static inline int mbedtls_ssl_get_psk_to_offer(
+int mbedtls_ssl_get_psk_to_offer(
         const mbedtls_ssl_context *ssl,
+        int *psk_type,
         const unsigned char **psk, size_t *psk_len,
-        const unsigned char **psk_identity, size_t *psk_identity_len )
-{
-    int ptrs_present = 0;
-
-    if( psk != NULL && psk_len != NULL &&
-        psk_identity != NULL && psk_identity_len != NULL )
-    {
-        ptrs_present = 1;
-    }
-
-    /* Check if an external PSK has been configured. */
-    if( ssl->conf->psk != NULL )
-    {
-        if( ptrs_present )
-        {
-            *psk = ssl->conf->psk;
-            *psk_len = ssl->conf->psk_len;
-            *psk_identity = ssl->conf->psk_identity;
-            *psk_identity_len = ssl->conf->psk_identity_len;
-        }
-        return( 0 );
-    }
-
-    return( 1 );
-}
+        const unsigned char **psk_identity, size_t *psk_identity_len );
 
 /**
  * \brief Given an SSL context and its associated configuration, write the TLS
@@ -2459,9 +2436,11 @@ static inline int mbedtls_ssl_get_psk_to_offer(
  * \param[in]   ssl     SSL context
  * \param[in]   buf     Base address of the buffer where to write the extension
  * \param[in]   end     End address of the buffer where to write the extension
- * \param[out]  out_len Length of the data written into the buffer \p buf
+ * \param[out]  out_len Length in bytes of the Pre-Shared key extension: data
+ *                      written into the buffer \p buf by this function plus
+ *                      the length of the binders to be written.
  * \param[out]  binders_len Length of the binders to be written at the end of
- *                          extension
+ *                          the extension.
  */
 int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
     mbedtls_ssl_context *ssl,
@@ -2474,8 +2453,8 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
  *        ClientHello.
  *
  * \param[in]   ssl     SSL context
- * \param[in]   buf     Base address of the buffer where to write the extension
- * \param[in]   end     End address of the buffer where to write the extension
+ * \param[in]   buf     Base address of the buffer where to write the binders
+ * \param[in]   end     End address of the buffer where to write the binders
  */
 int mbedtls_ssl_tls13_write_pre_shared_key_ext_binders(
     mbedtls_ssl_context *ssl,
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 62f00fac8c..6e82631f94 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -615,15 +615,16 @@ static int ssl_tls13_write_psk_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
     const unsigned char *psk_identity;
     size_t psk_identity_len;
     unsigned char *p = buf;
-    int num_modes = 0;
+    int ke_modes_len = 0;
 
-    ((void) num_modes );
+    ((void) ke_modes_len );
     *out_len = 0;
+
     /* Skip writing extension if no PSK key exchange mode
-     * is enabled in the config.
+     * is enabled in the config or there is no PSK to offer.
      */
     if( !mbedtls_ssl_conf_tls13_some_psk_enabled( ssl ) ||
-         mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+         mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip psk_key_exchange_modes extension" ) );
@@ -637,18 +638,17 @@ static int ssl_tls13_write_psk_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
     MBEDTLS_SSL_DEBUG_MSG(
             3, ( "client hello, adding psk_key_exchange_modes extension" ) );
 
-    /* Extension Type */
     MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES, p, 0 );
 
-    /* Skip extension length (2 byte) and
-     * PSK mode list length (1 byte) for now.
+    /* Skip extension length (2 bytes) and
+     * ke_modes length (1 byte) for now.
      */
     p += 5;
 
     if( mbedtls_ssl_conf_tls13_psk_enabled( ssl ) )
     {
         *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_PURE;
-        num_modes++;
+        ke_modes_len++;
 
         MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding pure PSK key exchange mode" ) );
     }
@@ -656,17 +656,14 @@ static int ssl_tls13_write_psk_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
     if( mbedtls_ssl_conf_tls13_psk_ephemeral_enabled( ssl ) )
     {
         *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_ECDHE;
-        num_modes++;
+        ke_modes_len++;
 
         MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding PSK-ECDHE key exchange mode" ) );
     }
 
-    /* Add extension length: PSK mode list length byte + actual
-     * PSK mode list length
-     */
-    MBEDTLS_PUT_UINT16_BE( num_modes + 1, buf, 2 );
-    /* Add PSK mode list length */
-    buf[4] = num_modes;
+    /* Now write the extension and ke_modes length */
+    MBEDTLS_PUT_UINT16_BE( ke_modes_len + 1, buf, 2 );
+    buf[4] = ke_modes_len;
 
     *out_len = p - buf;
     ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
@@ -685,22 +682,12 @@ static int ssl_tls13_write_psk_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
  * opaque PskBinderEntry<32..255>;
  *
  * struct {
- *   select ( Handshake.msg_type ) {
  *
- *     case client_hello:
- *       PskIdentity identities<7..2^16-1>;
- *       PskBinderEntry binders<33..2^16-1>;
- *
- *     case server_hello:
- *       uint16 selected_identity;
- *   };
+ *     PskIdentity identities<7..2^16-1>;
+ *     PskBinderEntry binders<33..2^16-1>;
  *
  * } PreSharedKeyExtension;
  *
- *
- * part = 0 ==> everything up to the PSK binder list,
- *              returning the binder list length in `binder_list_length`.
- * part = 1 ==> the PSK binder list
  */
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
@@ -715,12 +702,12 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
     size_t psk_len;
     const unsigned char *psk_identity;
     size_t psk_identity_len;
-    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info = NULL;
     const int *ciphersuites;
+    psa_algorithm_t psa_hash_alg;
     int hash_len = 0;
     size_t identities_len, l_binders_len;
     uint32_t obfuscated_ticket_age = 0;
-    psa_algorithm_t psa_hash_alg;
 
     *out_len = 0;
     *binders_len = 0;
@@ -738,7 +725,7 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
      * - Otherwise, skip the PSK extension.
      */
 
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+    if( mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip pre_shared_key extensions" ) );
@@ -751,22 +738,27 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
     ciphersuites = ssl->conf->ciphersuite_list;
     for ( int i = 0; ciphersuites[i] != 0; i++ )
     {
-        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
 
-        if( suite_info == NULL )
+        if( mbedtls_ssl_validate_ciphersuite(
+                                ssl, ciphersuite_info,
+                                MBEDTLS_SSL_VERSION_TLS1_3,
+                                MBEDTLS_SSL_VERSION_TLS1_3 ) != 0 )
             continue;
 
         /* In this implementation we only add one pre-shared-key extension. */
         ssl->session_negotiate->ciphersuite = ciphersuites[i];
-        ssl->handshake->ciphersuite_info = suite_info;
         break;
     }
 
-    if( suite_info != NULL )
-    {
-        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
-        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
-    }
+    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite );
+    /* No suitable ciphersuite for the PSK */
+    if( ciphersuite_info  == NULL )
+        return( 0 );
+
+    psa_hash_alg = mbedtls_psa_translate_md( ciphersuite_info->mac );
+    hash_len = PSA_HASH_LENGTH( psa_hash_alg );
     if( hash_len == -1 )
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
@@ -818,64 +810,28 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_binders(
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *p = buf;
-    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
-    const int *ciphersuites;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info = NULL;
+    psa_algorithm_t psa_hash_alg;
     int hash_len = 0;
-    const unsigned char *psk;
-    size_t psk_len;
-    const unsigned char *psk_identity;
-    size_t psk_identity_len;
+    const unsigned char *psk = NULL;
+    size_t psk_len = 0;
     int psk_type;
     unsigned char transcript[MBEDTLS_MD_MAX_SIZE];
     size_t transcript_len;
-    psa_algorithm_t psa_hash_alg;
 
-    /* Check if we have any PSKs to offer. If so, return the first.
-     *
-     * NOTE: Ultimately, we want to be able to offer multiple PSKs,
-     *       in which case we want to iterate over them here.
-     *
-     * As it stands, however, we only ever offer one, chosen
-     * by the following heuristic:
-     * - If a ticket has been configured, offer the corresponding PSK.
-     * - If no ticket has been configured by an external PSK has been
-     *   configured, offer that.
-     * - Otherwise, skip the PSK extension.
-     */
+    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite );
+    if( ciphersuite_info  == NULL )
+        return( 0 );
 
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
-                                      &psk_identity, &psk_identity_len ) != 0 )
-    {
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-    }
-
-    /*
-     * Ciphersuite list
-     */
-    ciphersuites = ssl->conf->ciphersuite_list;
-    for ( int i = 0; ciphersuites[i] != 0; i++ )
-    {
-        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
-
-        if( suite_info == NULL )
-            continue;
-
-        /* In this implementation we only add one pre-shared-key extension. */
-        ssl->session_negotiate->ciphersuite = ciphersuites[i];
-        ssl->handshake->ciphersuite_info = suite_info;
-        break;
-    }
-
-    if( suite_info != NULL )
-    {
-        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
-        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
-    }
+    psa_hash_alg = mbedtls_psa_translate_md( ciphersuite_info->mac );
+    hash_len = PSA_HASH_LENGTH( psa_hash_alg );
     if( ( hash_len == -1 ) || ( ( end - buf ) != 3 + hash_len ) )
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
     MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding PSK binder list" ) );
 
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 3 + hash_len );
     /* 2 bytes length field for array of psk binders */
     MBEDTLS_PUT_UINT16_BE( hash_len + 1, p, 0 );
     p += 2;
@@ -889,14 +845,14 @@ int mbedtls_ssl_tls13_write_pre_shared_key_ext_binders(
         psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
 
     /* Get current state of handshake transcript. */
-    ret = mbedtls_ssl_get_handshake_transcript( ssl, suite_info->mac,
+    ret = mbedtls_ssl_get_handshake_transcript( ssl, ciphersuite_info->mac,
                                                 transcript, sizeof( transcript ),
                                                 &transcript_len );
     if( ret != 0 )
         return( ret );
 
     ret = mbedtls_ssl_tls13_create_psk_binder( ssl,
-              mbedtls_psa_translate_md( suite_info->mac ),
+              mbedtls_psa_translate_md( ciphersuite_info->mac ),
               psk, psk_len, psk_type,
               transcript, p );
     if( ret != 0 )
@@ -1269,13 +1225,8 @@ static int ssl_tls13_check_server_hello_session_id_echo( mbedtls_ssl_context *ss
  * opaque PskBinderEntry<32..255>;
  *
  * struct {
- *   select ( Handshake.msg_type ) {
- *     case client_hello:
- *          PskIdentity identities<7..2^16-1>;
- *          PskBinderEntry binders<33..2^16-1>;
- *     case server_hello:
- *          uint16 selected_identity;
- *   };
+ *
+ *   uint16 selected_identity;
  *
  * } PreSharedKeyExtension;
  *
@@ -1283,7 +1234,7 @@ static int ssl_tls13_check_server_hello_session_id_echo( mbedtls_ssl_context *ss
 
 static int ssl_tls13_parse_server_psk_identity_ext( mbedtls_ssl_context *ssl,
                                                     const unsigned char *buf,
-                                                    size_t len )
+                                                    const unsigned char *end )
 {
     int ret = 0;
     size_t selected_identity;
@@ -1299,7 +1250,7 @@ static int ssl_tls13_parse_server_psk_identity_ext( mbedtls_ssl_context *ssl,
      * NOTE: Ultimately, we want to offer multiple PSKs, and in this
      *       case, we need to iterate over them here.
      */
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+    if( mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         /* If we haven't offered a PSK, the server must not send
@@ -1307,12 +1258,7 @@ static int ssl_tls13_parse_server_psk_identity_ext( mbedtls_ssl_context *ssl,
         return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
     }
 
-    if( len != (size_t) 2 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad psk_identity extension in server hello message" ) );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
+    MBEDTLS_SSL_CHK_BUF_PTR( buf, end, 2 );
     selected_identity = MBEDTLS_GET_UINT16_BE( buf, 0 );
 
     /* We have offered only one PSK, so the only valid choice
@@ -1571,7 +1517,7 @@ static int ssl_tls13_parse_server_hello( mbedtls_ssl_context *ssl,
                 }
 
                 if( ( ret = ssl_tls13_parse_server_psk_identity_ext(
-                                ssl, p, extension_data_len ) ) != 0 )
+                                ssl, p, extension_data_end ) ) != 0 )
                 {
                     MBEDTLS_SSL_DEBUG_RET(
                         1, ( "ssl_tls13_parse_server_psk_identity_ext" ), ret );
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 265d6d3097..4cd0d0e267 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1505,4 +1505,42 @@ int mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
 }
 #endif /* MBEDTLS_ECDH_C */
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+/* Check if we have any PSK to offer, returns 0 if PSK is available.
+ * Assign the psk and ticket if pointers are present.
+ */
+int mbedtls_ssl_get_psk_to_offer(
+        const mbedtls_ssl_context *ssl,
+        int *psk_type,
+        const unsigned char **psk, size_t *psk_len,
+        const unsigned char **psk_identity, size_t *psk_identity_len )
+{
+    int ptrs_present = 0;
+
+    if( psk != NULL && psk_len != NULL &&
+        psk_identity != NULL && psk_identity_len != NULL )
+    {
+        ptrs_present = 1;
+    }
+
+    /* Check if an external PSK has been configured. */
+    if( ssl->conf->psk != NULL )
+    {
+        if( ptrs_present )
+        {
+            *psk = ssl->conf->psk;
+            *psk_len = ssl->conf->psk_len;
+            *psk_identity = ssl->conf->psk_identity;
+            *psk_identity_len = ssl->conf->psk_identity_len;
+        }
+
+        if( psk_type != NULL )
+            *psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
+        return( 0 );
+    }
+
+    return( 1 );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */