#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::util::WakeList;
use std::future::Future;
use std::marker::PhantomPinned;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::{self, Acquire, Relaxed, Release, SeqCst};
use std::task::{Context, Poll, Waker};
type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
type GuardedWaitList = GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
#[derive(Debug)]
pub struct Notify {
state: AtomicUsize,
waiters: Mutex<WaitList>,
}
#[derive(Debug)]
struct Waiter {
pointers: linked_list::Pointers<Waiter>,
waker: UnsafeCell<Option<Waker>>,
notification: AtomicNotification,
_p: PhantomPinned,
}
impl Waiter {
fn new() -> Waiter {
Waiter {
pointers: linked_list::Pointers::new(),
waker: UnsafeCell::new(None),
notification: AtomicNotification::none(),
_p: PhantomPinned,
}
}
}
generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
&self.pointers
}
}
}
const NOTIFICATION_NONE: usize = 0;
const NOTIFICATION_ONE: usize = 1;
const NOTIFICATION_ALL: usize = 2;
#[derive(Debug)]
struct AtomicNotification(AtomicUsize);
impl AtomicNotification {
fn none() -> Self {
AtomicNotification(AtomicUsize::new(NOTIFICATION_NONE))
}
fn store_release(&self, notification: Notification) {
self.0.store(notification as usize, Release);
}
fn load(&self, ordering: Ordering) -> Option<Notification> {
match self.0.load(ordering) {
NOTIFICATION_NONE => None,
NOTIFICATION_ONE => Some(Notification::One),
NOTIFICATION_ALL => Some(Notification::All),
_ => unreachable!(),
}
}
fn clear(&self) {
self.0.store(NOTIFICATION_NONE, Relaxed);
}
}
#[derive(Debug, PartialEq, Eq)]
#[repr(usize)]
enum Notification {
One = NOTIFICATION_ONE,
All = NOTIFICATION_ALL,
}
struct NotifyWaitersList<'a> {
list: GuardedWaitList,
is_empty: bool,
notify: &'a Notify,
}
impl<'a> NotifyWaitersList<'a> {
fn new(
unguarded_list: WaitList,
guard: Pin<&'a Waiter>,
notify: &'a Notify,
) -> NotifyWaitersList<'a> {
let guard_ptr = NonNull::from(guard.get_ref());
let list = unguarded_list.into_guarded(guard_ptr);
NotifyWaitersList {
list,
is_empty: false,
notify,
}
}
fn pop_back_locked(&mut self, _waiters: &mut WaitList) -> Option<NonNull<Waiter>> {
let result = self.list.pop_back();
if result.is_none() {
self.is_empty = true;
}
result
}
}
impl Drop for NotifyWaitersList<'_> {
fn drop(&mut self) {
if !self.is_empty {
let _lock_guard = self.notify.waiters.lock();
while let Some(waiter) = self.list.pop_back() {
let waiter = unsafe { waiter.as_ref() };
waiter.notification.store_release(Notification::All);
}
}
}
}
#[derive(Debug)]
pub struct Notified<'a> {
notify: &'a Notify,
state: State,
notify_waiters_calls: usize,
waiter: Waiter,
}
unsafe impl<'a> Send for Notified<'a> {}
unsafe impl<'a> Sync for Notified<'a> {}
#[derive(Debug)]
enum State {
Init,
Waiting,
Done,
}
const NOTIFY_WAITERS_SHIFT: usize = 2;
const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1;
const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK;
const EMPTY: usize = 0;
const WAITING: usize = 1;
const NOTIFIED: usize = 2;
fn set_state(data: usize, state: usize) -> usize {
(data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK)
}
fn get_state(data: usize) -> usize {
data & STATE_MASK
}
fn get_num_notify_waiters_calls(data: usize) -> usize {
(data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT
}
fn inc_num_notify_waiters_calls(data: usize) -> usize {
data + (1 << NOTIFY_WAITERS_SHIFT)
}
fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) {
data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst);
}
impl Notify {
pub fn new() -> Notify {
Notify {
state: AtomicUsize::new(0),
waiters: Mutex::new(LinkedList::new()),
}
}
#[cfg(not(all(loom, test)))]
pub const fn const_new() -> Notify {
Notify {
state: AtomicUsize::new(0),
waiters: Mutex::const_new(LinkedList::new()),
}
}
pub fn notified(&self) -> Notified<'_> {
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init,
notify_waiters_calls: get_num_notify_waiters_calls(state),
waiter: Waiter::new(),
}
}
#[cfg_attr(docsrs, doc(alias = "notify"))]
pub fn notify_one(&self) {
let mut curr = self.state.load(SeqCst);
while let EMPTY | NOTIFIED = get_state(curr) {
let new = set_state(curr, NOTIFIED);
let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst);
match res {
Ok(_) => return,
Err(actual) => {
curr = actual;
}
}
}
let mut waiters = self.waiters.lock();
curr = self.state.load(SeqCst);
if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) {
drop(waiters);
waker.wake();
}
}
pub fn notify_waiters(&self) {
let mut waiters = self.waiters.lock();
let curr = self.state.load(SeqCst);
if matches!(get_state(curr), EMPTY | NOTIFIED) {
atomic_inc_num_notify_waiters_calls(&self.state);
return;
}
let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new_state, SeqCst);
let guard = Waiter::new();
pin!(guard);
let mut list = NotifyWaitersList::new(std::mem::take(&mut *waiters), guard.as_ref(), self);
let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match list.pop_back_locked(&mut waiters) {
Some(waiter) => {
let waiter = unsafe { waiter.as_ref() };
if let Some(waker) =
unsafe { waiter.waker.with_mut(|waker| (*waker).take()) }
{
wakers.push(waker);
}
waiter.notification.store_release(Notification::All);
}
None => {
break 'outer;
}
}
}
drop(waiters);
wakers.wake_all();
waiters = self.waiters.lock();
}
drop(waiters);
wakers.wake_all();
}
}
impl Default for Notify {
fn default() -> Notify {
Notify::new()
}
}
impl UnwindSafe for Notify {}
impl RefUnwindSafe for Notify {}
fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> {
match get_state(curr) {
EMPTY | NOTIFIED => {
let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst);
match res {
Ok(_) => None,
Err(actual) => {
let actual_state = get_state(actual);
assert!(actual_state == EMPTY || actual_state == NOTIFIED);
state.store(set_state(actual, NOTIFIED), SeqCst);
None
}
}
}
WAITING => {
let waiter = waiters.pop_back().unwrap();
let waiter = unsafe { waiter.as_ref() };
let waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) };
waiter.notification.store_release(Notification::One);
if waiters.is_empty() {
state.store(set_state(curr, EMPTY), SeqCst);
}
waker
}
_ => unreachable!(),
}
}
impl Notified<'_> {
pub fn enable(self: Pin<&mut Self>) -> bool {
self.poll_notified(None).is_ready()
}
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &Waiter) {
unsafe {
is_unpin::<&Notify>();
is_unpin::<State>();
is_unpin::<usize>();
let me = self.get_unchecked_mut();
(
me.notify,
&mut me.state,
&me.notify_waiters_calls,
&me.waiter,
)
}
}
fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
let (notify, state, notify_waiters_calls, waiter) = self.project();
'outer_loop: loop {
match *state {
State::Init => {
let curr = notify.state.load(SeqCst);
let res = notify.state.compare_exchange(
set_state(curr, NOTIFIED),
set_state(curr, EMPTY),
SeqCst,
SeqCst,
);
if res.is_ok() {
*state = State::Done;
continue 'outer_loop;
}
let waker = waker.cloned();
let mut waiters = notify.waiters.lock();
let mut curr = notify.state.load(SeqCst);
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
*state = State::Done;
continue 'outer_loop;
}
loop {
match get_state(curr) {
EMPTY => {
let res = notify.state.compare_exchange(
set_state(curr, EMPTY),
set_state(curr, WAITING),
SeqCst,
SeqCst,
);
if let Err(actual) = res {
assert_eq!(get_state(actual), NOTIFIED);
curr = actual;
} else {
break;
}
}
WAITING => break,
NOTIFIED => {
let res = notify.state.compare_exchange(
set_state(curr, NOTIFIED),
set_state(curr, EMPTY),
SeqCst,
SeqCst,
);
match res {
Ok(_) => {
*state = State::Done;
continue 'outer_loop;
}
Err(actual) => {
assert_eq!(get_state(actual), EMPTY);
curr = actual;
}
}
}
_ => unreachable!(),
}
}
let mut old_waker = None;
if waker.is_some() {
unsafe {
old_waker =
waiter.waker.with_mut(|v| std::mem::replace(&mut *v, waker));
}
}
waiters.push_front(NonNull::from(waiter));
*state = State::Waiting;
drop(waiters);
drop(old_waker);
return Poll::Pending;
}
State::Waiting => {
#[cfg(tokio_taskdump)]
if let Some(waker) = waker {
let mut ctx = Context::from_waker(waker);
ready!(crate::trace::trace_leaf(&mut ctx));
}
if waiter.notification.load(Acquire).is_some() {
drop(unsafe { waiter.waker.with_mut(|waker| (*waker).take()) });
waiter.notification.clear();
*state = State::Done;
return Poll::Ready(());
}
let mut old_waker = None;
let mut waiters = notify.waiters.lock();
if waiter.notification.load(Relaxed).is_some() {
old_waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) };
waiter.notification.clear();
drop(waiters);
drop(old_waker);
*state = State::Done;
return Poll::Ready(());
}
let curr = notify.state.load(SeqCst);
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
old_waker = unsafe { waiter.waker.with_mut(|waker| (*waker).take()) };
unsafe { waiters.remove(NonNull::from(waiter)) };
*state = State::Done;
} else {
unsafe {
waiter.waker.with_mut(|v| {
if let Some(waker) = waker {
let should_update = match &*v {
Some(current_waker) => !current_waker.will_wake(waker),
None => true,
};
if should_update {
old_waker = std::mem::replace(&mut *v, Some(waker.clone()));
}
}
});
}
drop(waiters);
drop(old_waker);
return Poll::Pending;
}
drop(waiters);
drop(old_waker);
}
State::Done => {
#[cfg(tokio_taskdump)]
if let Some(waker) = waker {
let mut ctx = Context::from_waker(waker);
ready!(crate::trace::trace_leaf(&mut ctx));
}
return Poll::Ready(());
}
}
}
}
}
impl Future for Notified<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
self.poll_notified(Some(cx.waker()))
}
}
impl Drop for Notified<'_> {
fn drop(&mut self) {
let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };
if matches!(*state, State::Waiting) {
let mut waiters = notify.waiters.lock();
let mut notify_state = notify.state.load(SeqCst);
let notification = waiter.notification.load(Relaxed);
unsafe { waiters.remove(NonNull::from(waiter)) };
if waiters.is_empty() && get_state(notify_state) == WAITING {
notify_state = set_state(notify_state, EMPTY);
notify.state.store(notify_state, SeqCst);
}
if notification == Some(Notification::One) {
if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) {
drop(waiters);
waker.wake();
}
}
}
}
}
unsafe impl linked_list::Link for Waiter {
type Handle = NonNull<Waiter>;
type Target = Waiter;
fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
*handle
}
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
ptr
}
unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
Waiter::addr_of_pointers(target)
}
}
fn is_unpin<T: Unpin>() {}