From 93ecbef6a89464e62c3225f423667e9541a48f22 Mon Sep 17 00:00:00 2001 From: Valerio Setti Date: Wed, 14 Feb 2024 11:44:48 +0100 Subject: [PATCH] pk_wrap: set proper PSA algin rsa wrappers based on padding mode set in RSA context Signed-off-by: Valerio Setti --- library/pk_wrap.c | 50 ++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 69e1baf2e1..b472cfbb7a 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -74,8 +74,7 @@ static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, int key_len; unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES]; unsigned char *p = buf + sizeof(buf); - psa_algorithm_t psa_alg_md = - PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg)); + psa_algorithm_t psa_alg_md; size_t rsa_len = mbedtls_rsa_get_len(rsa); #if SIZE_MAX > UINT_MAX @@ -84,6 +83,12 @@ static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, } #endif + if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) { + psa_alg_md = PSA_ALG_RSA_PSS(mbedtls_md_psa_alg_from_type(md_alg)); + } else { + psa_alg_md = PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg)); + } + if (sig_len < rsa_len) { return MBEDTLS_ERR_RSA_VERIFY_FAILED; } @@ -235,10 +240,14 @@ static int rsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, if (psa_md_alg == 0) { return MBEDTLS_ERR_PK_BAD_INPUT_DATA; } + psa_algorithm_t psa_alg; + if (mbedtls_rsa_get_padding_mode(mbedtls_pk_rsa(*pk)) == MBEDTLS_RSA_PKCS_V21) { + psa_alg = PSA_ALG_RSA_PSS(psa_md_alg); + } else { + psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN(psa_md_alg); + } - return mbedtls_pk_psa_rsa_sign_ext(PSA_ALG_RSA_PKCS1V15_SIGN( - psa_md_alg), - pk->pk_ctx, hash, hash_len, + return mbedtls_pk_psa_rsa_sign_ext(psa_alg, pk->pk_ctx, hash, hash_len, sig, sig_size, sig_len); } #else /* MBEDTLS_USE_PSA_CRYPTO */ @@ -276,6 +285,7 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; + psa_algorithm_t psa_md_alg, decrypt_alg; psa_status_t status; int key_len; unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES]; @@ -284,12 +294,6 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, ((void) f_rng); ((void) p_rng); -#if !defined(MBEDTLS_RSA_ALT) - if (rsa->padding != MBEDTLS_RSA_PKCS_V15) { - return MBEDTLS_ERR_RSA_INVALID_PADDING; - } -#endif /* !MBEDTLS_RSA_ALT */ - if (ilen != mbedtls_rsa_get_len(rsa)) { return MBEDTLS_ERR_RSA_BAD_INPUT_DATA; } @@ -301,7 +305,13 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, psa_set_key_type(&attributes, PSA_KEY_TYPE_RSA_KEY_PAIR); psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT); - psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT); + if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) { + psa_md_alg = mbedtls_md_psa_alg_from_type(mbedtls_rsa_get_md_alg(rsa)); + decrypt_alg = PSA_ALG_RSA_OAEP(psa_md_alg); + } else { + decrypt_alg = PSA_ALG_RSA_PKCS1V15_CRYPT; + } + psa_set_key_algorithm(&attributes, decrypt_alg); status = psa_import_key(&attributes, buf + sizeof(buf) - key_len, key_len, @@ -311,7 +321,7 @@ static int rsa_decrypt_wrap(mbedtls_pk_context *pk, goto cleanup; } - status = psa_asymmetric_decrypt(key_id, PSA_ALG_RSA_PKCS1V15_CRYPT, + status = psa_asymmetric_decrypt(key_id, decrypt_alg, input, ilen, NULL, 0, output, osize, olen); @@ -358,6 +368,7 @@ static int rsa_encrypt_wrap(mbedtls_pk_context *pk, int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; + psa_algorithm_t psa_md_alg; psa_status_t status; int key_len; unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES]; @@ -366,12 +377,6 @@ static int rsa_encrypt_wrap(mbedtls_pk_context *pk, ((void) f_rng); ((void) p_rng); -#if !defined(MBEDTLS_RSA_ALT) - if (rsa->padding != MBEDTLS_RSA_PKCS_V15) { - return MBEDTLS_ERR_RSA_INVALID_PADDING; - } -#endif - if (mbedtls_rsa_get_len(rsa) > osize) { return MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE; } @@ -382,7 +387,12 @@ static int rsa_encrypt_wrap(mbedtls_pk_context *pk, } psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT); - psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT); + if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) { + psa_md_alg = mbedtls_md_psa_alg_from_type(mbedtls_rsa_get_md_alg(rsa)); + psa_set_key_algorithm(&attributes, PSA_ALG_RSA_OAEP(psa_md_alg)); + } else { + psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT); + } psa_set_key_type(&attributes, PSA_KEY_TYPE_RSA_PUBLIC_KEY); status = psa_import_key(&attributes,