use crate::{belt_block_raw, from_u32, g5, g13, g21, key_idx, to_u32};
use cipher::{
AlgorithmName, Block, BlockCipherDecBackend, BlockCipherDecClosure, BlockCipherDecrypt,
BlockCipherEncBackend, BlockCipherEncClosure, BlockCipherEncrypt, BlockSizeUser, InOut, Key,
KeyInit, KeySizeUser, ParBlocksSizeUser,
consts::{U1, U16, U32},
};
use core::{fmt, mem::swap, num::Wrapping};
#[cfg(feature = "zeroize")]
use cipher::zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone)]
pub struct BeltBlock {
key: [u32; 8],
}
impl KeySizeUser for BeltBlock {
type KeySize = U32;
}
impl KeyInit for BeltBlock {
fn new(key: &Key<Self>) -> Self {
Self { key: to_u32(key) }
}
}
impl BlockSizeUser for BeltBlock {
type BlockSize = U16;
}
impl ParBlocksSizeUser for BeltBlock {
type ParBlocksSize = U1;
}
impl BlockCipherEncrypt for BeltBlock {
#[inline]
fn encrypt_with_backend(&self, f: impl BlockCipherEncClosure<BlockSize = Self::BlockSize>) {
f.call(self)
}
}
impl BlockCipherEncBackend for BeltBlock {
#[inline]
fn encrypt_block(&self, mut block: InOut<'_, '_, Block<Self>>) {
let x = to_u32(block.get_in());
let y = belt_block_raw(x, &self.key);
let block_out = block.get_out();
*block_out = from_u32(&y).into();
}
}
impl BlockCipherDecrypt for BeltBlock {
#[inline]
fn decrypt_with_backend(&self, f: impl BlockCipherDecClosure<BlockSize = Self::BlockSize>) {
f.call(self)
}
}
impl BlockCipherDecBackend for BeltBlock {
#[inline]
fn decrypt_block(&self, mut block: InOut<'_, '_, Block<Self>>) {
let key = &self.key;
let block_in: [u32; 4] = to_u32(block.get_in());
let mut a = Wrapping(block_in[0]);
let mut b = Wrapping(block_in[1]);
let mut c = Wrapping(block_in[2]);
let mut d = Wrapping(block_in[3]);
for i in (1..9).rev() {
b ^= g5(a + key_idx(key, i, 0));
c ^= g21(d + key_idx(key, i, 1));
a -= g13(b + key_idx(key, i, 2));
let e = g21(b + c + key_idx(key, i, 3)) ^ Wrapping(i as u32);
b += e;
c -= e;
d += g13(c + key_idx(key, i, 4));
b ^= g21(a + key_idx(key, i, 5));
c ^= g5(d + key_idx(key, i, 6));
swap(&mut a, &mut b);
swap(&mut c, &mut d);
swap(&mut a, &mut d);
}
let block_out = block.get_out();
let x = [c.0, a.0, d.0, b.0];
*block_out = from_u32(&x).into();
}
}
impl AlgorithmName for BeltBlock {
fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BeltBlock")
}
}
impl Drop for BeltBlock {
fn drop(&mut self) {
#[cfg(feature = "zeroize")]
self.key.zeroize();
}
}
#[cfg(feature = "zeroize")]
impl ZeroizeOnDrop for BeltBlock {}