use super::error;
use futures_core::Stream;
use futures_util::{stream::FuturesUnordered, task::AtomicWaker};
pub use indexmap::Equivalent;
use indexmap::IndexMap;
use std::fmt;
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_service::Service;
use tracing::{debug, trace};
pub struct ReadyCache<K, S, Req>
where
K: Eq + Hash,
{
pending: FuturesUnordered<Pending<K, S, Req>>,
pending_cancel_txs: IndexMap<K, CancelTx>,
ready: IndexMap<K, (S, CancelPair)>,
}
impl<S, K: Eq + Hash, Req> Unpin for ReadyCache<K, S, Req> {}
#[derive(Debug)]
struct Cancel {
waker: AtomicWaker,
canceled: AtomicBool,
}
#[derive(Debug)]
struct CancelRx(Arc<Cancel>);
#[derive(Debug)]
struct CancelTx(Arc<Cancel>);
type CancelPair = (CancelTx, CancelRx);
#[derive(Debug)]
enum PendingError<K, E> {
Canceled(K),
Inner(K, E),
}
pin_project_lite::pin_project! {
struct Pending<K, S, Req> {
key: Option<K>,
cancel: Option<CancelRx>,
ready: Option<S>,
_pd: std::marker::PhantomData<Req>,
}
}
impl<K, S, Req> Default for ReadyCache<K, S, Req>
where
K: Eq + Hash,
S: Service<Req>,
{
fn default() -> Self {
Self {
ready: IndexMap::default(),
pending: FuturesUnordered::new(),
pending_cancel_txs: IndexMap::default(),
}
}
}
impl<K, S, Req> fmt::Debug for ReadyCache<K, S, Req>
where
K: fmt::Debug + Eq + Hash,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Self {
pending,
pending_cancel_txs,
ready,
} = self;
f.debug_struct("ReadyCache")
.field("pending", pending)
.field("pending_cancel_txs", pending_cancel_txs)
.field("ready", ready)
.finish()
}
}
impl<K, S, Req> ReadyCache<K, S, Req>
where
K: Eq + Hash,
{
pub fn len(&self) -> usize {
self.ready_len() + self.pending_len()
}
pub fn is_empty(&self) -> bool {
self.ready.is_empty() && self.pending.is_empty()
}
pub fn ready_len(&self) -> usize {
self.ready.len()
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
pub fn pending_contains<Q: Hash + Equivalent<K>>(&self, key: &Q) -> bool {
self.pending_cancel_txs.contains_key(key)
}
pub fn get_ready<Q: Hash + Equivalent<K>>(&self, key: &Q) -> Option<(usize, &K, &S)> {
self.ready.get_full(key).map(|(i, k, v)| (i, k, &v.0))
}
pub fn get_ready_mut<Q: Hash + Equivalent<K>>(
&mut self,
key: &Q,
) -> Option<(usize, &K, &mut S)> {
self.ready
.get_full_mut(key)
.map(|(i, k, v)| (i, k, &mut v.0))
}
pub fn get_ready_index(&self, idx: usize) -> Option<(&K, &S)> {
self.ready.get_index(idx).map(|(k, v)| (k, &v.0))
}
pub fn get_ready_index_mut(&mut self, idx: usize) -> Option<(&mut K, &mut S)> {
self.ready.get_index_mut(idx).map(|(k, v)| (k, &mut v.0))
}
pub fn evict<Q: Hash + Equivalent<K>>(&mut self, key: &Q) -> bool {
let canceled = if let Some(c) = self.pending_cancel_txs.swap_remove(key) {
c.cancel();
true
} else {
false
};
self.ready
.swap_remove_full(key)
.map(|_| true)
.unwrap_or(canceled)
}
}
impl<K, S, Req> ReadyCache<K, S, Req>
where
K: Clone + Eq + Hash,
S: Service<Req>,
<S as Service<Req>>::Error: Into<crate::BoxError>,
S::Error: Into<crate::BoxError>,
{
pub fn push(&mut self, key: K, svc: S) {
let cancel = cancelable();
self.push_pending(key, svc, cancel);
}
fn push_pending(&mut self, key: K, svc: S, (cancel_tx, cancel_rx): CancelPair) {
if let Some(c) = self.pending_cancel_txs.insert(key.clone(), cancel_tx) {
c.cancel();
}
self.pending.push(Pending {
key: Some(key),
cancel: Some(cancel_rx),
ready: Some(svc),
_pd: std::marker::PhantomData,
});
}
pub fn poll_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), error::Failed<K>>> {
loop {
match Pin::new(&mut self.pending).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Ok((key, svc, cancel_rx)))) => {
trace!("endpoint ready");
let cancel_tx = self.pending_cancel_txs.swap_remove(&key);
if let Some(cancel_tx) = cancel_tx {
self.ready.insert(key, (svc, (cancel_tx, cancel_rx)));
} else {
assert!(
cancel_tx.is_some(),
"services that become ready must have a pending cancelation"
);
}
}
Poll::Ready(Some(Err(PendingError::Canceled(_)))) => {
debug!("endpoint canceled");
}
Poll::Ready(Some(Err(PendingError::Inner(key, e)))) => {
let cancel_tx = self.pending_cancel_txs.swap_remove(&key);
assert!(
cancel_tx.is_some(),
"services that return an error must have a pending cancelation"
);
return Err(error::Failed(key, e.into())).into();
}
}
}
}
pub fn check_ready<Q: Hash + Equivalent<K>>(
&mut self,
cx: &mut Context<'_>,
key: &Q,
) -> Result<bool, error::Failed<K>> {
match self.ready.get_full_mut(key) {
Some((index, _, _)) => self.check_ready_index(cx, index),
None => Ok(false),
}
}
pub fn check_ready_index(
&mut self,
cx: &mut Context<'_>,
index: usize,
) -> Result<bool, error::Failed<K>> {
let svc = match self.ready.get_index_mut(index) {
None => return Ok(false),
Some((_, (svc, _))) => svc,
};
match svc.poll_ready(cx) {
Poll::Ready(Ok(())) => Ok(true),
Poll::Pending => {
let (key, (svc, cancel)) = self
.ready
.swap_remove_index(index)
.expect("invalid ready index");
if !self.pending_contains(&key) {
self.push_pending(key, svc, cancel);
}
Ok(false)
}
Poll::Ready(Err(e)) => {
let (key, _) = self
.ready
.swap_remove_index(index)
.expect("invalid ready index");
Err(error::Failed(key, e.into()))
}
}
}
pub fn call_ready<Q: Hash + Equivalent<K>>(&mut self, key: &Q, req: Req) -> S::Future {
let (index, _, _) = self
.ready
.get_full_mut(key)
.expect("check_ready was not called");
self.call_ready_index(index, req)
}
pub fn call_ready_index(&mut self, index: usize, req: Req) -> S::Future {
let (key, (mut svc, cancel)) = self
.ready
.swap_remove_index(index)
.expect("check_ready_index was not called");
let fut = svc.call(req);
if !self.pending_contains(&key) {
self.push_pending(key, svc, cancel);
}
fut
}
}
fn cancelable() -> CancelPair {
let cx = Arc::new(Cancel {
waker: AtomicWaker::new(),
canceled: AtomicBool::new(false),
});
(CancelTx(cx.clone()), CancelRx(cx))
}
impl CancelTx {
fn cancel(self) {
self.0.canceled.store(true, Ordering::SeqCst);
self.0.waker.wake();
}
}
impl<K, S, Req> Future for Pending<K, S, Req>
where
S: Service<Req>,
{
type Output = Result<(K, S, CancelRx), PendingError<K, S::Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let CancelRx(cancel) = this.cancel.as_mut().expect("polled after complete");
if cancel.canceled.load(Ordering::SeqCst) {
let key = this.key.take().expect("polled after complete");
return Err(PendingError::Canceled(key)).into();
}
match this
.ready
.as_mut()
.expect("polled after ready")
.poll_ready(cx)
{
Poll::Pending => {
let CancelRx(cancel) = this.cancel.as_mut().expect("polled after complete");
cancel.waker.register(cx.waker());
assert!(
!cancel.canceled.load(Ordering::SeqCst),
"cancelation cannot be notified while polling a pending service"
);
Poll::Pending
}
Poll::Ready(Ok(())) => {
let key = this.key.take().expect("polled after complete");
let cancel = this.cancel.take().expect("polled after complete");
Ok((key, this.ready.take().expect("polled after ready"), cancel)).into()
}
Poll::Ready(Err(e)) => {
let key = this.key.take().expect("polled after compete");
Err(PendingError::Inner(key, e)).into()
}
}
}
}
impl<K, S, Req> fmt::Debug for Pending<K, S, Req>
where
K: fmt::Debug,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Self {
key,
cancel,
ready,
_pd,
} = self;
f.debug_struct("Pending")
.field("key", key)
.field("cancel", cancel)
.field("ready", ready)
.finish()
}
}