use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use rand_core::CryptoRngCore;
use digest::{Digest, DynDigest};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use zeroize::Zeroizing;
use crate::algorithms::mgf1_xor;
use crate::errors::{Error, Result};
use crate::key::{self, PrivateKey, PublicKey};
use crate::padding::PaddingScheme;
const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951;
pub struct Oaep {
pub digest: Box<dyn DynDigest + Send + Sync>,
pub mgf_digest: Box<dyn DynDigest + Send + Sync>,
pub label: Option<String>,
}
impl Oaep {
pub fn new<T: 'static + Digest + DynDigest + Send + Sync>() -> Self {
Self {
digest: Box::new(T::new()),
mgf_digest: Box::new(T::new()),
label: None,
}
}
pub fn new_with_label<T: 'static + Digest + DynDigest + Send + Sync, S: AsRef<str>>(
label: S,
) -> Self {
Self {
digest: Box::new(T::new()),
mgf_digest: Box::new(T::new()),
label: Some(label.as_ref().to_string()),
}
}
pub fn new_with_mgf_hash<
T: 'static + Digest + DynDigest + Send + Sync,
U: 'static + Digest + DynDigest + Send + Sync,
>() -> Self {
Self {
digest: Box::new(T::new()),
mgf_digest: Box::new(U::new()),
label: None,
}
}
pub fn new_with_mgf_hash_and_label<
T: 'static + Digest + DynDigest + Send + Sync,
U: 'static + Digest + DynDigest + Send + Sync,
S: AsRef<str>,
>(
label: S,
) -> Self {
Self {
digest: Box::new(T::new()),
mgf_digest: Box::new(U::new()),
label: Some(label.as_ref().to_string()),
}
}
}
impl PaddingScheme for Oaep {
fn decrypt<Rng: CryptoRngCore, Priv: PrivateKey>(
mut self,
rng: Option<&mut Rng>,
priv_key: &Priv,
ciphertext: &[u8],
) -> Result<Vec<u8>> {
decrypt(
rng,
priv_key,
ciphertext,
&mut *self.digest,
&mut *self.mgf_digest,
self.label,
)
}
fn encrypt<Rng: CryptoRngCore, Pub: PublicKey>(
mut self,
rng: &mut Rng,
pub_key: &Pub,
msg: &[u8],
) -> Result<Vec<u8>> {
encrypt(
rng,
pub_key,
msg,
&mut *self.digest,
&mut *self.mgf_digest,
self.label,
)
}
}
impl fmt::Debug for Oaep {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OAEP")
.field("digest", &"...")
.field("mgf_digest", &"...")
.field("label", &self.label)
.finish()
}
}
#[inline]
pub fn encrypt<R: CryptoRngCore, K: PublicKey>(
rng: &mut R,
pub_key: &K,
msg: &[u8],
digest: &mut dyn DynDigest,
mgf_digest: &mut dyn DynDigest,
label: Option<String>,
) -> Result<Vec<u8>> {
key::check_public(pub_key)?;
let k = pub_key.size();
let h_size = digest.output_size();
if msg.len() + 2 * h_size + 2 > k {
return Err(Error::MessageTooLong);
}
let label = label.unwrap_or_default();
if label.len() as u64 > MAX_LABEL_LEN {
return Err(Error::LabelTooLong);
}
let mut em = Zeroizing::new(vec![0u8; k]);
let (_, payload) = em.split_at_mut(1);
let (seed, db) = payload.split_at_mut(h_size);
rng.fill_bytes(seed);
let db_len = k - h_size - 1;
digest.update(label.as_bytes());
let p_hash = digest.finalize_reset();
db[0..h_size].copy_from_slice(&p_hash);
db[db_len - msg.len() - 1] = 1;
db[db_len - msg.len()..].copy_from_slice(msg);
mgf1_xor(db, mgf_digest, seed);
mgf1_xor(seed, mgf_digest, db);
pub_key.raw_encryption_primitive(&em, pub_key.size())
}
#[inline]
pub fn decrypt<R: CryptoRngCore, SK: PrivateKey>(
rng: Option<&mut R>,
priv_key: &SK,
ciphertext: &[u8],
digest: &mut dyn DynDigest,
mgf_digest: &mut dyn DynDigest,
label: Option<String>,
) -> Result<Vec<u8>> {
key::check_public(priv_key)?;
let res = decrypt_inner(rng, priv_key, ciphertext, digest, mgf_digest, label)?;
if res.is_none().into() {
return Err(Error::Decryption);
}
let (out, index) = res.unwrap();
Ok(out[index as usize..].to_vec())
}
#[inline]
fn decrypt_inner<R: CryptoRngCore, SK: PrivateKey>(
rng: Option<&mut R>,
priv_key: &SK,
ciphertext: &[u8],
digest: &mut dyn DynDigest,
mgf_digest: &mut dyn DynDigest,
label: Option<String>,
) -> Result<CtOption<(Vec<u8>, u32)>> {
let k = priv_key.size();
if k < 11 {
return Err(Error::Decryption);
}
let h_size = digest.output_size();
if ciphertext.len() != k || k < h_size * 2 + 2 {
return Err(Error::Decryption);
}
let mut em = priv_key.raw_decryption_primitive(rng, ciphertext, priv_key.size())?;
let label = label.unwrap_or_default();
if label.len() as u64 > MAX_LABEL_LEN {
return Err(Error::LabelTooLong);
}
digest.update(label.as_bytes());
let expected_p_hash = &*digest.finalize_reset();
let first_byte_is_zero = em[0].ct_eq(&0u8);
let (_, payload) = em.split_at_mut(1);
let (seed, db) = payload.split_at_mut(h_size);
mgf1_xor(seed, mgf_digest, db);
mgf1_xor(db, mgf_digest, seed);
let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
let mut looking_for_index = Choice::from(1u8);
let mut index = 0u32;
let mut nonzero_before_one = Choice::from(0u8);
for (i, el) in db.iter().skip(h_size).enumerate() {
let equals0 = el.ct_eq(&0u8);
let equals1 = el.ct_eq(&1u8);
index.conditional_assign(&(i as u32), looking_for_index & equals1);
looking_for_index &= !equals1;
nonzero_before_one |= looking_for_index & !equals0;
}
let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
Ok(CtOption::new((em, index + 2 + (h_size * 2) as u32), valid))
}