use crate::runtime::task::{self, JoinHandle, Task};
use crate::sync::AtomicWaker;
use crate::util::linked_list::LinkedList;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::Poll;
use pin_project_lite::pin_project;
cfg_rt_util! {
pub struct LocalSet {
tick: Cell<u8>,
context: Context,
_not_send: PhantomData<*const ()>,
}
}
struct Context {
tasks: RefCell<Tasks>,
shared: Arc<Shared>,
}
struct Tasks {
owned: LinkedList<Task<Arc<Shared>>>,
queue: VecDeque<task::Notified<Arc<Shared>>>,
}
struct Shared {
queue: Mutex<VecDeque<task::Notified<Arc<Shared>>>>,
waker: AtomicWaker,
}
pin_project! {
#[derive(Debug)]
struct RunUntil<'a, F> {
local_set: &'a LocalSet,
#[pin]
future: F,
}
}
scoped_thread_local!(static CURRENT: Context);
cfg_rt_util! {
pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
CURRENT.with(|maybe_cx| {
let cx = maybe_cx
.expect("`spawn_local` called from outside of a `task::LocalSet`");
let (task, handle) = unsafe { task::joinable_local(future) };
cx.tasks.borrow_mut().queue.push_back(task);
handle
})
}
}
const INITIAL_CAPACITY: usize = 64;
const MAX_TASKS_PER_TICK: usize = 61;
const REMOTE_FIRST_INTERVAL: u8 = 31;
impl LocalSet {
pub fn new() -> LocalSet {
LocalSet {
tick: Cell::new(0),
context: Context {
tasks: RefCell::new(Tasks {
owned: LinkedList::new(),
queue: VecDeque::with_capacity(INITIAL_CAPACITY),
}),
shared: Arc::new(Shared {
queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
waker: AtomicWaker::new(),
}),
},
_not_send: PhantomData,
}
}
pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let (task, handle) = unsafe { task::joinable_local(future) };
self.context.tasks.borrow_mut().queue.push_back(task);
handle
}
pub fn block_on<F>(&self, rt: &mut crate::runtime::Runtime, future: F) -> F::Output
where
F: Future,
{
rt.block_on(self.run_until(future))
}
pub async fn run_until<F>(&self, future: F) -> F::Output
where
F: Future,
{
let run_until = RunUntil {
future,
local_set: self,
};
run_until.await
}
fn tick(&self) -> bool {
for _ in 0..MAX_TASKS_PER_TICK {
match self.next_task() {
Some(task) => crate::coop::budget(|| task.run()),
None => return false,
}
}
true
}
fn next_task(&self) -> Option<task::Notified<Arc<Shared>>> {
let tick = self.tick.get();
self.tick.set(tick.wrapping_add(1));
if tick % REMOTE_FIRST_INTERVAL == 0 {
self.context
.shared
.queue
.lock()
.unwrap()
.pop_front()
.or_else(|| self.context.tasks.borrow_mut().queue.pop_front())
} else {
self.context
.tasks
.borrow_mut()
.queue
.pop_front()
.or_else(|| self.context.shared.queue.lock().unwrap().pop_front())
}
}
fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.set(&self.context, f)
}
}
impl fmt::Debug for LocalSet {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("LocalSet").finish()
}
}
impl Future for LocalSet {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
self.context.shared.waker.register_by_ref(cx.waker());
if self.with(|| self.tick()) {
cx.waker().wake_by_ref();
Poll::Pending
} else if self.context.tasks.borrow().owned.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl Default for LocalSet {
fn default() -> LocalSet {
LocalSet::new()
}
}
impl Drop for LocalSet {
fn drop(&mut self) {
self.with(|| {
#[allow(clippy::while_let_loop)]
loop {
let task = match self.context.tasks.borrow_mut().owned.pop_back() {
Some(task) => task,
None => break,
};
task.shutdown();
}
for task in self.context.tasks.borrow_mut().queue.drain(..) {
task.shutdown();
}
for task in self.context.shared.queue.lock().unwrap().drain(..) {
task.shutdown();
}
assert!(self.context.tasks.borrow().owned.is_empty());
});
}
}
impl<T: Future> Future for RunUntil<'_, T> {
type Output = T::Output;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let me = self.project();
me.local_set.with(|| {
me.local_set
.context
.shared
.waker
.register_by_ref(cx.waker());
let _no_blocking = crate::runtime::enter::disallow_blocking();
if let Poll::Ready(output) = me.future.poll(cx) {
return Poll::Ready(output);
}
if me.local_set.tick() {
cx.waker().wake_by_ref();
}
Poll::Pending
})
}
}
impl Shared {
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| match maybe_cx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.tasks.borrow_mut().queue.push_back(task);
}
_ => {
self.queue.lock().unwrap().push_back(task);
self.waker.wake();
}
});
}
fn ptr_eq(&self, other: &Shared) -> bool {
self as *const _ == other as *const _
}
}
impl task::Schedule for Arc<Shared> {
fn bind(task: Task<Self>) -> Arc<Shared> {
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
cx.tasks.borrow_mut().owned.push_front(task);
cx.shared.clone()
})
}
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
use std::ptr::NonNull;
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
assert!(cx.shared.ptr_eq(self));
let ptr = NonNull::from(task.header());
unsafe { cx.tasks.borrow_mut().owned.remove(ptr) }
})
}
fn schedule(&self, task: task::Notified<Self>) {
Shared::schedule(self, task);
}
}