use crate::{AlgorithmIdentifierRef, Error, Result};
use der::{
asn1::{AnyRef, ObjectIdentifier, OctetStringRef},
Decode, DecodeValue, Encode, EncodeValue, ErrorKind, Length, Reader, Sequence, Tag, Tagged,
Writer,
};
pub const PBKDF2_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.5.12");
pub const HMAC_WITH_SHA1_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.2.7");
pub const HMAC_WITH_SHA224_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.2.840.113549.2.8");
pub const HMAC_WITH_SHA256_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.2.840.113549.2.9");
pub const HMAC_WITH_SHA384_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.2.840.113549.2.10");
pub const HMAC_WITH_SHA512_OID: ObjectIdentifier =
ObjectIdentifier::new_unwrap("1.2.840.113549.2.11");
pub const SCRYPT_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.11591.4.11");
type ScryptCost = u64;
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Kdf<'a> {
Pbkdf2(Pbkdf2Params<'a>),
Scrypt(ScryptParams<'a>),
}
impl<'a> Kdf<'a> {
pub fn key_length(&self) -> Option<u16> {
match self {
Self::Pbkdf2(params) => params.key_length,
Self::Scrypt(params) => params.key_length,
}
}
pub fn oid(&self) -> ObjectIdentifier {
match self {
Self::Pbkdf2(_) => PBKDF2_OID,
Self::Scrypt(_) => SCRYPT_OID,
}
}
pub fn pbkdf2(&self) -> Option<&Pbkdf2Params<'a>> {
match self {
Self::Pbkdf2(params) => Some(params),
_ => None,
}
}
pub fn scrypt(&self) -> Option<&ScryptParams<'a>> {
match self {
Self::Scrypt(params) => Some(params),
_ => None,
}
}
pub fn is_pbkdf2(&self) -> bool {
self.pbkdf2().is_some()
}
pub fn is_scrypt(&self) -> bool {
self.scrypt().is_some()
}
pub fn to_alg_params_invalid(&self) -> Error {
Error::AlgorithmParametersInvalid { oid: self.oid() }
}
}
impl<'a> DecodeValue<'a> for Kdf<'a> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
AlgorithmIdentifierRef::decode_value(reader, header)?.try_into()
}
}
impl EncodeValue for Kdf<'_> {
fn value_len(&self) -> der::Result<Length> {
self.oid().encoded_len()?
+ match self {
Self::Pbkdf2(params) => params.encoded_len()?,
Self::Scrypt(params) => params.encoded_len()?,
}
}
fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
self.oid().encode(writer)?;
match self {
Self::Pbkdf2(params) => params.encode(writer)?,
Self::Scrypt(params) => params.encode(writer)?,
}
Ok(())
}
}
impl<'a> Sequence<'a> for Kdf<'a> {}
impl<'a> From<Pbkdf2Params<'a>> for Kdf<'a> {
fn from(params: Pbkdf2Params<'a>) -> Self {
Kdf::Pbkdf2(params)
}
}
impl<'a> From<ScryptParams<'a>> for Kdf<'a> {
fn from(params: ScryptParams<'a>) -> Self {
Kdf::Scrypt(params)
}
}
impl<'a> TryFrom<AlgorithmIdentifierRef<'a>> for Kdf<'a> {
type Error = der::Error;
fn try_from(alg: AlgorithmIdentifierRef<'a>) -> der::Result<Self> {
if let Some(params) = alg.parameters {
match alg.oid {
PBKDF2_OID => params.try_into().map(Self::Pbkdf2),
SCRYPT_OID => params.try_into().map(Self::Scrypt),
oid => Err(ErrorKind::OidUnknown { oid }.into()),
}
} else {
Err(Tag::OctetString.value_error())
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Pbkdf2Params<'a> {
pub salt: &'a [u8],
pub iteration_count: u32,
pub key_length: Option<u16>,
pub prf: Pbkdf2Prf,
}
impl<'a> Pbkdf2Params<'a> {
pub const MAX_ITERATION_COUNT: u32 = 100_000_000;
const INVALID_ERR: Error = Error::AlgorithmParametersInvalid { oid: PBKDF2_OID };
pub fn hmac_with_sha256(iteration_count: u32, salt: &'a [u8]) -> Result<Self> {
if iteration_count > Self::MAX_ITERATION_COUNT {
return Err(Self::INVALID_ERR);
}
Ok(Self {
salt,
iteration_count,
key_length: None,
prf: Pbkdf2Prf::HmacWithSha256,
})
}
}
impl<'a> DecodeValue<'a> for Pbkdf2Params<'a> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
AnyRef::decode_value(reader, header)?.try_into()
}
}
impl EncodeValue for Pbkdf2Params<'_> {
fn value_len(&self) -> der::Result<Length> {
let len = OctetStringRef::new(self.salt)?.encoded_len()?
+ self.iteration_count.encoded_len()?
+ self.key_length.encoded_len()?;
if self.prf == Pbkdf2Prf::default() {
len
} else {
len + self.prf.encoded_len()?
}
}
fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
OctetStringRef::new(self.salt)?.encode(writer)?;
self.iteration_count.encode(writer)?;
self.key_length.encode(writer)?;
if self.prf == Pbkdf2Prf::default() {
Ok(())
} else {
self.prf.encode(writer)
}
}
}
impl<'a> Sequence<'a> for Pbkdf2Params<'a> {}
impl<'a> TryFrom<AnyRef<'a>> for Pbkdf2Params<'a> {
type Error = der::Error;
fn try_from(any: AnyRef<'a>) -> der::Result<Self> {
any.sequence(|reader| {
Ok(Self {
salt: OctetStringRef::decode(reader)?.as_bytes(),
iteration_count: reader.decode()?,
key_length: reader.decode()?,
prf: Option::<AlgorithmIdentifierRef<'_>>::decode(reader)?
.map(TryInto::try_into)
.transpose()?
.unwrap_or_default(),
})
})
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Pbkdf2Prf {
HmacWithSha1,
HmacWithSha224,
HmacWithSha256,
HmacWithSha384,
HmacWithSha512,
}
impl Pbkdf2Prf {
pub fn oid(self) -> ObjectIdentifier {
match self {
Self::HmacWithSha1 => HMAC_WITH_SHA1_OID,
Self::HmacWithSha224 => HMAC_WITH_SHA224_OID,
Self::HmacWithSha256 => HMAC_WITH_SHA256_OID,
Self::HmacWithSha384 => HMAC_WITH_SHA384_OID,
Self::HmacWithSha512 => HMAC_WITH_SHA512_OID,
}
}
}
impl Default for Pbkdf2Prf {
fn default() -> Self {
Self::HmacWithSha1
}
}
impl<'a> TryFrom<AlgorithmIdentifierRef<'a>> for Pbkdf2Prf {
type Error = der::Error;
fn try_from(alg: AlgorithmIdentifierRef<'a>) -> der::Result<Self> {
if let Some(params) = alg.parameters {
if !params.is_null() {
return Err(params.tag().value_error());
}
} else {
return Err(Tag::Null.value_error());
}
match alg.oid {
HMAC_WITH_SHA1_OID => Ok(Self::HmacWithSha1),
HMAC_WITH_SHA224_OID => Ok(Self::HmacWithSha224),
HMAC_WITH_SHA256_OID => Ok(Self::HmacWithSha256),
HMAC_WITH_SHA384_OID => Ok(Self::HmacWithSha384),
HMAC_WITH_SHA512_OID => Ok(Self::HmacWithSha512),
oid => Err(ErrorKind::OidUnknown { oid }.into()),
}
}
}
impl<'a> From<Pbkdf2Prf> for AlgorithmIdentifierRef<'a> {
fn from(prf: Pbkdf2Prf) -> Self {
let parameters = der::asn1::Null;
AlgorithmIdentifierRef {
oid: prf.oid(),
parameters: Some(parameters.into()),
}
}
}
impl Encode for Pbkdf2Prf {
fn encoded_len(&self) -> der::Result<Length> {
AlgorithmIdentifierRef::try_from(*self)?.encoded_len()
}
fn encode(&self, writer: &mut impl Writer) -> der::Result<()> {
AlgorithmIdentifierRef::try_from(*self)?.encode(writer)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct ScryptParams<'a> {
pub salt: &'a [u8],
pub cost_parameter: ScryptCost,
pub block_size: u16,
pub parallelization: u16,
pub key_length: Option<u16>,
}
impl<'a> ScryptParams<'a> {
#[cfg(feature = "pbes2")]
const INVALID_ERR: Error = Error::AlgorithmParametersInvalid { oid: SCRYPT_OID };
#[cfg(feature = "pbes2")]
pub fn from_params_and_salt(params: scrypt::Params, salt: &'a [u8]) -> Result<Self> {
Ok(Self {
salt,
cost_parameter: 1 << params.log_n(),
block_size: params.r().try_into().map_err(|_| Self::INVALID_ERR)?,
parallelization: params.p().try_into().map_err(|_| Self::INVALID_ERR)?,
key_length: None,
})
}
}
impl<'a> DecodeValue<'a> for ScryptParams<'a> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: der::Header) -> der::Result<Self> {
AnyRef::decode_value(reader, header)?.try_into()
}
}
impl EncodeValue for ScryptParams<'_> {
fn value_len(&self) -> der::Result<Length> {
OctetStringRef::new(self.salt)?.encoded_len()?
+ self.cost_parameter.encoded_len()?
+ self.block_size.encoded_len()?
+ self.parallelization.encoded_len()?
+ self.key_length.encoded_len()?
}
fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> {
OctetStringRef::new(self.salt)?.encode(writer)?;
self.cost_parameter.encode(writer)?;
self.block_size.encode(writer)?;
self.parallelization.encode(writer)?;
self.key_length.encode(writer)?;
Ok(())
}
}
impl<'a> Sequence<'a> for ScryptParams<'a> {}
impl<'a> TryFrom<AnyRef<'a>> for ScryptParams<'a> {
type Error = der::Error;
fn try_from(any: AnyRef<'a>) -> der::Result<Self> {
any.sequence(|reader| {
Ok(Self {
salt: OctetStringRef::decode(reader)?.as_bytes(),
cost_parameter: reader.decode()?,
block_size: reader.decode()?,
parallelization: reader.decode()?,
key_length: reader.decode()?,
})
})
}
}
#[cfg(feature = "pbes2")]
impl<'a> TryFrom<ScryptParams<'a>> for scrypt::Params {
type Error = Error;
fn try_from(params: ScryptParams<'a>) -> Result<scrypt::Params> {
scrypt::Params::try_from(¶ms)
}
}
#[cfg(feature = "pbes2")]
impl<'a> TryFrom<&ScryptParams<'a>> for scrypt::Params {
type Error = Error;
fn try_from(params: &ScryptParams<'a>) -> Result<scrypt::Params> {
let n = params.cost_parameter;
let log_n = ((8 * core::mem::size_of::<ScryptCost>() as u32) - n.leading_zeros() - 1) as u8;
if 1 << log_n != n {
return Err(ScryptParams::INVALID_ERR);
}
scrypt::Params::new(
log_n,
params.block_size.into(),
params.parallelization.into(),
scrypt::Params::RECOMMENDED_LEN,
)
.map_err(|_| ScryptParams::INVALID_ERR)
}
}