diff --git a/library/pk_wrap.c b/library/pk_wrap.c index f424a3789e..664c266abc 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -950,12 +950,27 @@ static int pk_ecdsa_sig_asn1_from_psa(unsigned char *sig, size_t *sig_len, /* This is the common helper used by ecdsa_sign_wrap() functions below (they * differ in having PK_USE_PSA_EC_DATA defined or not) to sign using PSA * functions. */ -static int ecdsa_sign_psa(mbedtls_svc_key_id_t key_id, psa_algorithm_t psa_sig_md, +static int ecdsa_sign_psa(mbedtls_svc_key_id_t key_id, mbedtls_md_type_t md_alg, const unsigned char *hash, size_t hash_len, unsigned char *sig, size_t sig_size, size_t *sig_len) { int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; psa_status_t status; + psa_algorithm_t psa_sig_md; + psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT; + psa_algorithm_t alg; + + status = psa_get_key_attributes(key_id, &key_attr); + if (status != PSA_SUCCESS) { + return PSA_PK_ECDSA_TO_MBEDTLS_ERR(status); + } + alg = psa_get_key_algorithm(&key_attr); + + if (PSA_ALG_IS_DETERMINISTIC_ECDSA(alg)) { + psa_sig_md = PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_md_psa_alg_from_type(md_alg)); + } else { + psa_sig_md = PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg)); + } status = psa_sign_hash(key_id, psa_sig_md, hash, hash_len, sig, sig_size, sig_len); @@ -983,14 +998,8 @@ static int pk_opaque_ecdsa_sign_wrap(mbedtls_pk_context *pk, { ((void) f_rng); ((void) p_rng); - psa_algorithm_t psa_sig_md = - PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg)); - if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) { - return MBEDTLS_ERR_PK_BAD_INPUT_DATA; - } - - return ecdsa_sign_psa(pk->priv_id, psa_sig_md, hash, hash_len, sig, sig_size, + return ecdsa_sign_psa(pk->priv_id, md_alg, hash, hash_len, sig, sig_size, sig_len); } @@ -1002,22 +1011,8 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, { ((void) f_rng); ((void) p_rng); -#if defined(MBEDTLS_ECDSA_DETERMINISTIC) - psa_algorithm_t psa_sig_md = - PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_md_psa_alg_from_type(md_alg)); -#else - psa_algorithm_t psa_sig_md = - PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg)); -#endif - if (pk->ec_family == 0) { - return MBEDTLS_ERR_PK_BAD_INPUT_DATA; - } - if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) { - return MBEDTLS_ERR_PK_BAD_INPUT_DATA; - } - - return ecdsa_sign_psa(pk->priv_id, psa_sig_md, hash, hash_len, sig, sig_size, + return ecdsa_sign_psa(pk->priv_id, md_alg, hash, hash_len, sig, sig_size, sig_len); } #else /* MBEDTLS_PK_USE_PSA_EC_DATA */ @@ -1068,7 +1063,7 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, goto cleanup; } - ret = ecdsa_sign_psa(key_id, psa_sig_md, hash, hash_len, sig, sig_size, sig_len); + ret = ecdsa_sign_psa(key_id, md_alg, hash, hash_len, sig, sig_size, sig_len); cleanup: mbedtls_platform_zeroize(buf, sizeof(buf));