From 6be9cf542f3e5763371a347d199c6db6bdd96d06 Mon Sep 17 00:00:00 2001 From: Przemyslaw Stekiel Date: Wed, 19 Jan 2022 16:00:22 +0100 Subject: [PATCH] Cleanup the code Use conditional compilation for psa and mbedtls code (MBEDTLS_USE_PSA_CRYPTO). Signed-off-by: Przemyslaw Stekiel --- library/ssl_misc.h | 5 +- library/ssl_msg.c | 106 +++++++++++++++++++++++++++++++++-- library/ssl_tls.c | 117 ++++++++++++--------------------------- library/ssl_tls13_keys.c | 11 +++- 4 files changed, 150 insertions(+), 89 deletions(-) diff --git a/library/ssl_misc.h b/library/ssl_misc.h index 68cc4f038d..a6439dc3eb 100644 --- a/library/ssl_misc.h +++ b/library/ssl_misc.h @@ -937,14 +937,15 @@ struct mbedtls_ssl_transform #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */ - mbedtls_cipher_context_t cipher_ctx_enc; /*!< encryption context */ - mbedtls_cipher_context_t cipher_ctx_dec; /*!< decryption context */ int minor_ver; #if defined(MBEDTLS_USE_PSA_CRYPTO) mbedtls_svc_key_id_t psa_key_enc; /*!< psa encryption key */ mbedtls_svc_key_id_t psa_key_dec; /*!< psa decryption key */ psa_algorithm_t psa_alg; /*!< psa algorithm */ +#else + mbedtls_cipher_context_t cipher_ctx_enc; /*!< encryption context */ + mbedtls_cipher_context_t cipher_ctx_dec; /*!< decryption context */ #endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) diff --git a/library/ssl_msg.c b/library/ssl_msg.c index c9f75de6b3..2353c5e441 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -522,7 +522,9 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { +#if !defined(MBEDTLS_USE_PSA_CRYPTO) mbedtls_cipher_mode_t mode; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ int auth_done = 0; unsigned char * data; unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_OUT_LEN_MAX ]; @@ -568,7 +570,9 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, MBEDTLS_SSL_DEBUG_BUF( 4, "before encrypt: output payload", data, rec->data_len ); +#if !defined(MBEDTLS_USE_PSA_CRYPTO) mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ if( rec->data_len > MBEDTLS_SSL_OUT_CONTENT_LEN ) { @@ -649,8 +653,13 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, * Add MAC before if needed */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER || + ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING +#else if( mode == MBEDTLS_MODE_STREAM || ( mode == MBEDTLS_MODE_CBC +#endif #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) && transform->encrypt_then_mac == MBEDTLS_SSL_ETM_DISABLED #endif @@ -707,7 +716,11 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, * Encrypt */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER ) +#else if( mode == MBEDTLS_MODE_STREAM ) +#endif { size_t olen; #if defined(MBEDTLS_USE_PSA_CRYPTO) @@ -779,9 +792,18 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, #if defined(MBEDTLS_GCM_C) || \ defined(MBEDTLS_CCM_C) || \ defined(MBEDTLS_CHACHAPOLY_C) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == PSA_ALG_GCM || + /* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to + psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 ) + in tls context (TLS only uses the default taglen or 8) */ + PSA_ALG_IS_AEAD( transform->psa_alg ) || + transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 ) +#else if( mode == MBEDTLS_MODE_GCM || mode == MBEDTLS_MODE_CCM || mode == MBEDTLS_MODE_CHACHAPOLY ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { unsigned char iv[12]; unsigned char *dynamic_iv; @@ -897,7 +919,11 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, else #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING ) +#else if( mode == MBEDTLS_MODE_CBC ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; size_t padlen, i; @@ -1092,7 +1118,9 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, mbedtls_record *rec ) { size_t olen; +#if !defined(MBEDTLS_USE_PSA_CRYPTO) mbedtls_cipher_mode_t mode; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ int ret, auth_done = 0; #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) size_t padlen = 0, correct = 1; @@ -1117,7 +1145,9 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, } data = rec->buf + rec->data_offset; +#if !defined(MBEDTLS_USE_PSA_CRYPTO) mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_dec ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) /* @@ -1131,7 +1161,11 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER ) +#else if( mode == MBEDTLS_MODE_STREAM ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { padlen = 0; #if defined(MBEDTLS_USE_PSA_CRYPTO) @@ -1198,9 +1232,18 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, #if defined(MBEDTLS_GCM_C) || \ defined(MBEDTLS_CCM_C) || \ defined(MBEDTLS_CHACHAPOLY_C) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == PSA_ALG_GCM || + /* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to + psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 ) + in tls context (TLS only uses the default taglen or 8) */ + PSA_ALG_IS_AEAD( transform->psa_alg ) || + transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 ) +#else if( mode == MBEDTLS_MODE_GCM || mode == MBEDTLS_MODE_CCM || mode == MBEDTLS_MODE_CHACHAPOLY ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { unsigned char iv[12]; unsigned char *dynamic_iv; @@ -1322,7 +1365,11 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, else #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC) +#if defined(MBEDTLS_USE_PSA_CRYPTO) + if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING ) +#else if( mode == MBEDTLS_MODE_CBC ) +#endif /* MBEDTLS_USE_PSA_CRYPTO */ { size_t minlen = 0; #if defined(MBEDTLS_USE_PSA_CRYPTO) @@ -5047,12 +5094,62 @@ int mbedtls_ssl_get_record_expansion( const mbedtls_ssl_context *ssl ) size_t transform_expansion = 0; const mbedtls_ssl_transform *transform = ssl->transform_out; unsigned block_size; +#if defined(MBEDTLS_USE_PSA_CRYPTO) + psa_key_attributes_t attr = PSA_KEY_ATTRIBUTES_INIT; + psa_key_type_t key_type; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ size_t out_hdr_len = mbedtls_ssl_out_hdr_len( ssl ); if( transform == NULL ) return( (int) out_hdr_len ); + +#if defined(MBEDTLS_USE_PSA_CRYPTO) + switch( transform->psa_alg ) + { + case PSA_ALG_GCM: + case PSA_ALG_CHACHA20_POLY1305: + case MBEDTLS_SSL_NULL_CIPHER: + transform_expansion = transform->minlen; + break; + + case PSA_ALG_CBC_NO_PADDING: + (void) psa_get_key_attributes( transform->psa_key_enc, &attr ); + key_type = psa_get_key_type( &attr ); + + block_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type ); + + /* Expansion due to the addition of the MAC. */ + transform_expansion += transform->maclen; + + /* Expansion due to the addition of CBC padding; + * Theoretically up to 256 bytes, but we never use + * more than the block size of the underlying cipher. */ + transform_expansion += block_size; + + /* For TLS 1.2 or higher, an explicit IV is added + * after the record header. */ +#if defined(MBEDTLS_SSL_PROTO_TLS1_2) + transform_expansion += block_size; +#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ + break; + + default: + /* Handle CCM case in default: + PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to + psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 ) + in tls context (TLS only uses the default taglen or 8) */ + if ( PSA_ALG_IS_AEAD( transform->psa_alg ) ) + { + transform_expansion = transform->minlen; + break; + } + + MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) ); + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } +#else switch( mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc ) ) { case MBEDTLS_MODE_GCM: @@ -5087,6 +5184,7 @@ int mbedtls_ssl_get_record_expansion( const mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) ); return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) if( transform->out_cid_len != 0 ) @@ -5591,13 +5689,13 @@ void mbedtls_ssl_transform_free( mbedtls_ssl_transform *transform ) if( transform == NULL ) return; - mbedtls_cipher_free( &transform->cipher_ctx_enc ); - mbedtls_cipher_free( &transform->cipher_ctx_dec ); - #if defined(MBEDTLS_USE_PSA_CRYPTO) psa_destroy_key( transform->psa_key_enc ); psa_destroy_key( transform->psa_key_dec ); -#endif +#else + mbedtls_cipher_free( &transform->cipher_ctx_enc ); + mbedtls_cipher_free( &transform->cipher_ctx_dec ); +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) mbedtls_md_free( &transform->md_ctx_enc ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 6191d634af..4266af4d37 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -705,9 +705,6 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, const mbedtls_ssl_context *ssl ) { int ret = 0; -#if defined(MBEDTLS_USE_PSA_CRYPTO) - int psa_fallthrough; -#endif /* MBEDTLS_USE_PSA_CRYPTO */ unsigned char keyblk[256]; unsigned char *key1; unsigned char *key2; @@ -1011,80 +1008,6 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, tls_prf_get_type( tls_prf ) ); } -#if defined(MBEDTLS_USE_PSA_CRYPTO) - ret = mbedtls_cipher_setup_psa( &transform->cipher_ctx_enc, - cipher_info, transform->taglen ); - if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup_psa", ret ); - goto end; - } - - if( ret == 0 ) - { - MBEDTLS_SSL_DEBUG_MSG( 3, ( "Successfully setup PSA-based encryption cipher context" ) ); - psa_fallthrough = 0; - } - else - { - MBEDTLS_SSL_DEBUG_MSG( 1, ( "Failed to setup PSA-based cipher context for record encryption - fall through to default setup." ) ); - psa_fallthrough = 1; - } - - if( psa_fallthrough == 1 ) -#endif /* MBEDTLS_USE_PSA_CRYPTO */ - if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc, - cipher_info ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret ); - goto end; - } - -#if defined(MBEDTLS_USE_PSA_CRYPTO) - ret = mbedtls_cipher_setup_psa( &transform->cipher_ctx_dec, - cipher_info, transform->taglen ); - if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup_psa", ret ); - goto end; - } - - if( ret == 0 ) - { - MBEDTLS_SSL_DEBUG_MSG( 3, ( "Successfully setup PSA-based decryption cipher context" ) ); - psa_fallthrough = 0; - } - else - { - MBEDTLS_SSL_DEBUG_MSG( 1, ( "Failed to setup PSA-based cipher context for record decryption - fall through to default setup." ) ); - psa_fallthrough = 1; - } - - if( psa_fallthrough == 1 ) -#endif /* MBEDTLS_USE_PSA_CRYPTO */ - if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_dec, - cipher_info ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret ); - goto end; - } - - if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc, key1, - (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ), - MBEDTLS_ENCRYPT ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret ); - goto end; - } - - if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_dec, key2, - (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ), - MBEDTLS_DECRYPT ) ) != 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret ); - goto end; - } - #if defined(MBEDTLS_USE_PSA_CRYPTO) if( ( status = mbedtls_cipher_to_psa( cipher_info->type, transform->taglen, @@ -1099,6 +1022,7 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT ); psa_set_key_algorithm( &attributes, alg ); + psa_set_key_type( &attributes, key_type ); transform->psa_alg = alg; @@ -1123,7 +1047,36 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, MBEDTLS_SSL_DEBUG_RET( 1, "psa_import_key", ret ); goto end; } -#endif /* MBEDTLS_USE_PSA_CRYPTO */ +#else + if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc, + cipher_info ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret ); + goto end; + } + + if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_dec, + cipher_info ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret ); + goto end; + } + + if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc, key1, + (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ), + MBEDTLS_ENCRYPT ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret ); + goto end; + } + + if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_dec, key2, + (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ), + MBEDTLS_DECRYPT ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret ); + goto end; + } #if defined(MBEDTLS_CIPHER_MODE_CBC) if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CBC ) @@ -1143,7 +1096,7 @@ static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform, } } #endif /* MBEDTLS_CIPHER_MODE_CBC */ - +#endif /* MBEDTLS_USE_PSA_CRYPTO */ end: mbedtls_platform_zeroize( keyblk, sizeof( keyblk ) ); @@ -3070,12 +3023,12 @@ void mbedtls_ssl_transform_init( mbedtls_ssl_transform *transform ) { memset( transform, 0, sizeof(mbedtls_ssl_transform) ); - mbedtls_cipher_init( &transform->cipher_ctx_enc ); - mbedtls_cipher_init( &transform->cipher_ctx_dec ); - #if defined(MBEDTLS_USE_PSA_CRYPTO) transform->psa_key_enc = MBEDTLS_SVC_KEY_ID_INIT; transform->psa_key_dec = MBEDTLS_SVC_KEY_ID_INIT; +#else + mbedtls_cipher_init( &transform->cipher_ctx_enc ); + mbedtls_cipher_init( &transform->cipher_ctx_dec ); #endif #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c index 0aade35b0e..a3c1fe54f5 100644 --- a/library/ssl_tls13_keys.c +++ b/library/ssl_tls13_keys.c @@ -801,7 +801,9 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, mbedtls_ssl_key_set const *traffic_keys, mbedtls_ssl_context *ssl /* DEBUG ONLY */ ) { +#if !defined(MBEDTLS_USE_PSA_CRYPTO) int ret; +#endif /* MBEDTLS_USE_PSA_CRYPTO */ mbedtls_cipher_info_t const *cipher_info; const mbedtls_ssl_ciphersuite_t *ciphersuite_info; unsigned char const *key_enc; @@ -838,10 +840,10 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } +#if !defined(MBEDTLS_USE_PSA_CRYPTO) /* * Setup cipher contexts in target transform */ - if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc, cipher_info ) ) != 0 ) { @@ -855,6 +857,7 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ #if defined(MBEDTLS_SSL_SRV_C) if( endpoint == MBEDTLS_SSL_IS_SERVER ) @@ -884,6 +887,7 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, memcpy( transform->iv_enc, iv_enc, traffic_keys->iv_len ); memcpy( transform->iv_dec, iv_dec, traffic_keys->iv_len ); +#if !defined(MBEDTLS_USE_PSA_CRYPTO) if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc, key_enc, cipher_info->key_bitlen, MBEDTLS_ENCRYPT ) ) != 0 ) @@ -899,6 +903,7 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret ); return( ret ); } +#endif /* MBEDTLS_USE_PSA_CRYPTO */ /* * Setup other fields in SSL transform @@ -922,6 +927,9 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, transform->taglen + MBEDTLS_SSL_CID_TLS1_3_PADDING_GRANULARITY; #if defined(MBEDTLS_USE_PSA_CRYPTO) + /* + * Setup psa keys and alg + */ if( ( status = mbedtls_cipher_to_psa( cipher_info->type, transform->taglen, &alg, @@ -934,6 +942,7 @@ int mbedtls_ssl_tls13_populate_transform( mbedtls_ssl_transform *transform, psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT ); psa_set_key_algorithm( &attributes, alg ); + psa_set_key_type( &attributes, key_type ); transform->psa_alg = alg;