use super::error;
use futures_core::Stream;
use futures_util::stream::FuturesUnordered;
pub use indexmap::Equivalent;
use indexmap::IndexMap;
use std::fmt;
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::oneshot;
use tower_service::Service;
use tracing::{debug, trace};
pub struct ReadyCache<K, S, Req>
where
K: Eq + Hash,
{
pending: FuturesUnordered<Pending<K, S, Req>>,
pending_cancel_txs: IndexMap<K, CancelTx>,
ready: IndexMap<K, (S, CancelPair)>,
}
impl<S, K: Eq + Hash, Req> Unpin for ReadyCache<K, S, Req> {}
type CancelRx = oneshot::Receiver<()>;
type CancelTx = oneshot::Sender<()>;
type CancelPair = (CancelTx, CancelRx);
#[derive(Debug)]
enum PendingError<K, E> {
Canceled(K),
Inner(K, E),
}
struct Pending<K, S, Req> {
key: Option<K>,
cancel: Option<CancelRx>,
ready: Option<S>,
_pd: std::marker::PhantomData<Req>,
}
impl<K, S, Req> Default for ReadyCache<K, S, Req>
where
K: Eq + Hash,
S: Service<Req>,
{
fn default() -> Self {
Self {
ready: IndexMap::default(),
pending: FuturesUnordered::new(),
pending_cancel_txs: IndexMap::default(),
}
}
}
impl<K, S, Req> fmt::Debug for ReadyCache<K, S, Req>
where
K: fmt::Debug + Eq + Hash,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Self {
pending,
pending_cancel_txs,
ready,
} = self;
f.debug_struct("ReadyCache")
.field("pending", pending)
.field("pending_cancel_txs", pending_cancel_txs)
.field("ready", ready)
.finish()
}
}
impl<K, S, Req> ReadyCache<K, S, Req>
where
K: Eq + Hash,
{
pub fn len(&self) -> usize {
self.ready_len() + self.pending_len()
}
pub fn is_empty(&self) -> bool {
self.ready.is_empty() && self.pending.is_empty()
}
pub fn ready_len(&self) -> usize {
self.ready.len()
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
pub fn pending_contains<Q: Hash + Equivalent<K>>(&self, key: &Q) -> bool {
self.pending_cancel_txs.contains_key(key)
}
pub fn get_ready<Q: Hash + Equivalent<K>>(&self, key: &Q) -> Option<(usize, &K, &S)> {
self.ready.get_full(key).map(|(i, k, v)| (i, k, &v.0))
}
pub fn get_ready_mut<Q: Hash + Equivalent<K>>(
&mut self,
key: &Q,
) -> Option<(usize, &K, &mut S)> {
self.ready
.get_full_mut(key)
.map(|(i, k, v)| (i, k, &mut v.0))
}
pub fn get_ready_index(&self, idx: usize) -> Option<(&K, &S)> {
self.ready.get_index(idx).map(|(k, v)| (k, &v.0))
}
pub fn get_ready_index_mut(&mut self, idx: usize) -> Option<(&mut K, &mut S)> {
self.ready.get_index_mut(idx).map(|(k, v)| (k, &mut v.0))
}
pub fn evict<Q: Hash + Equivalent<K>>(&mut self, key: &Q) -> bool {
let canceled = if let Some(c) = self.pending_cancel_txs.swap_remove(key) {
c.send(()).expect("cancel receiver lost");
true
} else {
false
};
self.ready
.swap_remove_full(key)
.map(|_| true)
.unwrap_or(canceled)
}
}
impl<K, S, Req> ReadyCache<K, S, Req>
where
K: Clone + Eq + Hash,
S: Service<Req>,
<S as Service<Req>>::Error: Into<crate::BoxError>,
S::Error: Into<crate::BoxError>,
{
pub fn push(&mut self, key: K, svc: S) {
let cancel = oneshot::channel();
self.push_pending(key, svc, cancel);
}
fn push_pending(&mut self, key: K, svc: S, (cancel_tx, cancel_rx): CancelPair) {
if let Some(c) = self.pending_cancel_txs.insert(key.clone(), cancel_tx) {
c.send(()).expect("cancel receiver lost");
}
self.pending.push(Pending {
key: Some(key),
cancel: Some(cancel_rx),
ready: Some(svc),
_pd: std::marker::PhantomData,
});
}
pub fn poll_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), error::Failed<K>>> {
loop {
match Pin::new(&mut self.pending).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Ok((key, svc, cancel_rx)))) => {
trace!("endpoint ready");
let cancel_tx = self.pending_cancel_txs.swap_remove(&key);
if let Some(cancel_tx) = cancel_tx {
self.ready.insert(key, (svc, (cancel_tx, cancel_rx)));
} else {
debug_assert!(cancel_tx.is_some());
debug!("canceled endpoint removed when ready");
}
}
Poll::Ready(Some(Err(PendingError::Canceled(_)))) => {
debug!("endpoint canceled");
}
Poll::Ready(Some(Err(PendingError::Inner(key, e)))) => {
let cancel_tx = self.pending_cancel_txs.swap_remove(&key);
if cancel_tx.is_some() {
return Err(error::Failed(key, e.into())).into();
} else {
debug_assert!(cancel_tx.is_some());
debug!("canceled endpoint removed on error");
}
}
}
}
}
pub fn check_ready<Q: Hash + Equivalent<K>>(
&mut self,
cx: &mut Context<'_>,
key: &Q,
) -> Result<bool, error::Failed<K>> {
match self.ready.get_full_mut(key) {
Some((index, _, _)) => self.check_ready_index(cx, index),
None => Ok(false),
}
}
pub fn check_ready_index(
&mut self,
cx: &mut Context<'_>,
index: usize,
) -> Result<bool, error::Failed<K>> {
let svc = match self.ready.get_index_mut(index) {
None => return Ok(false),
Some((_, (svc, _))) => svc,
};
match svc.poll_ready(cx) {
Poll::Ready(Ok(())) => Ok(true),
Poll::Pending => {
let (key, (svc, cancel)) = self
.ready
.swap_remove_index(index)
.expect("invalid ready index");
if !self.pending_contains(&key) {
self.push_pending(key, svc, cancel);
}
Ok(false)
}
Poll::Ready(Err(e)) => {
let (key, _) = self
.ready
.swap_remove_index(index)
.expect("invalid ready index");
Err(error::Failed(key, e.into()))
}
}
}
pub fn call_ready<Q: Hash + Equivalent<K>>(&mut self, key: &Q, req: Req) -> S::Future {
let (index, _, _) = self
.ready
.get_full_mut(key)
.expect("check_ready was not called");
self.call_ready_index(index, req)
}
pub fn call_ready_index(&mut self, index: usize, req: Req) -> S::Future {
let (key, (mut svc, cancel)) = self
.ready
.swap_remove_index(index)
.expect("check_ready_index was not called");
let fut = svc.call(req);
if !self.pending_contains(&key) {
self.push_pending(key, svc, cancel);
}
fut
}
}
impl<K, S, Req> Unpin for Pending<K, S, Req> {}
impl<K, S, Req> Future for Pending<K, S, Req>
where
S: Service<Req>,
{
type Output = Result<(K, S, CancelRx), PendingError<K, S::Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut fut = self.cancel.as_mut().expect("polled after complete");
if let Poll::Ready(r) = Pin::new(&mut fut).poll(cx) {
assert!(r.is_ok(), "cancel sender lost");
let key = self.key.take().expect("polled after complete");
return Err(PendingError::Canceled(key)).into();
}
match self
.ready
.as_mut()
.expect("polled after ready")
.poll_ready(cx)
{
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
let key = self.key.take().expect("polled after complete");
let cancel = self.cancel.take().expect("polled after complete");
Ok((key, self.ready.take().expect("polled after ready"), cancel)).into()
}
Poll::Ready(Err(e)) => {
let key = self.key.take().expect("polled after compete");
Err(PendingError::Inner(key, e)).into()
}
}
}
}
impl<K, S, Req> fmt::Debug for Pending<K, S, Req>
where
K: fmt::Debug,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Self {
key,
cancel,
ready,
_pd,
} = self;
f.debug_struct("Pending")
.field("key", key)
.field("cancel", cancel)
.field("ready", ready)
.finish()
}
}