use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::AtomicBool;
use core::sync::atomic::Ordering::SeqCst;
use futures_core::future::Future;
use futures_core::task::{Context, Poll, Waker};
use crate::lock::Lock;
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
#[derive(Debug)]
pub struct Sender<T> {
inner: Arc<Inner<T>>,
}
impl<T> Unpin for Receiver<T> {}
impl<T> Unpin for Sender<T> {}
#[derive(Debug)]
struct Inner<T> {
complete: AtomicBool,
data: Lock<Option<T>>,
rx_task: Lock<Option<Waker>>,
tx_task: Lock<Option<Waker>>,
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(Inner::new());
let receiver = Receiver {
inner: inner.clone(),
};
let sender = Sender {
inner,
};
(sender, receiver)
}
impl<T> Inner<T> {
fn new() -> Inner<T> {
Inner {
complete: AtomicBool::new(false),
data: Lock::new(None),
rx_task: Lock::new(None),
tx_task: Lock::new(None),
}
}
fn send(&self, t: T) -> Result<(), T> {
if self.complete.load(SeqCst) {
return Err(t)
}
if let Some(mut slot) = self.data.try_lock() {
assert!(slot.is_none());
*slot = Some(t);
drop(slot);
if self.complete.load(SeqCst) {
if let Some(mut slot) = self.data.try_lock() {
if let Some(t) = slot.take() {
return Err(t);
}
}
}
Ok(())
} else {
Err(t)
}
}
fn poll_canceled(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.complete.load(SeqCst) {
return Poll::Ready(())
}
let handle = cx.waker().clone();
match self.tx_task.try_lock() {
Some(mut p) => *p = Some(handle),
None => return Poll::Ready(()),
}
if self.complete.load(SeqCst) {
Poll::Ready(())
} else {
Poll::Pending
}
}
fn is_canceled(&self) -> bool {
self.complete.load(SeqCst)
}
fn drop_tx(&self) {
self.complete.store(true, SeqCst);
if let Some(mut slot) = self.rx_task.try_lock() {
if let Some(task) = slot.take() {
drop(slot);
task.wake();
}
}
if let Some(mut slot) = self.tx_task.try_lock() {
drop(slot.take());
}
}
fn close_rx(&self) {
self.complete.store(true, SeqCst);
if let Some(mut handle) = self.tx_task.try_lock() {
if let Some(task) = handle.take() {
drop(handle);
task.wake()
}
}
}
fn try_recv(&self) -> Result<Option<T>, Canceled> {
if self.complete.load(SeqCst) {
if let Some(mut slot) = self.data.try_lock() {
if let Some(data) = slot.take() {
return Ok(Some(data));
}
}
Err(Canceled)
} else {
Ok(None)
}
}
fn recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, Canceled>> {
let done = if self.complete.load(SeqCst) {
true
} else {
let task = cx.waker().clone();
match self.rx_task.try_lock() {
Some(mut slot) => { *slot = Some(task); false },
None => true,
}
};
if done || self.complete.load(SeqCst) {
if let Some(mut slot) = self.data.try_lock() {
if let Some(data) = slot.take() {
return Poll::Ready(Ok(data));
}
}
Poll::Ready(Err(Canceled))
} else {
Poll::Pending
}
}
fn drop_rx(&self) {
self.complete.store(true, SeqCst);
if let Some(mut slot) = self.rx_task.try_lock() {
let task = slot.take();
drop(slot);
drop(task);
}
if let Some(mut handle) = self.tx_task.try_lock() {
if let Some(task) = handle.take() {
drop(handle);
task.wake()
}
}
}
}
impl<T> Sender<T> {
pub fn send(self, t: T) -> Result<(), T> {
self.inner.send(t)
}
pub fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.inner.poll_canceled(cx)
}
pub fn is_canceled(&self) -> bool {
self.inner.is_canceled()
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.inner.drop_tx()
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Canceled;
impl fmt::Display for Canceled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "oneshot canceled")
}
}
#[cfg(feature = "std")]
impl std::error::Error for Canceled {}
impl<T> Receiver<T> {
pub fn close(&mut self) {
self.inner.close_rx()
}
pub fn try_recv(&mut self) -> Result<Option<T>, Canceled> {
self.inner.try_recv()
}
}
impl<T> Future for Receiver<T> {
type Output = Result<T, Canceled>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<T, Canceled>> {
self.inner.recv(cx)
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.inner.drop_rx()
}
}