use alloc::vec::Vec;
use digest::{Digest, DynDigest, FixedOutputReset};
use subtle::{Choice, ConstantTimeEq};
use super::mgf::{mgf1_xor, mgf1_xor_digest};
use crate::errors::{Error, Result};
pub(crate) fn emsa_pss_encode(
m_hash: &[u8],
em_bits: usize,
salt: &[u8],
hash: &mut dyn DynDigest,
) -> Result<Vec<u8>> {
let h_len = hash.output_size();
let s_len = salt.len();
let em_len = em_bits.div_ceil(8);
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
let mut em = vec![0; em_len];
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..(em_len - 1) - db.len()];
let prefix = [0u8; 8];
hash.update(&prefix);
hash.update(m_hash);
hash.update(salt);
let hashed = hash.finalize_reset();
h.copy_from_slice(&hashed);
db[em_len - s_len - h_len - 2] = 0x01;
db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
mgf1_xor(db, hash, h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
em[em_len - 1] = 0xBC;
Ok(em)
}
pub(crate) fn emsa_pss_encode_digest<D>(
m_hash: &[u8],
em_bits: usize,
salt: &[u8],
) -> Result<Vec<u8>>
where
D: Digest + FixedOutputReset,
{
let h_len = <D as Digest>::output_size();
let s_len = salt.len();
let em_len = em_bits.div_ceil(8);
if m_hash.len() != h_len {
return Err(Error::InputNotHashed);
}
if em_len < h_len + s_len + 2 {
return Err(Error::Internal);
}
let mut em = vec![0; em_len];
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..(em_len - 1) - db.len()];
let prefix = [0u8; 8];
let mut hash = D::new();
Digest::update(&mut hash, prefix);
Digest::update(&mut hash, m_hash);
Digest::update(&mut hash, salt);
let hashed = hash.finalize_reset();
h.copy_from_slice(&hashed);
db[em_len - s_len - h_len - 2] = 0x01;
db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
mgf1_xor_digest(db, &mut hash, h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
em[em_len - 1] = 0xBC;
Ok(em)
}
fn emsa_pss_verify_pre<'a>(
m_hash: &[u8],
em: &'a mut [u8],
em_bits: usize,
s_len: usize,
h_len: usize,
) -> Result<(&'a mut [u8], &'a mut [u8])> {
if m_hash.len() != h_len {
return Err(Error::Verification);
}
let em_len = em.len(); if em_len < h_len + s_len + 2 {
return Err(Error::Verification);
}
if em[em.len() - 1] != 0xBC {
return Err(Error::Verification);
}
let (db, h) = em.split_at_mut(em_len - h_len - 1);
let h = &mut h[..h_len];
if db[0]
& (0xFF_u8
.checked_shl(8 - (8 * em_len - em_bits) as u32)
.unwrap_or(0))
!= 0
{
return Err(Error::Verification);
}
Ok((db, h))
}
fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
let valid: Choice = zeroes
.iter()
.fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00));
valid & rest[0].ct_eq(&0x01)
}
pub(crate) fn emsa_pss_verify(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
hash: &mut dyn DynDigest,
key_bits: usize,
) -> Result<()> {
let em_bits = key_bits - 1;
let em_len = em_bits.div_ceil(8);
let key_len = key_bits.div_ceil(8);
let h_len = hash.output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
mgf1_xor(db, hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
hash.update(&prefix[..]);
hash.update(m_hash);
hash.update(salt);
let h0 = hash.finalize_reset();
if (salt_valid & h0.ct_eq(h)).into() {
Ok(())
} else {
Err(Error::Verification)
}
}
pub(crate) fn emsa_pss_verify_digest<D>(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
key_bits: usize,
) -> Result<()>
where
D: Digest + FixedOutputReset,
{
let em_bits = key_bits - 1;
let em_len = em_bits.div_ceil(8);
let key_len = key_bits.div_ceil(8);
let h_len = <D as Digest>::output_size();
let em = &mut em[key_len - em_len..];
let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
let mut hash = D::new();
mgf1_xor_digest::<D>(db, &mut hash, &*h);
db[0] &= 0xFF >> (8 * em_len - em_bits);
let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let salt = &db[db.len() - s_len..];
let prefix = [0u8; 8];
Digest::update(&mut hash, &prefix[..]);
Digest::update(&mut hash, m_hash);
Digest::update(&mut hash, salt);
let h0 = hash.finalize_reset();
if (salt_valid & h0.ct_eq(h)).into() {
Ok(())
} else {
Err(Error::Verification)
}
}