use cipher::{
block::{Block, BlockCipher, NewBlockCipher, ParBlocks},
generic_array::{
typenum::{Unsigned, U16},
ArrayLength, GenericArray,
},
stream::{
FromBlockCipher, LoopError, OverflowError, SeekNum, SyncStreamCipher, SyncStreamCipherSeek,
},
};
use core::{convert::TryInto, fmt, mem};
#[inline(always)]
fn xor(buf: &mut [u8], key: &[u8]) {
debug_assert_eq!(buf.len(), key.len());
for (a, b) in buf.iter_mut().zip(key) {
*a ^= *b;
}
}
type Nonce = GenericArray<u8, U16>;
pub struct Ctr128<C>
where
C: BlockCipher<BlockSize = U16>,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
cipher: C,
block: Block<C>,
nonce: [u64; 2],
counter: u64,
pos: u8,
}
impl<C> FromBlockCipher for Ctr128<C>
where
C: BlockCipher<BlockSize = U16> + NewBlockCipher,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
type BlockCipher = C;
type NonceSize = C::BlockSize;
fn from_block_cipher(cipher: C, nonce: &Nonce) -> Self {
Self {
cipher,
nonce: [
u64::from_be_bytes(nonce[..8].try_into().unwrap()),
u64::from_be_bytes(nonce[8..].try_into().unwrap()),
],
counter: 0,
block: Default::default(),
pos: 0,
}
}
}
impl<C> SyncStreamCipher for Ctr128<C>
where
C: BlockCipher<BlockSize = U16>,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
fn try_apply_keystream(&mut self, mut data: &mut [u8]) -> Result<(), LoopError> {
self.check_data_len(data)?;
let bs = C::BlockSize::USIZE;
let pos = self.pos as usize;
debug_assert!(bs > pos);
let mut counter = self.counter;
if pos != 0 {
if data.len() < bs - pos {
let n = pos + data.len();
xor(data, &self.block[pos..n]);
self.pos = n as u8;
return Ok(());
} else {
let (l, r) = data.split_at_mut(bs - pos);
data = r;
xor(l, &self.block[pos..]);
counter += 1;
}
}
let pb = C::ParBlocks::USIZE;
let data = if pb != 1 {
let mut chunks = data.chunks_exact_mut(bs * pb);
for chunk in &mut chunks {
let blocks = self.generate_par_blocks(counter);
counter += pb as u64;
xor(chunk, to_slice::<C>(&blocks));
}
chunks.into_remainder()
} else {
data
};
let mut chunks = data.chunks_exact_mut(bs);
for chunk in &mut chunks {
xor(chunk, &self.generate_block(counter));
counter += 1;
}
let rem = chunks.into_remainder();
self.pos = rem.len() as u8;
self.counter = counter;
if !rem.is_empty() {
self.block = self.generate_block(counter);
xor(rem, &self.block[..rem.len()]);
}
Ok(())
}
}
impl<C> SyncStreamCipherSeek for Ctr128<C>
where
C: BlockCipher<BlockSize = U16>,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
fn try_current_pos<T: SeekNum>(&self) -> Result<T, OverflowError> {
T::from_block_byte(self.counter, self.pos, C::BlockSize::U8)
}
fn try_seek<T: SeekNum>(&mut self, pos: T) -> Result<(), LoopError> {
let res = pos.to_block_byte(C::BlockSize::U8)?;
self.counter = res.0;
self.pos = res.1;
if self.pos != 0 {
self.block = self.generate_block(self.counter);
}
Ok(())
}
}
impl<C> Ctr128<C>
where
C: BlockCipher<BlockSize = U16>,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
#[inline(always)]
fn generate_par_blocks(&self, counter: u64) -> ParBlocks<C> {
let mut block = self.nonce;
block[1] = block[1].wrapping_add(counter);
let mut blocks: ParBlocks<C> = unsafe { mem::zeroed() };
for b in blocks.iter_mut() {
let block_be = conv_be(block);
*b = unsafe { mem::transmute_copy(&block_be) };
block[1] = block[1].wrapping_add(1);
}
self.cipher.encrypt_blocks(&mut blocks);
blocks
}
#[inline(always)]
fn generate_block(&self, counter: u64) -> Block<C> {
let mut block = self.nonce;
block[1] = block[1].wrapping_add(counter);
let mut block: Block<C> = unsafe { mem::transmute(conv_be(block)) };
self.cipher.encrypt_block(&mut block);
block
}
fn check_data_len(&self, data: &[u8]) -> Result<(), LoopError> {
let bs = C::BlockSize::USIZE;
let leftover_bytes = bs - self.pos as usize;
if data.len() < leftover_bytes {
return Ok(());
}
let blocks = 1 + (data.len() - leftover_bytes) / bs;
self.counter
.checked_add(blocks as u64)
.ok_or(LoopError)
.map(|_| ())
}
}
impl<C> fmt::Debug for Ctr128<C>
where
C: BlockCipher<BlockSize = U16> + fmt::Debug,
C::ParBlocks: ArrayLength<GenericArray<u8, U16>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Ctr128-{:?}", self.cipher)
}
}
#[inline(always)]
fn conv_be(val: [u64; 2]) -> [u64; 2] {
[val[0].to_be(), val[1].to_be()]
}
#[inline(always)]
fn to_slice<C: BlockCipher>(blocks: &ParBlocks<C>) -> &[u8] {
let blocks_len = C::BlockSize::to_usize() * C::ParBlocks::to_usize();
unsafe { core::slice::from_raw_parts(blocks.as_ptr() as *const u8, blocks_len) }
}