use std::mem;
use pin_project::pin_project;
use tokio::sync::{mpsc, watch};
use super::{task, Future, Never, Pin, Poll};
#[derive(Clone, Copy)]
enum Action {
Open,
}
pub fn channel() -> (Signal, Watch) {
let (tx, rx) = watch::channel(Action::Open);
let (drained_tx, drained_rx) = mpsc::channel(1);
(
Signal {
drained_rx,
_tx: tx,
},
Watch { drained_tx, rx },
)
}
pub struct Signal {
drained_rx: mpsc::Receiver<Never>,
_tx: watch::Sender<Action>,
}
pub struct Draining {
drained_rx: mpsc::Receiver<Never>,
}
#[derive(Clone)]
pub struct Watch {
drained_tx: mpsc::Sender<Never>,
rx: watch::Receiver<Action>,
}
#[allow(missing_debug_implementations)]
#[pin_project]
pub struct Watching<F, FN> {
#[pin]
future: F,
state: State<FN>,
watch: Watch,
}
enum State<F> {
Watch(F),
Draining,
}
impl Signal {
pub fn drain(self) -> Draining {
Draining {
drained_rx: self.drained_rx,
}
}
}
impl Future for Draining {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match ready!(self.drained_rx.poll_recv(cx)) {
Some(never) => match never {},
None => Poll::Ready(()),
}
}
}
impl Watch {
pub fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
where
F: Future,
FN: FnOnce(Pin<&mut F>),
{
Watching {
future,
state: State::Watch(on_drain),
watch: self,
}
}
}
impl<F, FN> Future for Watching<F, FN>
where
F: Future,
FN: FnOnce(Pin<&mut F>),
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let mut me = self.project();
loop {
match mem::replace(me.state, State::Draining) {
State::Watch(on_drain) => {
match me.watch.rx.poll_recv_ref(cx) {
Poll::Ready(None) => {
on_drain(me.future.as_mut());
}
Poll::Ready(Some(_ )) | Poll::Pending => {
*me.state = State::Watch(on_drain);
return me.future.poll(cx);
}
}
}
State::Draining => return me.future.poll(cx),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestMe {
draining: bool,
finished: bool,
poll_cnt: usize,
}
impl Future for TestMe {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
self.poll_cnt += 1;
if self.finished {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
#[test]
fn watch() {
let mut mock = tokio_test::task::spawn(());
mock.enter(|cx, _| {
let (tx, rx) = channel();
let fut = TestMe {
draining: false,
finished: false,
poll_cnt: 0,
};
let mut watch = rx.watch(fut, |mut fut| {
fut.draining = true;
});
assert_eq!(watch.future.poll_cnt, 0);
assert!(Pin::new(&mut watch).poll(cx).is_pending());
assert_eq!(watch.future.poll_cnt, 1);
assert!(Pin::new(&mut watch).poll(cx).is_pending());
assert_eq!(watch.future.poll_cnt, 2);
let mut draining = tx.drain();
assert!(!watch.future.draining);
assert_eq!(watch.future.poll_cnt, 2);
assert!(Pin::new(&mut watch).poll(cx).is_pending());
assert_eq!(watch.future.poll_cnt, 3);
assert!(watch.future.draining);
assert!(Pin::new(&mut draining).poll(cx).is_pending());
watch.future.finished = true;
assert!(Pin::new(&mut watch).poll(cx).is_ready());
assert_eq!(watch.future.poll_cnt, 4);
drop(watch);
assert!(Pin::new(&mut draining).poll(cx).is_ready());
})
}
#[test]
fn watch_clones() {
let mut mock = tokio_test::task::spawn(());
mock.enter(|cx, _| {
let (tx, rx) = channel();
let fut1 = TestMe {
draining: false,
finished: false,
poll_cnt: 0,
};
let fut2 = TestMe {
draining: false,
finished: false,
poll_cnt: 0,
};
let watch1 = rx.clone().watch(fut1, |mut fut| {
fut.draining = true;
});
let watch2 = rx.watch(fut2, |mut fut| {
fut.draining = true;
});
let mut draining = tx.drain();
assert!(Pin::new(&mut draining).poll(cx).is_pending());
drop(watch1);
assert!(Pin::new(&mut draining).poll(cx).is_pending());
drop(watch2);
assert!(Pin::new(&mut draining).poll(cx).is_ready());
});
}
}