diff --git a/library/aesce.c b/library/aesce.c index 29a4ce0183..b4ebdadc0a 100644 --- a/library/aesce.c +++ b/library/aesce.c @@ -160,9 +160,6 @@ void mbedtls_aesce_inverse_key(unsigned char *invkey, } -static uint8_t const rcon[] = { 0x01, 0x02, 0x04, 0x08, 0x10, - 0x20, 0x40, 0x80, 0x1b, 0x36 }; - static inline uint32_t aes_rot_word(uint32_t word) { return (word << (32 - 8)) | (word >> 8); @@ -180,75 +177,47 @@ static inline uint32_t aes_sub_word(uint32_t in) } /* - * Key expansion, 128-bit case + * Key expansion function */ -static void aesce_setkey_enc_128(unsigned char *rk, - const unsigned char *key) +static void aesce_setkey_enc(unsigned char *rk, + const unsigned char *key, + const size_t key_bit_length) { uint32_t *rki; uint32_t *rko; uint32_t *rk_u32 = (uint32_t *) rk; + const uint32_t key_len_in_words = key_bit_length / 32; + const uint32_t key_len_in_bytes = key_bit_length / 8; + static uint8_t const rcon[] = { 0x01, 0x02, 0x04, 0x08, 0x10, + 0x20, 0x40, 0x80, 0x1b, 0x36 }; + const uint32_t rounds = + key_bit_length == 128 ? sizeof(rcon) : key_bit_length == 192 ? 8 : 7; - memcpy(rk, key, (128 / 8)); + memcpy(rk, key, key_len_in_bytes); - for (size_t i = 0; i < sizeof(rcon); i++) { - rki = rk_u32 + i * (128 / 32); - rko = rki + (128 / 32); - rko[0] = aes_rot_word(aes_sub_word(rki[(128 / 32) - 1])) ^ rcon[i] ^ rki[0]; + for (size_t i = 0; i < rounds; i++) { + rki = rk_u32 + i * key_len_in_words; + rko = rki + key_len_in_words; + rko[0] = aes_rot_word(aes_sub_word(rki[key_len_in_words - 1])); + rko[0] ^= rcon[i] ^ rki[0]; rko[1] = rko[0] ^ rki[1]; rko[2] = rko[1] ^ rki[2]; rko[3] = rko[2] ^ rki[3]; - } -} - -/* - * Key expansion, 192-bit case - */ -static void aesce_setkey_enc_192(unsigned char *rk, - const unsigned char *key) -{ - uint32_t *rki; - uint32_t *rko; - uint32_t *rk_u32 = (uint32_t *) rk; - memcpy(rk, key, (192 / 8)); - - for (size_t i = 0; i < 8; i++) { - rki = rk_u32 + i * (192 / 32); - rko = rki + (192 / 32); - rko[0] = aes_rot_word(aes_sub_word(rki[(192 / 32) - 1])) ^ rcon[i] ^ rki[0]; - rko[1] = rko[0] ^ rki[1]; - rko[2] = rko[1] ^ rki[2]; - rko[3] = rko[2] ^ rki[3]; - if (i < 7) { - rko[4] = rko[3] ^ rki[4]; - rko[5] = rko[4] ^ rki[5]; - } - } -} - -/* - * Key expansion, 256-bit case - */ -static void aesce_setkey_enc_256(unsigned char *rk, - const unsigned char *key) -{ - uint32_t *rki; - uint32_t *rko; - uint32_t *rk_u32 = (uint32_t *) rk; - memcpy(rk, key, (256 / 8)); - - for (size_t i = 0; i < 7; i++) { - rki = rk_u32 + i * (256 / 32); - rko = rki + (256 / 32); - rko[0] = aes_rot_word(aes_sub_word(rki[(256 / 32) - 1])) ^ rcon[i] ^ rki[0]; - rko[1] = rko[0] ^ rki[1]; - rko[2] = rko[1] ^ rki[2]; - rko[3] = rko[2] ^ rki[3]; - if (i < 6) { - rko[4] = aes_sub_word(rko[3]) ^ rki[4]; - rko[5] = rko[4] ^ rki[5]; - rko[6] = rko[5] ^ rki[6]; - rko[7] = rko[6] ^ rki[7]; + switch (key_bit_length) { + case 192: + if (i < 7) { + rko[4] = rko[3] ^ rki[4]; + rko[5] = rko[4] ^ rki[5]; + } + break; + case 256: + if (i < 6) { + rko[4] = aes_sub_word(rko[3]) ^ rki[4]; + rko[5] = rko[4] ^ rki[5]; + rko[6] = rko[5] ^ rki[6]; + rko[7] = rko[6] ^ rki[7]; + } + break; } } } @@ -261,9 +230,10 @@ int mbedtls_aesce_setkey_enc(unsigned char *rk, size_t bits) { switch (bits) { - case 128: aesce_setkey_enc_128(rk, key); break; - case 192: aesce_setkey_enc_192(rk, key); break; - case 256: aesce_setkey_enc_256(rk, key); break; + case 128: + case 192: + case 256: + aesce_setkey_enc(rk, key, bits); break; default: return MBEDTLS_ERR_AES_INVALID_KEY_LENGTH; }