#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
unused_import_braces,
unused_extern_crates,
trivial_casts,
trivial_numeric_casts
)]
#![cfg_attr(docsrs, feature(doc_cfg))]
mod error;
mod managers;
use std::{
collections::HashMap, convert::TryFrom, fmt, str::FromStr, time::SystemTime,
};
use http::{header::CACHE_CONTROL, request, response, StatusCode};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use serde::{Deserialize, Serialize};
use url::Url;
pub use error::{CacheError, Result};
#[cfg(feature = "manager-cacache")]
pub use managers::cacache::CACacheManager;
#[cfg(feature = "manager-moka")]
pub use managers::moka::MokaManager;
#[cfg(feature = "manager-moka")]
#[cfg_attr(docsrs, doc(cfg(feature = "manager-moka")))]
pub use moka::future::{Cache as MokaCache, CacheBuilder as MokaCacheBuilder};
pub const XCACHE: &str = "x-cache";
pub const XCACHELOOKUP: &str = "x-cache-lookup";
#[derive(Debug, Copy, Clone)]
pub enum HitOrMiss {
HIT,
MISS,
}
impl fmt::Display for HitOrMiss {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::HIT => write!(f, "HIT"),
Self::MISS => write!(f, "MISS"),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[non_exhaustive]
pub enum HttpVersion {
#[serde(rename = "HTTP/0.9")]
Http09,
#[serde(rename = "HTTP/1.0")]
Http10,
#[serde(rename = "HTTP/1.1")]
Http11,
#[serde(rename = "HTTP/2.0")]
H2,
#[serde(rename = "HTTP/3.0")]
H3,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpResponse {
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
pub status: u16,
pub url: Url,
pub version: HttpVersion,
}
impl HttpResponse {
pub fn parts(&self) -> Result<response::Parts> {
let mut converted =
response::Builder::new().status(self.status).body(())?;
{
let headers = converted.headers_mut();
for header in &self.headers {
headers.insert(
http::header::HeaderName::from_str(header.0.as_str())?,
http::HeaderValue::from_str(header.1.as_str())?,
);
}
}
Ok(converted.into_parts().0)
}
#[must_use]
pub fn warning_code(&self) -> Option<usize> {
self.headers.get("warning").and_then(|hdr| {
hdr.as_str().chars().take(3).collect::<String>().parse().ok()
})
}
pub fn add_warning(&mut self, url: &Url, code: usize, message: &str) {
self.headers.insert(
"warning".to_string(),
format!(
"{} {} {:?} \"{}\"",
code,
url.host().expect("Invalid URL"),
message,
httpdate::fmt_http_date(SystemTime::now())
),
);
}
pub fn remove_warning(&mut self) {
self.headers.remove("warning");
}
pub fn update_headers(&mut self, parts: &response::Parts) -> Result<()> {
for header in parts.headers.iter() {
self.headers.insert(
header.0.as_str().to_string(),
header.1.to_str()?.to_string(),
);
}
Ok(())
}
#[must_use]
pub fn must_revalidate(&self) -> bool {
self.headers.get(CACHE_CONTROL.as_str()).map_or(false, |val| {
val.as_str().to_lowercase().contains("must-revalidate")
})
}
pub fn cache_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHE.to_string(), hit_or_miss.to_string());
}
pub fn cache_lookup_status(&mut self, hit_or_miss: HitOrMiss) {
self.headers.insert(XCACHELOOKUP.to_string(), hit_or_miss.to_string());
}
}
#[async_trait::async_trait]
pub trait CacheManager: Send + Sync + 'static {
async fn get(
&self,
method: &str,
url: &Url,
) -> Result<Option<(HttpResponse, CachePolicy)>>;
async fn put(
&self,
method: &str,
url: &Url,
res: HttpResponse,
policy: CachePolicy,
) -> Result<HttpResponse>;
async fn delete(&self, method: &str, url: &Url) -> Result<()>;
}
#[async_trait::async_trait]
pub trait Middleware: Send {
fn is_method_get_head(&self) -> bool;
fn policy(&self, response: &HttpResponse) -> Result<CachePolicy>;
fn policy_with_options(
&self,
response: &HttpResponse,
options: CacheOptions,
) -> Result<CachePolicy>;
fn update_headers(&mut self, parts: &request::Parts) -> Result<()>;
fn force_no_cache(&mut self) -> Result<()>;
fn parts(&self) -> Result<request::Parts>;
fn url(&self) -> Result<Url>;
fn method(&self) -> Result<String>;
async fn remote_fetch(&mut self) -> Result<HttpResponse>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheMode {
Default,
NoStore,
Reload,
NoCache,
ForceCache,
OnlyIfCached,
}
impl TryFrom<http::Version> for HttpVersion {
type Error = CacheError;
fn try_from(value: http::Version) -> Result<Self> {
Ok(match value {
http::Version::HTTP_09 => Self::Http09,
http::Version::HTTP_10 => Self::Http10,
http::Version::HTTP_11 => Self::Http11,
http::Version::HTTP_2 => Self::H2,
http::Version::HTTP_3 => Self::H3,
_ => return Err(CacheError::BadVersion),
})
}
}
impl From<HttpVersion> for http::Version {
fn from(value: HttpVersion) -> Self {
match value {
HttpVersion::Http09 => Self::HTTP_09,
HttpVersion::Http10 => Self::HTTP_10,
HttpVersion::Http11 => Self::HTTP_11,
HttpVersion::H2 => Self::HTTP_2,
HttpVersion::H3 => Self::HTTP_3,
}
}
}
#[cfg(feature = "http-types")]
impl TryFrom<http_types::Version> for HttpVersion {
type Error = CacheError;
fn try_from(value: http_types::Version) -> Result<Self> {
Ok(match value {
http_types::Version::Http0_9 => Self::Http09,
http_types::Version::Http1_0 => Self::Http10,
http_types::Version::Http1_1 => Self::Http11,
http_types::Version::Http2_0 => Self::H2,
http_types::Version::Http3_0 => Self::H3,
_ => return Err(CacheError::BadVersion),
})
}
}
#[cfg(feature = "http-types")]
impl From<HttpVersion> for http_types::Version {
fn from(value: HttpVersion) -> Self {
match value {
HttpVersion::Http09 => Self::Http0_9,
HttpVersion::Http10 => Self::Http1_0,
HttpVersion::Http11 => Self::Http1_1,
HttpVersion::H2 => Self::Http2_0,
HttpVersion::H3 => Self::Http3_0,
}
}
}
pub use http_cache_semantics::CacheOptions;
#[derive(Debug, Clone)]
pub struct HttpCache<T: CacheManager> {
pub mode: CacheMode,
pub manager: T,
pub options: Option<CacheOptions>,
}
#[allow(dead_code)]
impl<T: CacheManager> HttpCache<T> {
pub async fn run(
&self,
mut middleware: impl Middleware,
) -> Result<HttpResponse> {
let is_cacheable = middleware.is_method_get_head()
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload;
if !is_cacheable {
return self.remote_fetch(&mut middleware).await;
}
let method = middleware.method()?.to_uppercase();
let url = middleware.url()?;
if let Some(store) = self.manager.get(&method, &url).await? {
let (mut res, policy) = store;
res.cache_lookup_status(HitOrMiss::HIT);
if let Some(warning_code) = res.warning_code() {
if (100..200).contains(&warning_code) {
res.remove_warning();
}
}
match self.mode {
CacheMode::Default => {
self.conditional_fetch(middleware, res, policy).await
}
CacheMode::NoCache => {
middleware.force_no_cache()?;
let mut res = self.remote_fetch(&mut middleware).await?;
res.cache_lookup_status(HitOrMiss::HIT);
Ok(res)
}
CacheMode::ForceCache | CacheMode::OnlyIfCached => {
res.add_warning(
&res.url.clone(),
112,
"Disconnected operation",
);
res.cache_status(HitOrMiss::HIT);
Ok(res)
}
_ => self.remote_fetch(&mut middleware).await,
}
} else {
match self.mode {
CacheMode::OnlyIfCached => {
let mut res = HttpResponse {
body: b"GatewayTimeout".to_vec(),
headers: HashMap::default(),
status: 504,
url: middleware.url()?,
version: HttpVersion::Http11,
};
res.cache_status(HitOrMiss::MISS);
res.cache_lookup_status(HitOrMiss::MISS);
Ok(res)
}
_ => self.remote_fetch(&mut middleware).await,
}
}
}
async fn remote_fetch(
&self,
middleware: &mut impl Middleware,
) -> Result<HttpResponse> {
let mut res = middleware.remote_fetch().await?;
res.cache_status(HitOrMiss::MISS);
res.cache_lookup_status(HitOrMiss::MISS);
let policy = match self.options {
Some(options) => middleware.policy_with_options(&res, options)?,
None => middleware.policy(&res)?,
};
let is_get_head = middleware.is_method_get_head();
let is_cacheable = is_get_head
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload
&& res.status == 200
&& policy.is_storable();
let url = middleware.url()?;
let method = middleware.method()?.to_uppercase();
if is_cacheable {
Ok(self.manager.put(&method, &url, res, policy).await?)
} else if !is_get_head {
self.manager.delete("GET", &url).await.ok();
Ok(res)
} else {
Ok(res)
}
}
async fn conditional_fetch(
&self,
mut middleware: impl Middleware,
mut cached_res: HttpResponse,
mut policy: CachePolicy,
) -> Result<HttpResponse> {
let before_req =
policy.before_request(&middleware.parts()?, SystemTime::now());
match before_req {
BeforeRequest::Fresh(parts) => {
cached_res.update_headers(&parts)?;
cached_res.cache_status(HitOrMiss::HIT);
cached_res.cache_lookup_status(HitOrMiss::HIT);
return Ok(cached_res);
}
BeforeRequest::Stale { request: parts, matches } => {
if matches {
middleware.update_headers(&parts)?;
}
}
}
let req_url = middleware.url()?;
match middleware.remote_fetch().await {
Ok(mut cond_res) => {
let status = StatusCode::from_u16(cond_res.status)?;
if status.is_server_error() && cached_res.must_revalidate() {
cached_res.add_warning(
&req_url,
111,
"Revalidation failed",
);
cached_res.cache_status(HitOrMiss::HIT);
Ok(cached_res)
} else if cond_res.status == 304 {
let after_res = policy.after_response(
&middleware.parts()?,
&cond_res.parts()?,
SystemTime::now(),
);
match after_res {
AfterResponse::Modified(new_policy, parts)
| AfterResponse::NotModified(new_policy, parts) => {
policy = new_policy;
cached_res.update_headers(&parts)?;
}
}
cached_res.cache_status(HitOrMiss::HIT);
cached_res.cache_lookup_status(HitOrMiss::HIT);
let method = middleware.method()?.to_uppercase();
let res = self
.manager
.put(&method, &req_url, cached_res, policy)
.await?;
Ok(res)
} else if cond_res.status == 200 {
let policy = match self.options {
Some(options) => middleware
.policy_with_options(&cond_res, options)?,
None => middleware.policy(&cond_res)?,
};
cond_res.cache_status(HitOrMiss::MISS);
cond_res.cache_lookup_status(HitOrMiss::HIT);
let method = middleware.method()?.to_uppercase();
let res = self
.manager
.put(&method, &req_url, cond_res, policy)
.await?;
Ok(res)
} else {
cached_res.cache_status(HitOrMiss::HIT);
Ok(cached_res)
}
}
Err(e) => {
if cached_res.must_revalidate() {
Err(e)
} else {
cached_res.add_warning(
&req_url,
111,
"Revalidation failed",
);
cached_res.cache_status(HitOrMiss::HIT);
Ok(cached_res)
}
}
}
}
}