use core::future::Future;
use core::pin::Pin;
use core::task::ready;
use core::task::Context;
use core::task::Poll;
use core::time::Duration;
use crate::backoff::BackoffBuilder;
use crate::sleep::MaybeSleeper;
use crate::Backoff;
use crate::DefaultSleeper;
use crate::Sleeper;
pub trait Retryable<
B: BackoffBuilder,
T,
E,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
>
{
fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn>;
}
impl<B, T, E, Fut, FutureFn> Retryable<B, T, E, Fut, FutureFn> for FutureFn
where
B: BackoffBuilder,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
{
fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn> {
Retry::new(self, builder.build())
}
}
pub struct Retry<
B: Backoff,
T,
E,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
SF: MaybeSleeper = DefaultSleeper,
RF = fn(&E) -> bool,
NF = fn(&E, Duration),
> {
backoff: B,
retryable: RF,
notify: NF,
future_fn: FutureFn,
sleep_fn: SF,
state: State<T, E, Fut, SF::Sleep>,
}
impl<B, T, E, Fut, FutureFn> Retry<B, T, E, Fut, FutureFn>
where
B: Backoff,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
{
fn new(future_fn: FutureFn, backoff: B) -> Self {
Retry {
backoff,
retryable: |_: &E| true,
notify: |_: &E, _: Duration| {},
future_fn,
sleep_fn: DefaultSleeper::default(),
state: State::Idle,
}
}
}
impl<B, T, E, Fut, FutureFn, SF, RF, NF> Retry<B, T, E, Fut, FutureFn, SF, RF, NF>
where
B: Backoff,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
SF: MaybeSleeper,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
pub fn sleep<SN: Sleeper>(self, sleep_fn: SN) -> Retry<B, T, E, Fut, FutureFn, SN, RF, NF> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify: self.notify,
future_fn: self.future_fn,
sleep_fn,
state: State::Idle,
}
}
pub fn when<RN: FnMut(&E) -> bool>(
self,
retryable: RN,
) -> Retry<B, T, E, Fut, FutureFn, SF, RN, NF> {
Retry {
backoff: self.backoff,
retryable,
notify: self.notify,
future_fn: self.future_fn,
sleep_fn: self.sleep_fn,
state: self.state,
}
}
pub fn notify<NN: FnMut(&E, Duration)>(
self,
notify: NN,
) -> Retry<B, T, E, Fut, FutureFn, SF, RF, NN> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify,
sleep_fn: self.sleep_fn,
future_fn: self.future_fn,
state: self.state,
}
}
}
#[derive(Default)]
enum State<T, E, Fut: Future<Output = Result<T, E>>, SleepFut: Future<Output = ()>> {
#[default]
Idle,
Polling(Fut),
Sleeping(SleepFut),
}
impl<B, T, E, Fut, FutureFn, SF, RF, NF> Future for Retry<B, T, E, Fut, FutureFn, SF, RF, NF>
where
B: Backoff,
Fut: Future<Output = Result<T, E>>,
FutureFn: FnMut() -> Fut,
SF: Sleeper,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
loop {
match &mut this.state {
State::Idle => {
let fut = (this.future_fn)();
this.state = State::Polling(fut);
continue;
}
State::Polling(fut) => {
let mut fut = unsafe { Pin::new_unchecked(fut) };
match ready!(fut.as_mut().poll(cx)) {
Ok(v) => return Poll::Ready(Ok(v)),
Err(err) => {
if !(this.retryable)(&err) {
return Poll::Ready(Err(err));
}
match this.backoff.next() {
None => return Poll::Ready(Err(err)),
Some(dur) => {
(this.notify)(&err, dur);
this.state = State::Sleeping(this.sleep_fn.sleep(dur));
continue;
}
}
}
}
}
State::Sleeping(sl) => {
let mut sl = unsafe { Pin::new_unchecked(sl) };
ready!(sl.as_mut().poll(cx));
this.state = State::Idle;
continue;
}
}
}
}
}
#[cfg(test)]
#[cfg(any(feature = "tokio-sleep", feature = "gloo-timers-sleep"))]
mod default_sleeper_tests {
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use core::time::Duration;
use tokio::sync::Mutex;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[cfg(not(target_arch = "wasm32"))]
use tokio::test;
use super::*;
use crate::ExponentialBuilder;
async fn always_error() -> anyhow::Result<()> {
Err(anyhow::anyhow!("test_query meets error"))
}
#[test]
async fn test_retry() {
let result = always_error
.retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
.await;
assert!(result.is_err());
assert_eq!("test_query meets error", result.unwrap_err().to_string());
}
#[test]
async fn test_retry_with_not_retryable_error() {
let error_times = Mutex::new(0);
let f = || async {
let mut x = error_times.lock().await;
*x += 1;
Err::<(), anyhow::Error>(anyhow::anyhow!("not retryable"))
};
let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
let result = f
.retry(backoff)
.when(|e| e.to_string() == "retryable")
.await;
assert!(result.is_err());
assert_eq!("not retryable", result.unwrap_err().to_string());
assert_eq!(*error_times.lock().await, 1);
}
#[test]
async fn test_retry_with_retryable_error() {
let error_times = Mutex::new(0);
let f = || async {
let mut x = error_times.lock().await;
*x += 1;
Err::<(), anyhow::Error>(anyhow::anyhow!("retryable"))
};
let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
let result = f
.retry(backoff)
.when(|e| e.to_string() == "retryable")
.await;
assert!(result.is_err());
assert_eq!("retryable", result.unwrap_err().to_string());
assert_eq!(*error_times.lock().await, 4);
}
#[test]
async fn test_fn_mut_when_and_notify() {
let mut calls_retryable: Vec<()> = vec![];
let mut calls_notify: Vec<()> = vec![];
let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };
let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
let result = f
.retry(backoff)
.when(|_| {
calls_retryable.push(());
true
})
.notify(|_, _| {
calls_notify.push(());
})
.await;
assert!(result.is_err());
assert_eq!("retryable", result.unwrap_err().to_string());
assert_eq!(calls_retryable.len(), 4);
assert_eq!(calls_notify.len(), 3);
}
}
#[cfg(test)]
mod custom_sleeper_tests {
use alloc::string::ToString;
use core::{future::ready, time::Duration};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[cfg(not(target_arch = "wasm32"))]
use tokio::test;
use super::*;
use crate::ExponentialBuilder;
async fn always_error() -> anyhow::Result<()> {
Err(anyhow::anyhow!("test_query meets error"))
}
#[test]
async fn test_retry_with_sleep() {
let result = always_error
.retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
.sleep(|_| ready(()))
.await;
assert!(result.is_err());
assert_eq!("test_query meets error", result.unwrap_err().to_string());
}
}