diff --git a/library/pk.c b/library/pk.c index 77012e1578..cccadb1f92 100644 --- a/library/pk.c +++ b/library/pk.c @@ -912,24 +912,34 @@ int mbedtls_pk_wrap_as_opaque(mbedtls_pk_context *pk, #else /* !MBEDTLS_ECP_LIGHT && !MBEDTLS_RSA_C */ #if defined(MBEDTLS_ECP_LIGHT) if (mbedtls_pk_get_type(pk) == MBEDTLS_PK_ECKEY) { - mbedtls_ecp_keypair *ec; unsigned char d[MBEDTLS_ECP_MAX_BYTES]; size_t d_len; psa_ecc_family_t curve_id; psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; psa_key_type_t key_type; size_t bits; - int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; psa_status_t status; /* export the private key material in the format PSA wants */ - ec = mbedtls_pk_ec_rw(*pk); +#if defined(MBEDTLS_PK_USE_PSA_EC_DATA) + status = psa_export_key(pk->priv_id, d, sizeof(d), &d_len); + if (status != PSA_SUCCESS) { + return psa_pk_status_to_mbedtls(status); + } + + curve_id = pk->ec_family; + bits = pk->ec_bits; +#else /* MBEDTLS_PK_USE_PSA_EC_DATA */ + mbedtls_ecp_keypair *ec = mbedtls_pk_ec_rw(*pk); + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + d_len = PSA_BITS_TO_BYTES(ec->grp.nbits); if ((ret = mbedtls_ecp_write_key(ec, d, d_len)) != 0) { return ret; } curve_id = mbedtls_ecc_group_to_psa(ec->grp.id, &bits); +#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */ key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(curve_id); /* prepare the key attributes */ diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 7f5e751a9b..f3a44aedfb 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -925,12 +925,9 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, unsigned char *sig, size_t sig_size, size_t *sig_len, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) { - mbedtls_ecp_keypair *ctx = pk->pk_ctx; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; psa_status_t status; - unsigned char buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH]; #if defined(MBEDTLS_ECDSA_DETERMINISTIC) psa_algorithm_t psa_sig_md = PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_hash_info_psa_from_md(md_alg)); @@ -938,10 +935,17 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, psa_algorithm_t psa_sig_md = PSA_ALG_ECDSA(mbedtls_hash_info_psa_from_md(md_alg)); #endif +#if defined(MBEDTLS_PK_USE_PSA_EC_DATA) + psa_ecc_family_t curve = pk->ec_family; +#else /* MBEDTLS_PK_USE_PSA_EC_DATA */ + mbedtls_ecp_keypair *ctx = pk->pk_ctx; + psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT; + unsigned char buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH]; size_t curve_bits; psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits); size_t key_len = PSA_BITS_TO_BYTES(curve_bits); +#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */ /* PSA has its own RNG */ ((void) f_rng); @@ -951,6 +955,12 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, return MBEDTLS_ERR_PK_BAD_INPUT_DATA; } +#if defined(MBEDTLS_PK_USE_PSA_EC_DATA) + if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) { + return MBEDTLS_ERR_PK_BAD_INPUT_DATA; + } + key_id = pk->priv_id; +#else if (key_len > sizeof(buf)) { return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; } @@ -970,6 +980,7 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, ret = PSA_PK_TO_MBEDTLS_ERR(status); goto cleanup; } +#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */ status = psa_sign_hash(key_id, psa_sig_md, hash, hash_len, sig, sig_size, sig_len); @@ -981,8 +992,11 @@ static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg, ret = pk_ecdsa_sig_asn1_from_psa(sig, sig_len, sig_size); cleanup: + +#if !defined(MBEDTLS_PK_USE_PSA_EC_DATA) mbedtls_platform_zeroize(buf, sizeof(buf)); status = psa_destroy_key(key_id); +#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */ if (ret == 0 && status != PSA_SUCCESS) { ret = PSA_PK_TO_MBEDTLS_ERR(status); } @@ -1123,24 +1137,19 @@ cleanup: static int eckey_check_pair_psa(mbedtls_pk_context *pub, mbedtls_pk_context *prv) { psa_status_t status, destruction_status; - psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT; int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; - /* We are using MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH for the size of this - * buffer because it will be used to hold the private key at first and - * then its public part (but not at the same time). */ uint8_t prv_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH]; size_t prv_key_len; - mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; #if defined(MBEDTLS_PK_USE_PSA_EC_DATA) - const psa_ecc_family_t curve = prv->ec_family; - const size_t curve_bits = prv->ec_bits; + mbedtls_svc_key_id_t key_id = prv->priv_id; #else /* !MBEDTLS_PK_USE_PSA_EC_DATA */ + mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; + psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT; uint8_t pub_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH]; size_t pub_key_len; size_t curve_bits; const psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(mbedtls_pk_ec_ro(*prv)->grp.id, &curve_bits); -#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */ const size_t curve_bytes = PSA_BITS_TO_BYTES(curve_bits); if (curve == 0) { @@ -1163,6 +1172,7 @@ static int eckey_check_pair_psa(mbedtls_pk_context *pub, mbedtls_pk_context *prv } mbedtls_platform_zeroize(prv_key_buf, sizeof(prv_key_buf)); +#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */ status = psa_export_public_key(key_id, prv_key_buf, sizeof(prv_key_buf), &prv_key_len);