use std::time::{Duration, SystemTime};
use std::{borrow::Cow, path::Path};
use futures::FutureExt;
use reqwest::{Request, Response};
use reqwest_retry::RetryPolicy;
use rkyv::util::AlignedVec;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tracing::{Instrument, debug, info_span, instrument, trace, warn};
use uv_cache::{CacheEntry, Freshness};
use uv_fs::write_atomic;
use uv_redacted::DisplaySafeUrl;
use crate::BaseClient;
use crate::base_client::is_transient_network_error;
use crate::error::ProblemDetails;
use crate::{
Error, ErrorKind,
httpcache::{AfterResponse, BeforeRequest, CachePolicy, CachePolicyBuilder},
rkyvutil::OwnedArchive,
};
async fn extract_problem_details(response: Response) -> Option<ProblemDetails> {
match response.bytes().await {
Ok(bytes) => match serde_json::from_slice(&bytes) {
Ok(details) => Some(details),
Err(err) => {
warn!("Failed to parse problem details: {err}");
None
}
},
Err(err) => {
warn!("Failed to read response body for problem details: {err}");
None
}
}
}
pub trait Cacheable: Sized {
type Target: Send + 'static;
fn from_aligned_bytes(bytes: AlignedVec) -> Result<Self::Target, Error>;
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error>;
fn into_target(self) -> Self::Target;
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(transparent)]
pub(crate) struct SerdeCacheable<T> {
inner: T,
}
impl<T: Serialize + DeserializeOwned + Send + 'static> Cacheable for SerdeCacheable<T> {
type Target = T;
fn from_aligned_bytes(bytes: AlignedVec) -> Result<T, Error> {
Ok(rmp_serde::from_slice::<T>(&bytes).map_err(ErrorKind::Decode)?)
}
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error> {
Ok(Cow::from(
rmp_serde::to_vec(&self.inner).map_err(ErrorKind::Encode)?,
))
}
fn into_target(self) -> Self::Target {
self.inner
}
}
impl<A> Cacheable for OwnedArchive<A>
where
A: rkyv::Archive + for<'a> rkyv::Serialize<crate::rkyvutil::Serializer<'a>> + Send + 'static,
A::Archived: rkyv::Portable
+ rkyv::Deserialize<A, crate::rkyvutil::Deserializer>
+ for<'a> rkyv::bytecheck::CheckBytes<crate::rkyvutil::Validator<'a>>,
{
type Target = Self;
fn from_aligned_bytes(bytes: AlignedVec) -> Result<Self, Error> {
Self::new(bytes)
}
fn to_bytes(&self) -> Result<Cow<'_, [u8]>, Error> {
Ok(Cow::from(Self::as_bytes(self)))
}
fn into_target(self) -> Self::Target {
self
}
}
pub enum CachedClientError<CallbackError: std::error::Error + 'static> {
Client {
retries: Option<u32>,
err: Error,
},
Callback {
retries: Option<u32>,
err: CallbackError,
},
}
impl<CallbackError: std::error::Error + 'static> CachedClientError<CallbackError> {
fn with_retries(self, retries: u32) -> Self {
match self {
Self::Client {
retries: existing_retries,
err,
} => Self::Client {
retries: Some(existing_retries.unwrap_or_default() + retries),
err,
},
Self::Callback {
retries: existing_retries,
err,
} => Self::Callback {
retries: Some(existing_retries.unwrap_or_default() + retries),
err,
},
}
}
fn retries(&self) -> Option<u32> {
match self {
Self::Client { retries, .. } => *retries,
Self::Callback { retries, .. } => *retries,
}
}
fn error(&self) -> &(dyn std::error::Error + 'static) {
match self {
Self::Client { err, .. } => err,
Self::Callback { err, .. } => err,
}
}
}
impl<CallbackError: std::error::Error + 'static> From<Error> for CachedClientError<CallbackError> {
fn from(error: Error) -> Self {
Self::Client {
retries: None,
err: error,
}
}
}
impl<CallbackError: std::error::Error + 'static> From<ErrorKind>
for CachedClientError<CallbackError>
{
fn from(error: ErrorKind) -> Self {
Self::Client {
retries: None,
err: error.into(),
}
}
}
impl<E: Into<Self> + std::error::Error + 'static> From<CachedClientError<E>> for Error {
fn from(error: CachedClientError<E>) -> Self {
match error {
CachedClientError::Client {
retries: Some(retries),
err,
} => Self::new(err.into_kind(), retries),
CachedClientError::Client { retries: None, err } => err,
CachedClientError::Callback {
retries: Some(retries),
err,
} => Self::new(err.into().into_kind(), retries),
CachedClientError::Callback { retries: None, err } => err.into(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CacheControl<'a> {
None,
MustRevalidate,
AllowStale,
Override(&'a str),
}
impl From<Freshness> for CacheControl<'_> {
fn from(value: Freshness) -> Self {
match value {
Freshness::Fresh => Self::None,
Freshness::Stale => Self::MustRevalidate,
Freshness::Missing => Self::None,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedClient(BaseClient);
impl CachedClient {
pub fn new(client: BaseClient) -> Self {
Self(client)
}
pub fn uncached(&self) -> &BaseClient {
&self.0
}
#[instrument(skip_all)]
pub async fn get_serde<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let payload = self
.get_cacheable(req, cache_entry, cache_control, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
#[instrument(skip_all)]
pub async fn get_cacheable<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let cached_response = if let Some(cached) = Self::read_cache(cache_entry).await {
self.send_cached(req, cache_control, cached)
.boxed_local()
.await?
} else {
debug!("No cache entry for: {}", req.url());
let (response, cache_policy) = self.fresh_request(req, cache_control).await?;
CachedResponse::ModifiedOrNew {
response,
cache_policy,
}
};
match cached_response {
CachedResponse::FreshCache(cached) => match Payload::from_aligned_bytes(cached.data) {
Ok(payload) => Ok(payload),
Err(err) => {
warn!(
"Broken fresh cache entry (for payload) at {}, removing: {err}",
cache_entry.path().display()
);
self.resend_and_heal_cache(
fresh_req,
cache_entry,
cache_control,
response_callback,
)
.await
}
},
CachedResponse::NotModified { cached, new_policy } => {
let refresh_cache =
info_span!("refresh_cache", file = %cache_entry.path().display());
async {
let data_with_cache_policy_bytes =
DataWithCachePolicy::serialize(&new_policy, &cached.data)?;
write_atomic(cache_entry.path(), data_with_cache_policy_bytes)
.await
.map_err(ErrorKind::CacheWrite)?;
match Payload::from_aligned_bytes(cached.data) {
Ok(payload) => Ok(payload),
Err(err) => {
warn!(
"Broken fresh cache entry after revalidation \
(for payload) at {}, removing: {err}",
cache_entry.path().display()
);
self.resend_and_heal_cache(
fresh_req,
cache_entry,
cache_control,
response_callback,
)
.await
}
}
}
.instrument(refresh_cache)
.await
}
CachedResponse::ModifiedOrNew {
response,
cache_policy,
} => {
if response.status() == http::StatusCode::NOT_MODIFIED {
warn!("Server returned unusable 304 for: {}", fresh_req.url());
self.resend_and_heal_cache(
fresh_req,
cache_entry,
cache_control,
response_callback,
)
.await
} else {
self.run_response_callback(
cache_entry,
cache_policy,
response,
response_callback,
)
.await
}
}
}
}
pub async fn skip_cache<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let (response, cache_policy) = self.fresh_request(req, cache_control).await?;
let payload = self
.run_response_callback(cache_entry, cache_policy, response, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
async fn resend_and_heal_cache<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
let (response, cache_policy) = self.fresh_request(req, cache_control).await?;
self.run_response_callback(cache_entry, cache_policy, response, response_callback)
.await
}
async fn run_response_callback<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFnOnce(Response) -> Result<Payload, CallBackError>,
>(
&self,
cache_entry: &CacheEntry,
cache_policy: Option<Box<CachePolicy>>,
response: Response,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let new_cache = info_span!("new_cache", file = %cache_entry.path().display());
let data = response_callback(response)
.boxed_local()
.await
.map_err(|err| CachedClientError::Callback { retries: None, err })?;
let Some(cache_policy) = cache_policy else {
return Ok(data.into_target());
};
async {
fs_err::tokio::create_dir_all(cache_entry.dir())
.await
.map_err(ErrorKind::CacheWrite)?;
let data_with_cache_policy_bytes =
DataWithCachePolicy::serialize(&cache_policy, &data.to_bytes()?)?;
write_atomic(cache_entry.path(), data_with_cache_policy_bytes)
.await
.map_err(ErrorKind::CacheWrite)?;
Ok(data.into_target())
}
.instrument(new_cache)
.await
}
#[instrument(name = "read_and_parse_cache", skip_all, fields(file = %cache_entry.path().display()
))]
async fn read_cache(cache_entry: &CacheEntry) -> Option<DataWithCachePolicy> {
match DataWithCachePolicy::from_path_async(cache_entry.path()).await {
Ok(data) => Some(data),
Err(err) => {
if err.is_file_not_exists() {
trace!("No cache entry exists for {}", cache_entry.path().display());
} else {
warn!(
"Broken cache policy entry at {}, removing: {err}",
cache_entry.path().display()
);
let _ = fs_err::tokio::remove_file(&cache_entry.path()).await;
}
None
}
}
}
async fn send_cached(
&self,
mut req: Request,
cache_control: CacheControl<'_>,
cached: DataWithCachePolicy,
) -> Result<CachedResponse, Error> {
match cache_control {
CacheControl::None | CacheControl::AllowStale | CacheControl::Override(..) => {}
CacheControl::MustRevalidate => {
req.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("no-cache"),
);
}
}
Ok(match cached.cache_policy.before_request(&mut req) {
BeforeRequest::Fresh => {
debug!("Found fresh response for: {}", req.url());
CachedResponse::FreshCache(cached)
}
BeforeRequest::Stale(new_cache_policy_builder) => match cache_control {
CacheControl::None | CacheControl::MustRevalidate | CacheControl::Override(_) => {
debug!("Found stale response for: {}", req.url());
self.send_cached_handle_stale(
req,
cache_control,
cached,
new_cache_policy_builder,
)
.await?
}
CacheControl::AllowStale => {
debug!("Found stale (but allowed) response for: {}", req.url());
CachedResponse::FreshCache(cached)
}
},
BeforeRequest::NoMatch => {
warn!(
"Cached response doesn't match current request for: {}",
req.url()
);
let (response, cache_policy) = self.fresh_request(req, cache_control).await?;
CachedResponse::ModifiedOrNew {
response,
cache_policy,
}
}
})
}
async fn send_cached_handle_stale(
&self,
req: Request,
cache_control: CacheControl<'_>,
cached: DataWithCachePolicy,
new_cache_policy_builder: CachePolicyBuilder,
) -> Result<CachedResponse, Error> {
let url = DisplaySafeUrl::from_url(req.url().clone());
debug!("Sending revalidation request for: {url}");
let mut response = self
.0
.execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str()))
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?;
if let Err(status_error) = response.error_for_status_ref() {
let problem_details = if response
.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.map(|ct| ct == "application/problem+json")
.unwrap_or(false)
{
extract_problem_details(response).await
} else {
None
};
return Err(ErrorKind::from_reqwest_with_problem_details(
url.clone(),
status_error,
problem_details,
)
.into());
}
if let CacheControl::Override(header) = cache_control {
response.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_str(header)
.expect("Cache-Control header must be valid UTF-8"),
);
}
match cached
.cache_policy
.after_response(new_cache_policy_builder, &response)
{
AfterResponse::NotModified(new_policy) => {
debug!("Found not-modified response for: {url}");
Ok(CachedResponse::NotModified {
cached,
new_policy: Box::new(new_policy),
})
}
AfterResponse::Modified(new_policy) => {
debug!("Found modified response for: {url}");
Ok(CachedResponse::ModifiedOrNew {
response,
cache_policy: new_policy
.to_archived()
.is_storable()
.then(|| Box::new(new_policy)),
})
}
}
}
#[instrument(skip_all, fields(url = req.url().as_str()))]
async fn fresh_request(
&self,
req: Request,
cache_control: CacheControl<'_>,
) -> Result<(Response, Option<Box<CachePolicy>>), Error> {
let url = DisplaySafeUrl::from_url(req.url().clone());
trace!("Sending fresh {} request for {}", req.method(), url);
let cache_policy_builder = CachePolicyBuilder::new(&req);
let mut response = self
.0
.execute(req)
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?;
if let CacheControl::Override(header) = cache_control {
response.headers_mut().insert(
http::header::CACHE_CONTROL,
http::HeaderValue::from_str(header)
.expect("Cache-Control header must be valid UTF-8"),
);
}
let retry_count = response
.extensions()
.get::<reqwest_retry::RetryCount>()
.map(|retries| retries.value());
if let Err(status_error) = response.error_for_status_ref() {
let problem_details = if response
.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.map(|ct| ct.starts_with("application/problem+json"))
.unwrap_or(false)
{
extract_problem_details(response).await
} else {
None
};
return Err(CachedClientError::<Error>::Client {
retries: retry_count,
err: ErrorKind::from_reqwest_with_problem_details(
url,
status_error,
problem_details,
)
.into(),
}
.into());
}
let cache_policy = cache_policy_builder.build(&response);
let cache_policy = if cache_policy.to_archived().is_storable() {
Some(Box::new(cache_policy))
} else {
None
};
Ok((response, cache_policy))
}
#[instrument(skip_all)]
pub async fn get_serde_with_retry<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let payload = self
.get_cacheable_with_retry(req, cache_entry, cache_control, async |resp| {
let payload = response_callback(resp).await?;
Ok(SerdeCacheable { inner: payload })
})
.await?;
Ok(payload)
}
#[instrument(skip_all)]
pub async fn get_cacheable_with_retry<
Payload: Cacheable,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload::Target, CachedClientError<CallBackError>> {
let mut past_retries = 0;
let start_time = SystemTime::now();
let retry_policy = self.uncached().retry_policy();
loop {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let result = self
.get_cacheable(fresh_req, cache_entry, cache_control, &response_callback)
.await;
let middleware_retries = match &result {
Err(err) => err.retries().unwrap_or_default(),
Ok(_) => 0,
};
if result
.as_ref()
.is_err_and(|err| is_transient_network_error(err.error()))
{
let total_retries = past_retries + middleware_retries;
let retry_decision = retry_policy.should_retry(start_time, total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
let duration = execute_after
.duration_since(SystemTime::now())
.unwrap_or_else(|_| Duration::default());
debug!(
"Transient failure while handling response from {}; retrying after {:.1}s...",
req.url(),
duration.as_secs_f32(),
);
tokio::time::sleep(duration).await;
past_retries += 1;
continue;
}
}
if past_retries > 0 {
return result.map_err(|err| err.with_retries(past_retries));
}
return result;
}
}
pub async fn skip_cache_with_retry<
Payload: Serialize + DeserializeOwned + Send + 'static,
CallBackError: std::error::Error + 'static,
Callback: AsyncFn(Response) -> Result<Payload, CallBackError>,
>(
&self,
req: Request,
cache_entry: &CacheEntry,
cache_control: CacheControl<'_>,
response_callback: Callback,
) -> Result<Payload, CachedClientError<CallBackError>> {
let mut past_retries = 0;
let start_time = SystemTime::now();
let retry_policy = self.uncached().retry_policy();
loop {
let fresh_req = req.try_clone().expect("HTTP request must be cloneable");
let result = self
.skip_cache(fresh_req, cache_entry, cache_control, &response_callback)
.await;
let middleware_retries = match &result {
Err(err) => err.retries().unwrap_or_default(),
_ => 0,
};
if result
.as_ref()
.err()
.is_some_and(|err| is_transient_network_error(err.error()))
{
let total_retries = past_retries + middleware_retries;
let retry_decision = retry_policy.should_retry(start_time, total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
let duration = execute_after
.duration_since(SystemTime::now())
.unwrap_or_else(|_| Duration::default());
debug!(
"Transient failure while handling response from {}; retrying after {}s...",
req.url(),
duration.as_secs(),
);
tokio::time::sleep(duration).await;
past_retries += 1;
continue;
}
}
if past_retries > 0 {
return result.map_err(|err| err.with_retries(past_retries));
}
return result;
}
}
}
#[derive(Debug)]
enum CachedResponse {
FreshCache(DataWithCachePolicy),
NotModified {
cached: DataWithCachePolicy,
new_policy: Box<CachePolicy>,
},
ModifiedOrNew {
response: Response,
cache_policy: Option<Box<CachePolicy>>,
},
}
#[derive(Debug)]
pub struct DataWithCachePolicy {
pub data: AlignedVec,
cache_policy: OwnedArchive<CachePolicy>,
}
impl DataWithCachePolicy {
async fn from_path_async(path: &Path) -> Result<Self, Error> {
let path = path.to_path_buf();
tokio::task::spawn_blocking(move || Self::from_path_sync(&path))
.await
.unwrap()
}
#[instrument]
fn from_path_sync(path: &Path) -> Result<Self, Error> {
let file = fs_err::File::open(path).map_err(ErrorKind::Io)?;
Self::from_reader(file)
}
pub fn from_reader(mut rdr: impl std::io::Read) -> Result<Self, Error> {
let mut aligned_bytes = AlignedVec::new();
aligned_bytes
.extend_from_reader(&mut rdr)
.map_err(ErrorKind::Io)?;
Self::from_aligned_bytes(aligned_bytes)
}
fn from_aligned_bytes(mut bytes: AlignedVec) -> Result<Self, Error> {
let cache_policy = Self::deserialize_cache_policy(&mut bytes)?;
Ok(Self {
data: bytes,
cache_policy,
})
}
fn serialize(cache_policy: &CachePolicy, data: &[u8]) -> Result<Vec<u8>, Error> {
let mut buf = vec![];
Self::serialize_to_writer(cache_policy, data, &mut buf)?;
Ok(buf)
}
fn serialize_to_writer(
cache_policy: &CachePolicy,
data: &[u8],
mut wtr: impl std::io::Write,
) -> Result<(), Error> {
let cache_policy_archived = OwnedArchive::from_unarchived(cache_policy)?;
let cache_policy_bytes = OwnedArchive::as_bytes(&cache_policy_archived);
wtr.write_all(data).map_err(ErrorKind::Io)?;
wtr.write_all(cache_policy_bytes).map_err(ErrorKind::Io)?;
let len = u64::try_from(cache_policy_bytes.len()).map_err(|_| {
let msg = format!(
"failed to represent {} (length of cache policy) in a u64",
cache_policy_bytes.len()
);
ErrorKind::Io(std::io::Error::other(msg))
})?;
wtr.write_all(&len.to_le_bytes()).map_err(ErrorKind::Io)?;
Ok(())
}
fn deserialize_cache_policy(
bytes: &mut AlignedVec,
) -> Result<OwnedArchive<CachePolicy>, Error> {
let len = Self::deserialize_cache_policy_len(bytes)?;
let cache_policy_bytes_start = bytes.len() - (len + 8);
let cache_policy_bytes = &bytes[cache_policy_bytes_start..][..len];
let mut cache_policy_bytes_aligned = AlignedVec::with_capacity(len);
cache_policy_bytes_aligned.extend_from_slice(cache_policy_bytes);
assert!(
cache_policy_bytes_start <= bytes.len(),
"slicing cache policy should result in a truncation"
);
bytes.resize(cache_policy_bytes_start, 0);
OwnedArchive::new(cache_policy_bytes_aligned)
}
fn deserialize_cache_policy_len(bytes: &[u8]) -> Result<usize, Error> {
let Some(cache_policy_len_start) = bytes.len().checked_sub(8) else {
let msg = format!(
"data-with-cache-policy buffer should be at least 8 bytes \
in length, but is {} bytes",
bytes.len(),
);
return Err(ErrorKind::ArchiveRead(msg).into());
};
let cache_policy_len_bytes = <[u8; 8]>::try_from(&bytes[cache_policy_len_start..])
.expect("cache policy length is 8 bytes");
let len_u64 = u64::from_le_bytes(cache_policy_len_bytes);
let Ok(len_usize) = usize::try_from(len_u64) else {
let msg = format!(
"data-with-cache-policy has cache policy length of {len_u64}, \
but overflows usize",
);
return Err(ErrorKind::ArchiveRead(msg).into());
};
if bytes.len() < len_usize + 8 {
let msg = format!(
"invalid cache entry: data-with-cache-policy has cache policy length of {}, \
but total buffer size is {}",
len_usize,
bytes.len(),
);
return Err(ErrorKind::ArchiveRead(msg).into());
}
Ok(len_usize)
}
}