use crate::task::AtomicWaker;
use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use futures_core::Stream;
use pin_project_lite::pin_project;
pin_project! {
#[derive(Debug, Clone)]
#[must_use = "futures/streams do nothing unless you poll them"]
pub struct Abortable<T> {
#[pin]
task: T,
inner: Arc<AbortInner>,
}
}
impl<T> Abortable<T> {
pub fn new(task: T, reg: AbortRegistration) -> Self {
Self { task, inner: reg.inner }
}
pub fn is_aborted(&self) -> bool {
self.inner.aborted.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct AbortRegistration {
pub(crate) inner: Arc<AbortInner>,
}
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}
impl AbortHandle {
pub fn new_pair() -> (Self, AbortRegistration) {
let inner =
Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });
(Self { inner: inner.clone() }, AbortRegistration { inner })
}
}
#[derive(Debug)]
pub(crate) struct AbortInner {
pub(crate) waker: AtomicWaker,
pub(crate) aborted: AtomicBool,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;
impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}
#[cfg(feature = "std")]
impl std::error::Error for Aborted {}
impl<T> Abortable<T> {
fn try_poll<I>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
) -> Poll<Result<I, Aborted>> {
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}
if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
return Poll::Ready(Ok(x));
}
self.inner.waker.register(cx.waker());
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}
Poll::Pending
}
}
impl<Fut> Future for Abortable<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Aborted>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.try_poll(cx, |fut, cx| fut.poll(cx))
}
}
impl<St> Stream for Abortable<St>
where
St: Stream,
{
type Item = St::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
}
}
impl AbortHandle {
pub fn abort(&self) {
self.inner.aborted.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
}