#![allow(clippy::needless_range_loop)]
use crate::{Block, ParBlocks};
use cipher::{
consts::{U16, U24, U32, U8},
generic_array::GenericArray,
BlockCipher, BlockDecrypt, BlockEncrypt, NewBlockCipher,
};
use core::{arch::aarch64::*, convert::TryInto, mem, slice};
const BLOCK_WORDS: usize = 4;
const WORD_SIZE: usize = 4;
const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36];
macro_rules! define_aes_impl {
(
$name:ident,
$name_enc:ident,
$name_dec:ident,
$key_size:ty,
$rounds:tt,
$doc:expr
) => {
#[doc=$doc]
#[doc = "block cipher"]
#[derive(Clone)]
pub struct $name {
encrypt: $name_enc,
decrypt: $name_dec,
}
impl NewBlockCipher for $name {
type KeySize = $key_size;
#[inline]
fn new(key: &GenericArray<u8, $key_size>) -> Self {
let encrypt = $name_enc::new(key);
let decrypt = $name_dec::from(&encrypt);
Self { encrypt, decrypt }
}
}
impl BlockCipher for $name {
type BlockSize = U16;
type ParBlocks = U8;
}
impl BlockEncrypt for $name {
#[inline]
fn encrypt_block(&self, block: &mut Block) {
self.encrypt.encrypt_block(block)
}
#[inline]
fn encrypt_par_blocks(&self, blocks: &mut ParBlocks) {
self.encrypt.encrypt_par_blocks(blocks)
}
}
impl BlockDecrypt for $name {
#[inline]
fn decrypt_block(&self, block: &mut Block) {
self.decrypt.decrypt_block(block)
}
#[inline]
fn decrypt_par_blocks(&self, blocks: &mut ParBlocks) {
self.decrypt.decrypt_par_blocks(blocks)
}
}
#[doc=$doc]
#[doc = "block cipher (encrypt-only)"]
#[derive(Clone)]
pub struct $name_enc {
round_keys: [uint8x16_t; $rounds],
}
impl NewBlockCipher for $name_enc {
type KeySize = $key_size;
fn new(key: &GenericArray<u8, $key_size>) -> Self {
Self {
round_keys: expand_key(key.as_ref()),
}
}
}
impl BlockCipher for $name_enc {
type BlockSize = U16;
type ParBlocks = U8;
}
impl BlockEncrypt for $name_enc {
fn encrypt_block(&self, block: &mut Block) {
unsafe { encrypt(&self.round_keys, block) }
}
fn encrypt_par_blocks(&self, blocks: &mut ParBlocks) {
unsafe { encrypt8(&self.round_keys, blocks) }
}
}
#[doc=$doc]
#[doc = "block cipher (decrypt-only)"]
#[derive(Clone)]
pub struct $name_dec {
round_keys: [uint8x16_t; $rounds],
}
impl NewBlockCipher for $name_dec {
type KeySize = $key_size;
fn new(key: &GenericArray<u8, $key_size>) -> Self {
$name_enc::new(key).into()
}
}
impl From<$name_enc> for $name_dec {
fn from(enc: $name_enc) -> $name_dec {
Self::from(&enc)
}
}
impl From<&$name_enc> for $name_dec {
fn from(enc: &$name_enc) -> $name_dec {
let mut round_keys = enc.round_keys;
inverse_expanded_keys(&mut round_keys);
Self { round_keys }
}
}
impl BlockCipher for $name_dec {
type BlockSize = U16;
type ParBlocks = U8;
}
impl BlockDecrypt for $name_dec {
fn decrypt_block(&self, block: &mut Block) {
unsafe { decrypt(&self.round_keys, block) }
}
fn decrypt_par_blocks(&self, blocks: &mut ParBlocks) {
unsafe { decrypt8(&self.round_keys, blocks) }
}
}
opaque_debug::implement!($name);
opaque_debug::implement!($name_enc);
opaque_debug::implement!($name_dec);
};
}
define_aes_impl!(Aes128, Aes128Enc, Aes128Dec, U16, 11, "AES-128");
define_aes_impl!(Aes192, Aes192Enc, Aes192Dec, U24, 13, "AES-192");
define_aes_impl!(Aes256, Aes256Enc, Aes256Dec, U32, 15, "AES-256");
#[inline]
fn expand_key<const L: usize, const N: usize>(key: &[u8; L]) -> [uint8x16_t; N] {
assert!((L == 16 && N == 11) || (L == 24 && N == 13) || (L == 32 && N == 15));
let mut expanded_keys: [uint8x16_t; N] = unsafe { mem::zeroed() };
let ek_words = unsafe {
slice::from_raw_parts_mut(expanded_keys.as_mut_ptr() as *mut u32, N * BLOCK_WORDS)
};
for (i, chunk) in key.chunks_exact(WORD_SIZE).enumerate() {
ek_words[i] = u32::from_ne_bytes(chunk.try_into().unwrap());
}
let nk = L / WORD_SIZE;
for i in nk..(N * BLOCK_WORDS) {
let mut word = ek_words[i - 1];
if i % nk == 0 {
word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1];
} else if nk > 6 && i % nk == 4 {
word = sub_word(word)
}
ek_words[i] = ek_words[i - nk] ^ word;
}
expanded_keys
}
#[inline]
fn inverse_expanded_keys<const N: usize>(expanded_keys: &mut [uint8x16_t; N]) {
assert!(N == 11 || N == 13 || N == 15);
for ek in expanded_keys.iter_mut().take(N - 1).skip(1) {
unsafe { *ek = vaesimcq_u8(*ek) }
}
expanded_keys.reverse();
}
#[target_feature(enable = "crypto")]
#[target_feature(enable = "neon")]
unsafe fn encrypt<const N: usize>(expanded_keys: &[uint8x16_t; N], block: &mut Block) {
let rounds = N - 1;
assert!(rounds == 10 || rounds == 12 || rounds == 14);
let mut state = vld1q_u8(block.as_ptr());
for k in expanded_keys.iter().take(rounds - 1) {
state = vaeseq_u8(state, *k);
state = vaesmcq_u8(state);
}
state = vaeseq_u8(state, expanded_keys[rounds - 1]);
state = veorq_u8(state, expanded_keys[rounds]);
vst1q_u8(block.as_mut_ptr(), state);
}
#[target_feature(enable = "crypto")]
#[target_feature(enable = "neon")]
unsafe fn encrypt8<const N: usize>(expanded_keys: &[uint8x16_t; N], blocks: &mut ParBlocks) {
let rounds = N - 1;
assert!(rounds == 10 || rounds == 12 || rounds == 14);
let mut state = [
vld1q_u8(blocks[0].as_ptr()),
vld1q_u8(blocks[1].as_ptr()),
vld1q_u8(blocks[2].as_ptr()),
vld1q_u8(blocks[3].as_ptr()),
vld1q_u8(blocks[4].as_ptr()),
vld1q_u8(blocks[5].as_ptr()),
vld1q_u8(blocks[6].as_ptr()),
vld1q_u8(blocks[7].as_ptr()),
];
for k in expanded_keys.iter().take(rounds - 1) {
for i in 0..8 {
state[i] = vaeseq_u8(state[i], *k);
state[i] = vaesmcq_u8(state[i]);
}
}
for i in 0..8 {
state[i] = vaeseq_u8(state[i], expanded_keys[rounds - 1]);
state[i] = veorq_u8(state[i], expanded_keys[rounds]);
vst1q_u8(blocks[i].as_mut_ptr(), state[i]);
}
}
#[target_feature(enable = "crypto")]
#[target_feature(enable = "neon")]
unsafe fn decrypt<const N: usize>(expanded_keys: &[uint8x16_t; N], block: &mut Block) {
let rounds = N - 1;
assert!(rounds == 10 || rounds == 12 || rounds == 14);
let mut state = vld1q_u8(block.as_ptr());
for k in expanded_keys.iter().take(rounds - 1) {
state = vaesdq_u8(state, *k);
state = vaesimcq_u8(state);
}
state = vaesdq_u8(state, expanded_keys[rounds - 1]);
state = veorq_u8(state, expanded_keys[rounds]);
vst1q_u8(block.as_mut_ptr(), state);
}
#[target_feature(enable = "crypto")]
#[target_feature(enable = "neon")]
unsafe fn decrypt8<const N: usize>(expanded_keys: &[uint8x16_t; N], blocks: &mut ParBlocks) {
let rounds = N - 1;
assert!(rounds == 10 || rounds == 12 || rounds == 14);
let mut state = [
vld1q_u8(blocks[0].as_ptr()),
vld1q_u8(blocks[1].as_ptr()),
vld1q_u8(blocks[2].as_ptr()),
vld1q_u8(blocks[3].as_ptr()),
vld1q_u8(blocks[4].as_ptr()),
vld1q_u8(blocks[5].as_ptr()),
vld1q_u8(blocks[6].as_ptr()),
vld1q_u8(blocks[7].as_ptr()),
];
for k in expanded_keys.iter().take(rounds - 1) {
for i in 0..8 {
state[i] = vaesdq_u8(state[i], *k);
state[i] = vaesimcq_u8(state[i]);
}
}
for i in 0..8 {
state[i] = vaesdq_u8(state[i], expanded_keys[rounds - 1]);
state[i] = veorq_u8(state[i], expanded_keys[rounds]);
vst1q_u8(blocks[i].as_mut_ptr(), state[i]);
}
}
#[inline(always)]
fn sub_word(input: u32) -> u32 {
unsafe {
let input = vreinterpretq_u8_u32(vdupq_n_u32(input));
let sub_input = vaeseq_u8(input, vdupq_n_u8(0));
vgetq_lane_u32(vreinterpretq_u32_u8(sub_input), 0)
}
}
#[inline(always)]
unsafe fn vst1q_u8(dst: *mut u8, src: uint8x16_t) {
dst.copy_from_nonoverlapping(&src as *const _ as *const u8, 16);
}
#[cfg(test)]
mod tests {
use super::{
decrypt, decrypt8, encrypt, encrypt8, expand_key, inverse_expanded_keys, vst1q_u8,
ParBlocks,
};
use core::{arch::aarch64::*, convert::TryInto};
use hex_literal::hex;
const AES128_KEY: [u8; 16] = hex!("2b7e151628aed2a6abf7158809cf4f3c");
const AES128_EXP_KEYS: [[u8; 16]; 11] = [
AES128_KEY,
hex!("a0fafe1788542cb123a339392a6c7605"),
hex!("f2c295f27a96b9435935807a7359f67f"),
hex!("3d80477d4716fe3e1e237e446d7a883b"),
hex!("ef44a541a8525b7fb671253bdb0bad00"),
hex!("d4d1c6f87c839d87caf2b8bc11f915bc"),
hex!("6d88a37a110b3efddbf98641ca0093fd"),
hex!("4e54f70e5f5fc9f384a64fb24ea6dc4f"),
hex!("ead27321b58dbad2312bf5607f8d292f"),
hex!("ac7766f319fadc2128d12941575c006e"),
hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"),
];
const AES128_EXP_INVKEYS: [[u8; 16]; 11] = [
hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"),
hex!("0c7b5a631319eafeb0398890664cfbb4"),
hex!("df7d925a1f62b09da320626ed6757324"),
hex!("12c07647c01f22c7bc42d2f37555114a"),
hex!("6efcd876d2df54807c5df034c917c3b9"),
hex!("6ea30afcbc238cf6ae82a4b4b54a338d"),
hex!("90884413d280860a12a128421bc89739"),
hex!("7c1f13f74208c219c021ae480969bf7b"),
hex!("cc7505eb3e17d1ee82296c51c9481133"),
hex!("2b3708a7f262d405bc3ebdbf4b617d62"),
AES128_KEY,
];
const AES192_KEY: [u8; 24] = hex!("8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b");
const AES192_EXP_KEYS: [[u8; 16]; 13] = [
hex!("8e73b0f7da0e6452c810f32b809079e5"),
hex!("62f8ead2522c6b7bfe0c91f72402f5a5"),
hex!("ec12068e6c827f6b0e7a95b95c56fec2"),
hex!("4db7b4bd69b5411885a74796e92538fd"),
hex!("e75fad44bb095386485af05721efb14f"),
hex!("a448f6d94d6dce24aa326360113b30e6"),
hex!("a25e7ed583b1cf9a27f939436a94f767"),
hex!("c0a69407d19da4e1ec1786eb6fa64971"),
hex!("485f703222cb8755e26d135233f0b7b3"),
hex!("40beeb282f18a2596747d26b458c553e"),
hex!("a7e1466c9411f1df821f750aad07d753"),
hex!("ca4005388fcc5006282d166abc3ce7b5"),
hex!("e98ba06f448c773c8ecc720401002202"),
];
const AES256_KEY: [u8; 32] =
hex!("603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4");
const AES256_EXP_KEYS: [[u8; 16]; 15] = [
hex!("603deb1015ca71be2b73aef0857d7781"),
hex!("1f352c073b6108d72d9810a30914dff4"),
hex!("9ba354118e6925afa51a8b5f2067fcde"),
hex!("a8b09c1a93d194cdbe49846eb75d5b9a"),
hex!("d59aecb85bf3c917fee94248de8ebe96"),
hex!("b5a9328a2678a647983122292f6c79b3"),
hex!("812c81addadf48ba24360af2fab8b464"),
hex!("98c5bfc9bebd198e268c3ba709e04214"),
hex!("68007bacb2df331696e939e46c518d80"),
hex!("c814e20476a9fb8a5025c02d59c58239"),
hex!("de1369676ccc5a71fa2563959674ee15"),
hex!("5886ca5d2e2f31d77e0af1fa27cf73c3"),
hex!("749c47ab18501ddae2757e4f7401905a"),
hex!("cafaaae3e4d59b349adf6acebd10190d"),
hex!("fe4890d1e6188d0b046df344706c631e"),
];
const INPUT: [u8; 16] = hex!("3243f6a8885a308d313198a2e0370734");
const EXPECTED: [u8; 16] = hex!("3925841d02dc09fbdc118597196a0b32");
fn load_expanded_keys<const N: usize>(input: [[u8; 16]; N]) -> [uint8x16_t; N] {
let mut output = [unsafe { vdupq_n_u8(0) }; N];
for (src, dst) in input.iter().zip(output.iter_mut()) {
*dst = unsafe { vld1q_u8(src.as_ptr()) }
}
output
}
fn store_expanded_keys<const N: usize>(input: [uint8x16_t; N]) -> [[u8; 16]; N] {
let mut output = [[0u8; 16]; N];
for (src, dst) in input.iter().zip(output.iter_mut()) {
unsafe { vst1q_u8(dst.as_mut_ptr(), *src) }
}
output
}
#[test]
fn aes128_key_expansion() {
let ek = expand_key(&AES128_KEY);
assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS);
}
#[test]
fn aes128_key_expansion_inv() {
let mut ek = load_expanded_keys(AES128_EXP_KEYS);
inverse_expanded_keys(&mut ek);
assert_eq!(store_expanded_keys(ek), AES128_EXP_INVKEYS);
}
#[test]
fn aes192_key_expansion() {
let ek = expand_key(&AES192_KEY);
assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS);
}
#[test]
fn aes256_key_expansion() {
let ek = expand_key(&AES256_KEY);
assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS);
}
#[test]
fn aes128_encrypt() {
let mut block = [0u8; 19];
block[3..].copy_from_slice(&INPUT);
unsafe {
encrypt(
&load_expanded_keys(AES128_EXP_KEYS),
(&mut block[3..]).try_into().unwrap(),
)
};
assert_eq!(&block[3..], &EXPECTED);
}
#[test]
fn aes128_encrypt8() {
let mut blocks = ParBlocks::default();
for block in &mut blocks {
block.copy_from_slice(&INPUT);
}
unsafe { encrypt8(&load_expanded_keys(AES128_EXP_KEYS), &mut blocks) };
for block in &blocks {
assert_eq!(block.as_slice(), &EXPECTED);
}
}
#[test]
fn aes128_decrypt() {
let mut block = [0u8; 19];
block[3..].copy_from_slice(&EXPECTED);
unsafe {
decrypt(
&load_expanded_keys(AES128_EXP_INVKEYS),
(&mut block[3..]).try_into().unwrap(),
)
};
assert_eq!(&block[3..], &INPUT);
}
#[test]
fn aes128_decrypt8() {
let mut blocks = ParBlocks::default();
for block in &mut blocks {
block.copy_from_slice(&EXPECTED);
}
unsafe { decrypt8(&load_expanded_keys(AES128_EXP_INVKEYS), &mut blocks) };
for block in &blocks {
assert_eq!(block.as_slice(), &INPUT);
}
}
}