diff --git a/library/aes.c b/library/aes.c index 566e74715f..f0ade21490 100644 --- a/library/aes.c +++ b/library/aes.c @@ -39,6 +39,9 @@ #if defined(MBEDTLS_AESNI_C) #include "aesni.h" #endif +#if defined(MBEDTLS_AESCE_C) +#include "aesce.h" +#endif #include "mbedtls/platform.h" @@ -544,6 +547,12 @@ int mbedtls_aes_setkey_enc(mbedtls_aes_context *ctx, const unsigned char *key, } #endif +#if defined(MBEDTLS_AESCE_C) && defined(MBEDTLS_HAVE_ARM64) + if (mbedtls_aesce_has_support()) { + return mbedtls_aesce_setkey_enc((unsigned char *) RK, key, keybits); + } +#endif + for (i = 0; i < (keybits >> 5); i++) { RK[i] = MBEDTLS_GET_UINT32_LE(key, i << 2); } diff --git a/library/aesce.c b/library/aesce.c index f33d593427..4b0f9d7449 100644 --- a/library/aesce.c +++ b/library/aesce.c @@ -65,6 +65,114 @@ int mbedtls_aesce_has_support(void) #endif } + +static uint8_t const rcon[] = { 0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36 }; + +static inline uint32_t ror32_8(uint32_t word) +{ + return (word << (32 - 8)) | (word >> 8); +} + +static inline uint32_t aes_sub(uint32_t in) +{ + uint32x4_t _in = vdupq_n_u32(in); + uint32x4_t v; + uint8x16_t zero = vdupq_n_u8(0); + v = vreinterpretq_u32_u8(vaeseq_u8(zero, vreinterpretq_u8_u32(_in))); + return vgetq_lane_u32(v, 0); +} + +/* + * Key expansion, 128-bit case + */ +static void aesce_setkey_enc_128(unsigned char *rk, + const unsigned char *key) +{ + uint32_t *rki; + uint32_t *rko; + uint32_t *rk_u32 = (uint32_t *) rk; + memcpy(rk, key, (128 / 8)); + + for (size_t i = 0; i < sizeof(rcon); i++) { + rki = rk_u32 + i * (128 / 32); + rko = rki + (128 / 32); + rko[0] = ror32_8(aes_sub(rki[(128 / 32) - 1])) ^ rcon[i] ^ rki[0]; + rko[1] = rko[0] ^ rki[1]; + rko[2] = rko[1] ^ rki[2]; + rko[3] = rko[2] ^ rki[3]; + } +} + +/* + * Key expansion, 192-bit case + */ +static void aesce_setkey_enc_192(unsigned char *rk, + const unsigned char *key) +{ + uint32_t *rki; + uint32_t *rko; + uint32_t *rk_u32 = (uint32_t *) rk; + memcpy(rk, key, (192 / 8)); + + for (size_t i = 0; i < 8; i++) { + rki = rk_u32 + i * (192 / 32); + rko = rki + (192 / 32); + rko[0] = ror32_8(aes_sub(rki[(192 / 32) - 1])) ^ rcon[i] ^ rki[0]; + rko[1] = rko[0] ^ rki[1]; + rko[2] = rko[1] ^ rki[2]; + rko[3] = rko[2] ^ rki[3]; + if (i < 7) { + rko[4] = rko[3] ^ rki[4]; + rko[5] = rko[4] ^ rki[5]; + } + } +} + +/* + * Key expansion, 256-bit case + */ +static void aesce_setkey_enc_256(unsigned char *rk, + const unsigned char *key) +{ + uint32_t *rki; + uint32_t *rko; + uint32_t *rk_u32 = (uint32_t *) rk; + memcpy(rk, key, (256 / 8)); + + for (size_t i = 0; i < 7; i++) { + rki = rk_u32 + i * (256 / 32); + rko = rki + (256 / 32); + rko[0] = ror32_8(aes_sub(rki[(256 / 32) - 1])) ^ rcon[i] ^ rki[0]; + rko[1] = rko[0] ^ rki[1]; + rko[2] = rko[1] ^ rki[2]; + rko[3] = rko[2] ^ rki[3]; + if (i < 6) { + rko[4] = aes_sub(rko[3]) ^ rki[4]; + rko[5] = rko[4] ^ rki[5]; + rko[6] = rko[5] ^ rki[6]; + rko[7] = rko[6] ^ rki[7]; + } + } +} + +/* + * Key expansion, wrapper + */ +int mbedtls_aesce_setkey_enc(unsigned char *rk, + const unsigned char *key, + size_t bits) +{ + switch (bits) { + case 128: aesce_setkey_enc_128(rk, key); break; + case 192: aesce_setkey_enc_192(rk, key); break; + case 256: aesce_setkey_enc_256(rk, key); break; + default: return MBEDTLS_ERR_AES_INVALID_KEY_LENGTH; + } + + return 0; +} + #endif /* MBEDTLS_HAVE_ARM64 */ #endif /* MBEDTLS_AESCE_C */ diff --git a/library/aesce.h b/library/aesce.h index 2d5dde985f..7fc0cfa0eb 100644 --- a/library/aesce.h +++ b/library/aesce.h @@ -49,6 +49,20 @@ extern "C" { */ int mbedtls_aesce_has_support(void); + +/** + * \brief Internal key expansion for encryption + * + * \param rk Destination buffer where the round keys are written + * \param key Encryption key + * \param bits Key size in bits (must be 128, 192 or 256) + * + * \return 0 if successful, or MBEDTLS_ERR_AES_INVALID_KEY_LENGTH + */ +int mbedtls_aesce_setkey_enc(unsigned char *rk, + const unsigned char *key, + size_t bits); + #ifdef __cplusplus } #endif