use std::fmt;
use serde::{de, ser, Deserialize, Serialize};
use crate::http::private::cookie::Key;
use crate::request::{Outcome, Request, FromRequest};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
enum Kind {
Zero,
Generated,
Provided
}
#[derive(Clone)]
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
pub struct SecretKey {
pub(crate) key: Key,
provided: bool,
}
impl SecretKey {
pub(crate) fn zero() -> SecretKey {
SecretKey { key: Key::from(&[0; 64]), provided: false }
}
pub fn from(master: &[u8]) -> SecretKey {
SecretKey { key: Key::from(master), provided: true }
}
pub fn derive_from(material: &[u8]) -> SecretKey {
SecretKey { key: Key::derive_from(material), provided: true }
}
pub fn generate() -> Option<SecretKey> {
Some(SecretKey { key: Key::try_generate()?, provided: false })
}
pub fn is_zero(&self) -> bool {
self == &Self::zero()
}
pub fn is_provided(&self) -> bool {
self.provided && !self.is_zero()
}
pub(crate) fn serialize_zero<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where S: ser::Serializer
{
ser.serialize_bytes(&[0; 32][..])
}
}
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
#[crate::async_trait]
impl<'r> FromRequest<'r> for &'r SecretKey {
type Error = std::convert::Infallible;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
Outcome::Success(&req.rocket().config().secret_key)
}
}
impl<'de> Deserialize<'de> for SecretKey {
fn deserialize<D: de::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
use {binascii::{b64decode, hex2bin}, de::Unexpected::Str};
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = SecretKey;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("256-bit base64 or hex string, or 32-byte slice")
}
fn visit_str<E: de::Error>(self, val: &str) -> Result<SecretKey, E> {
let e = |s| E::invalid_value(Str(s), &"256-bit base64 or hex");
let mut buf = [0u8; 96];
let bytes = match val.len() {
44 | 88 => b64decode(val.as_bytes(), &mut buf).map_err(|_| e(val))?,
64 => hex2bin(val.as_bytes(), &mut buf).map_err(|_| e(val))?,
n => Err(E::invalid_length(n, &"44 or 88 for base64, 64 for hex"))?
};
self.visit_bytes(bytes)
}
fn visit_bytes<E: de::Error>(self, bytes: &[u8]) -> Result<SecretKey, E> {
if bytes.len() < 32 {
Err(E::invalid_length(bytes.len(), &"at least 32"))
} else if bytes.iter().all(|b| *b == 0) {
Ok(SecretKey::zero())
} else if bytes.len() >= 64 {
Ok(SecretKey::from(bytes))
} else {
Ok(SecretKey::derive_from(bytes))
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: de::SeqAccess<'de>
{
let mut bytes = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(byte) = seq.next_element()? {
bytes.push(byte);
}
self.visit_bytes(&bytes)
}
}
de.deserialize_any(Visitor)
}
}
impl fmt::Display for SecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_zero() {
f.write_str("[zero]")
} else {
match self.provided {
true => f.write_str("[provided]"),
false => f.write_str("[generated]"),
}
}
}
}
impl fmt::Debug for SecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Display>::fmt(self, f)
}
}