use std::cell::UnsafeCell;
#[must_use = "You should call Sender::send with the result"]
pub struct Sender<T>(std::sync::mpsc::Sender<T>);
impl<T> Sender<T> {
pub fn send(self, value: T) {
self.0.send(value).ok(); }
}
#[derive(Clone, Copy)]
#[allow(dead_code)]
pub enum TaskType {
Local,
Async,
None,
}
#[must_use]
pub struct Promise<T: Send + 'static> {
data: PromiseImpl<T>,
task_type: TaskType,
#[cfg(feature = "tokio")]
join_handle: Option<tokio::task::JoinHandle<()>>,
#[cfg(feature = "smol")]
smol_task: Option<smol::Task<()>>,
#[cfg(feature = "async-std")]
async_std_join_handle: Option<async_std::task::JoinHandle<()>>,
}
#[cfg(all(
not(docsrs),
any(
all(feature = "tokio", feature = "smol"),
all(feature = "tokio", feature = "async-std"),
all(feature = "tokio", feature = "web"),
all(feature = "smol", feature = "async-std"),
all(feature = "smol", feature = "web"),
all(feature = "async-std", feature = "web"),
)
))]
compile_error!(
"You can only specify one of the executor features: 'tokio', 'smol', 'async-std' or 'web'"
);
static_assertions::assert_not_impl_all!(Promise<u32>: Sync);
static_assertions::assert_impl_all!(Promise<u32>: Send);
impl<T: Send + 'static> Promise<T> {
pub fn new() -> (Sender<T>, Self) {
let (tx, rx) = std::sync::mpsc::channel();
(
Sender(tx),
Self {
data: PromiseImpl(UnsafeCell::new(PromiseStatus::Pending(rx))),
task_type: TaskType::None,
#[cfg(feature = "tokio")]
join_handle: None,
#[cfg(feature = "async-std")]
async_std_join_handle: None,
#[cfg(feature = "smol")]
smol_task: None,
},
)
}
pub fn from_ready(value: T) -> Self {
Self {
data: PromiseImpl(UnsafeCell::new(PromiseStatus::Ready(value))),
task_type: TaskType::None,
#[cfg(feature = "tokio")]
join_handle: None,
#[cfg(feature = "async-std")]
async_std_join_handle: None,
#[cfg(feature = "smol")]
smol_task: None,
}
}
#[cfg(any(feature = "tokio", feature = "smol", feature = "async-std"))]
pub fn spawn_async(future: impl std::future::Future<Output = T> + 'static + Send) -> Self {
let (sender, mut promise) = Self::new();
promise.task_type = TaskType::Async;
#[cfg(feature = "tokio")]
{
promise.join_handle =
Some(tokio::task::spawn(async move { sender.send(future.await) }));
}
#[cfg(feature = "smol")]
{
promise.smol_task =
Some(crate::EXECUTOR.spawn(async move { sender.send(future.await) }));
}
#[cfg(feature = "async-std")]
{
promise.async_std_join_handle =
Some(async_std::task::spawn(
async move { sender.send(future.await) },
));
}
promise
}
#[cfg(any(feature = "tokio", feature = "web", feature = "smol"))]
pub fn spawn_local(future: impl std::future::Future<Output = T> + 'static) -> Self {
#[allow(unused_mut)]
let (sender, mut promise) = Self::new();
promise.task_type = TaskType::Local;
#[cfg(feature = "tokio")]
{
promise.join_handle = Some(tokio::task::spawn_local(async move {
sender.send(future.await);
}));
}
#[cfg(feature = "web")]
{
wasm_bindgen_futures::spawn_local(async move { sender.send(future.await) });
}
#[cfg(feature = "smol")]
{
promise.smol_task = Some(
crate::LOCAL_EXECUTOR
.with(|exec| exec.spawn(async move { sender.send(future.await) })),
);
}
promise
}
#[cfg(any(feature = "tokio", feature = "async-std"))]
pub fn spawn_blocking<F>(f: F) -> Self
where
F: FnOnce() -> T + Send + 'static,
{
let (sender, mut promise) = Self::new();
#[cfg(feature = "tokio")]
{
promise.join_handle = Some(tokio::task::spawn(async move {
sender.send(tokio::task::block_in_place(f));
}));
}
#[cfg(feature = "async-std")]
{
promise.async_std_join_handle = Some(async_std::task::spawn_blocking(move || {
sender.send(f());
}));
}
promise
}
#[cfg(not(target_arch = "wasm32"))] pub fn spawn_thread<F>(thread_name: impl Into<String>, f: F) -> Self
where
F: FnOnce() -> T + Send + 'static,
{
let (sender, promise) = Self::new();
std::thread::Builder::new()
.name(thread_name.into())
.spawn(move || sender.send(f()))
.expect("Failed to spawn thread");
promise
}
pub fn ready(&self) -> Option<&T> {
match self.poll() {
std::task::Poll::Pending => None,
std::task::Poll::Ready(value) => Some(value),
}
}
pub fn ready_mut(&mut self) -> Option<&mut T> {
match self.poll_mut() {
std::task::Poll::Pending => None,
std::task::Poll::Ready(value) => Some(value),
}
}
pub fn try_take(self) -> Result<T, Self> {
self.data.try_take().map_err(|data| Self {
data,
task_type: self.task_type,
#[cfg(feature = "tokio")]
join_handle: None,
#[cfg(feature = "async-std")]
async_std_join_handle: None,
#[cfg(feature = "smol")]
smol_task: self.smol_task,
})
}
pub fn block_until_ready(&self) -> &T {
self.data.block_until_ready(self.task_type)
}
pub fn block_until_ready_mut(&mut self) -> &mut T {
self.data.block_until_ready_mut(self.task_type)
}
pub fn block_and_take(self) -> T {
self.data.block_until_ready(self.task_type);
match self.data.0.into_inner() {
PromiseStatus::Pending(_) => unreachable!(),
PromiseStatus::Ready(value) => value,
}
}
pub fn poll(&self) -> std::task::Poll<&T> {
self.data.poll(self.task_type)
}
pub fn poll_mut(&mut self) -> std::task::Poll<&mut T> {
self.data.poll_mut(self.task_type)
}
pub fn task_type(&self) -> TaskType {
self.task_type
}
#[cfg(feature = "tokio")]
pub fn abort(self) {
if let Some(join_handle) = self.join_handle {
join_handle.abort();
}
}
}
enum PromiseStatus<T: Send + 'static> {
Pending(std::sync::mpsc::Receiver<T>),
Ready(T),
}
struct PromiseImpl<T: Send + 'static>(UnsafeCell<PromiseStatus<T>>);
impl<T: Send + 'static> PromiseImpl<T> {
#[allow(unused_variables)]
fn poll_mut(&mut self, task_type: TaskType) -> std::task::Poll<&mut T> {
let inner = self.0.get_mut();
match inner {
PromiseStatus::Pending(rx) => {
#[cfg(all(feature = "smol", feature = "smol_tick_poll"))]
Self::tick(task_type);
if let Ok(value) = rx.try_recv() {
*inner = PromiseStatus::Ready(value);
match inner {
PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
PromiseStatus::Pending(_) => unreachable!(),
}
} else {
std::task::Poll::Pending
}
}
PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
}
}
fn try_take(self) -> Result<T, Self> {
let inner = self.0.into_inner();
match inner {
PromiseStatus::Pending(ref rx) => match rx.try_recv() {
Ok(value) => Ok(value),
Err(std::sync::mpsc::TryRecvError::Empty) => {
Err(PromiseImpl(UnsafeCell::new(inner)))
}
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
panic!("The Promise Sender was dropped")
}
},
PromiseStatus::Ready(value) => Ok(value),
}
}
#[allow(unsafe_code)]
#[allow(unused_variables)]
fn poll(&self, task_type: TaskType) -> std::task::Poll<&T> {
let this = unsafe {
self.0.get().as_mut().expect("UnsafeCell should be valid")
};
match this {
PromiseStatus::Pending(rx) => {
#[cfg(all(feature = "smol", feature = "smol_tick_poll"))]
Self::tick(task_type);
match rx.try_recv() {
Ok(value) => {
*this = PromiseStatus::Ready(value);
match this {
PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
PromiseStatus::Pending(_) => unreachable!(),
}
}
Err(std::sync::mpsc::TryRecvError::Empty) => std::task::Poll::Pending,
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
panic!("The Promise Sender was dropped")
}
}
}
PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
}
}
#[allow(unused_variables)]
fn block_until_ready_mut(&mut self, task_type: TaskType) -> &mut T {
#[cfg(feature = "smol")]
while self.poll(task_type).is_pending() {
#[cfg(not(feature = "smol_tick_poll"))]
Self::tick(task_type);
}
let inner = self.0.get_mut();
match inner {
PromiseStatus::Pending(rx) => {
let value = rx.recv().expect("The Promise Sender was dropped");
*inner = PromiseStatus::Ready(value);
match inner {
PromiseStatus::Ready(ref mut value) => value,
PromiseStatus::Pending(_) => unreachable!(),
}
}
PromiseStatus::Ready(ref mut value) => value,
}
}
#[allow(unsafe_code)]
#[allow(unused_variables)]
fn block_until_ready(&self, task_type: TaskType) -> &T {
#[cfg(feature = "smol")]
while self.poll(task_type).is_pending() {
#[cfg(not(feature = "smol_tick_poll"))]
Self::tick(task_type);
}
let this = unsafe {
self.0.get().as_mut().expect("UnsafeCell should be valid")
};
match this {
PromiseStatus::Pending(rx) => {
let value = rx.recv().expect("The Promise Sender was dropped");
*this = PromiseStatus::Ready(value);
match this {
PromiseStatus::Ready(ref value) => value,
PromiseStatus::Pending(_) => unreachable!(),
}
}
PromiseStatus::Ready(ref value) => value,
}
}
#[cfg(feature = "smol")]
fn tick(task_type: TaskType) {
match task_type {
TaskType::Local => {
crate::tick_local();
}
TaskType::Async => {
crate::tick();
}
TaskType::None => (),
};
}
}