use super::{sign_digest, Signature, VerifyingKey};
use crate::{Result, RsaPrivateKey};
use core::marker::PhantomData;
use digest::{Digest, FixedOutputReset, Update};
use rand_core::{CryptoRng, TryCryptoRng};
use signature::{
hazmat::RandomizedPrehashSigner, Keypair, RandomizedDigestSigner, RandomizedMultipartSigner,
RandomizedSigner,
};
use zeroize::ZeroizeOnDrop;
#[cfg(feature = "serde")]
use {
pkcs8::DecodePrivateKey,
serdect::serde::{de, ser, Deserialize, Serialize},
};
#[cfg(feature = "encoding")]
use {
super::get_pss_signature_algo_id,
crate::encoding::verify_algorithm_id,
const_oid::AssociatedOid,
pkcs8::{EncodePrivateKey, SecretDocument},
spki::{
der::AnyRef, AlgorithmIdentifierOwned, AlgorithmIdentifierRef,
AssociatedAlgorithmIdentifier, DynSignatureAlgorithmIdentifier,
},
};
#[cfg(feature = "os_rng")]
use {
rand_core::OsRng,
signature::{hazmat::PrehashSigner, MultipartSigner, Signer},
};
#[derive(Debug, Clone)]
pub struct SigningKey<D>
where
D: Digest,
{
inner: RsaPrivateKey,
salt_len: usize,
phantom: PhantomData<D>,
}
impl<D> SigningKey<D>
where
D: Digest,
{
pub fn new(key: RsaPrivateKey) -> Self {
Self::new_with_salt_len(key, <D as Digest>::output_size())
}
pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self {
Self {
inner: key,
salt_len,
phantom: Default::default(),
}
}
pub fn random<R: CryptoRng + ?Sized>(rng: &mut R, bit_size: usize) -> Result<Self> {
Self::random_with_salt_len(rng, bit_size, <D as Digest>::output_size())
}
pub fn random_with_salt_len<R: CryptoRng + ?Sized>(
rng: &mut R,
bit_size: usize,
salt_len: usize,
) -> Result<Self> {
Ok(Self {
inner: RsaPrivateKey::new(rng, bit_size)?,
salt_len,
phantom: Default::default(),
})
}
pub fn salt_len(&self) -> usize {
self.salt_len
}
}
impl<D> RandomizedDigestSigner<D, Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset + Update,
{
fn try_sign_digest_with_rng<
R: TryCryptoRng + ?Sized,
F: Fn(&mut D) -> signature::Result<()>,
>(
&self,
rng: &mut R,
f: F,
) -> signature::Result<Signature> {
let mut digest = D::new();
f(&mut digest)?;
sign_digest::<_, D>(rng, false, &self.inner, &digest.finalize(), self.salt_len)?
.as_slice()
.try_into()
}
}
impl<D> RandomizedSigner<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset + Update,
{
fn try_sign_with_rng<R: TryCryptoRng + ?Sized>(
&self,
rng: &mut R,
msg: &[u8],
) -> signature::Result<Signature> {
self.try_sign_digest_with_rng(rng, |digest: &mut D| {
Update::update(digest, msg);
Ok(())
})
}
}
impl<D> RandomizedMultipartSigner<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset + Update,
{
fn try_multipart_sign_with_rng<R: TryCryptoRng + ?Sized>(
&self,
rng: &mut R,
msg: &[&[u8]],
) -> signature::Result<Signature> {
self.try_sign_digest_with_rng(rng, |digest: &mut D| {
msg.iter().for_each(|slice| Update::update(digest, slice));
Ok(())
})
}
}
impl<D> RandomizedPrehashSigner<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset + Update,
{
fn sign_prehash_with_rng<R: TryCryptoRng + ?Sized>(
&self,
rng: &mut R,
prehash: &[u8],
) -> signature::Result<Signature> {
sign_digest::<_, D>(rng, false, &self.inner, prehash, self.salt_len)?
.as_slice()
.try_into()
}
}
#[cfg(feature = "os_rng")]
impl<D> PrehashSigner<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset,
{
fn sign_prehash(&self, prehash: &[u8]) -> signature::Result<Signature> {
self.sign_prehash_with_rng(&mut OsRng, prehash)
}
}
#[cfg(feature = "os_rng")]
impl<D> Signer<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset,
{
fn try_sign(&self, msg: &[u8]) -> signature::Result<Signature> {
self.try_sign_with_rng(&mut OsRng, msg)
}
}
#[cfg(feature = "os_rng")]
impl<D> MultipartSigner<Signature> for SigningKey<D>
where
D: Digest + FixedOutputReset,
{
fn try_multipart_sign(&self, msg: &[&[u8]]) -> signature::Result<Signature> {
self.try_multipart_sign_with_rng(&mut OsRng, msg)
}
}
impl<D> AsRef<RsaPrivateKey> for SigningKey<D>
where
D: Digest,
{
fn as_ref(&self) -> &RsaPrivateKey {
&self.inner
}
}
#[cfg(feature = "encoding")]
impl<D> AssociatedAlgorithmIdentifier for SigningKey<D>
where
D: Digest,
{
type Params = AnyRef<'static>;
const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = pkcs1::ALGORITHM_ID;
}
#[cfg(feature = "encoding")]
impl<D> DynSignatureAlgorithmIdentifier for SigningKey<D>
where
D: Digest + AssociatedOid,
{
fn signature_algorithm_identifier(&self) -> spki::Result<AlgorithmIdentifierOwned> {
get_pss_signature_algo_id::<D>(self.salt_len as u8)
}
}
#[cfg(feature = "encoding")]
impl<D> EncodePrivateKey for SigningKey<D>
where
D: Digest,
{
fn to_pkcs8_der(&self) -> pkcs8::Result<SecretDocument> {
self.inner.to_pkcs8_der()
}
}
impl<D> From<RsaPrivateKey> for SigningKey<D>
where
D: Digest,
{
fn from(key: RsaPrivateKey) -> Self {
Self::new(key)
}
}
impl<D> From<SigningKey<D>> for RsaPrivateKey
where
D: Digest,
{
fn from(key: SigningKey<D>) -> Self {
key.inner
}
}
impl<D> Keypair for SigningKey<D>
where
D: Digest,
{
type VerifyingKey = VerifyingKey<D>;
fn verifying_key(&self) -> Self::VerifyingKey {
VerifyingKey {
inner: self.inner.to_public_key(),
salt_len: Some(self.salt_len),
phantom: Default::default(),
}
}
}
#[cfg(feature = "encoding")]
impl<D> TryFrom<pkcs8::PrivateKeyInfoRef<'_>> for SigningKey<D>
where
D: Digest + AssociatedOid,
{
type Error = pkcs8::Error;
fn try_from(private_key_info: pkcs8::PrivateKeyInfoRef<'_>) -> pkcs8::Result<Self> {
verify_algorithm_id(&private_key_info.algorithm)?;
RsaPrivateKey::try_from(private_key_info).map(Self::new)
}
}
impl<D> ZeroizeOnDrop for SigningKey<D> where D: Digest {}
impl<D> PartialEq for SigningKey<D>
where
D: Digest,
{
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner && self.salt_len == other.salt_len
}
}
#[cfg(feature = "serde")]
impl<D> Serialize for SigningKey<D>
where
D: Digest,
{
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: serdect::serde::Serializer,
{
let der = self.to_pkcs8_der().map_err(ser::Error::custom)?;
serdect::slice::serialize_hex_lower_or_bin(&der.as_bytes(), serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, D> Deserialize<'de> for SigningKey<D>
where
D: Digest + AssociatedOid,
{
fn deserialize<De>(deserializer: De) -> core::result::Result<Self, De::Error>
where
De: serdect::serde::Deserializer<'de>,
{
let der_bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
Self::from_pkcs8_der(&der_bytes).map_err(de::Error::custom)
}
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(all(feature = "hazmat", feature = "serde"))]
fn test_serde() {
use super::*;
use crate::RsaPrivateKey;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
use serde_test::{assert_tokens, Configure, Token};
use sha2::Sha256;
let mut rng = ChaCha8Rng::from_seed([42; 32]);
let priv_key = RsaPrivateKey::new_unchecked(&mut rng, 64).expect("failed to generate key");
let signing_key = SigningKey::<Sha256>::new(priv_key);
let tokens = [Token::Str(concat!(
"3056020100300d06092a864886f70d010101050004423040020100020900ab240c",
"3361d02e370203010001020811e54a15259d22f9020500ceff5cf3020500d3a7aa",
"ad020500ccaddf17020500cb529d3d020500bb526d6f"
))];
assert_tokens(&signing_key.readable(), &tokens);
}
}