#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
use crate::util::WakeList;
use std::cell::UnsafeCell;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
#[derive(Debug)]
pub struct Notify {
state: AtomicUsize,
waiters: Mutex<WaitList>,
}
#[derive(Debug, Clone, Copy)]
enum NotificationType {
AllWaiters,
OneWaiter,
}
#[derive(Debug)]
struct Waiter {
pointers: linked_list::Pointers<Waiter>,
waker: Option<Waker>,
notified: Option<NotificationType>,
_p: PhantomPinned,
}
#[derive(Debug)]
pub struct Notified<'a> {
notify: &'a Notify,
state: State,
waiter: UnsafeCell<Waiter>,
}
unsafe impl<'a> Send for Notified<'a> {}
unsafe impl<'a> Sync for Notified<'a> {}
#[derive(Debug)]
enum State {
Init(usize),
Waiting,
Done,
}
const NOTIFY_WAITERS_SHIFT: usize = 2;
const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1;
const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK;
const EMPTY: usize = 0;
const WAITING: usize = 1;
const NOTIFIED: usize = 2;
fn set_state(data: usize, state: usize) -> usize {
(data & NOTIFY_WAITERS_CALLS_MASK) | (state & STATE_MASK)
}
fn get_state(data: usize) -> usize {
data & STATE_MASK
}
fn get_num_notify_waiters_calls(data: usize) -> usize {
(data & NOTIFY_WAITERS_CALLS_MASK) >> NOTIFY_WAITERS_SHIFT
}
fn inc_num_notify_waiters_calls(data: usize) -> usize {
data + (1 << NOTIFY_WAITERS_SHIFT)
}
fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) {
data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst);
}
impl Notify {
pub fn new() -> Notify {
Notify {
state: AtomicUsize::new(0),
waiters: Mutex::new(LinkedList::new()),
}
}
#[cfg(all(feature = "parking_lot", not(all(loom, test))))]
#[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
pub const fn const_new() -> Notify {
Notify {
state: AtomicUsize::new(0),
waiters: Mutex::const_new(LinkedList::new()),
}
}
pub fn notified(&self) -> Notified<'_> {
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}),
}
}
#[cfg_attr(docsrs, doc(alias = "notify"))]
pub fn notify_one(&self) {
let mut curr = self.state.load(SeqCst);
while let EMPTY | NOTIFIED = get_state(curr) {
let new = set_state(curr, NOTIFIED);
let res = self.state.compare_exchange(curr, new, SeqCst, SeqCst);
match res {
Ok(_) => return,
Err(actual) => {
curr = actual;
}
}
}
let mut waiters = self.waiters.lock();
curr = self.state.load(SeqCst);
if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) {
drop(waiters);
waker.wake();
}
}
pub fn notify_waiters(&self) {
let mut wakers = WakeList::new();
let mut waiters = self.waiters.lock();
let curr = self.state.load(SeqCst);
if let EMPTY | NOTIFIED = get_state(curr) {
atomic_inc_num_notify_waiters_calls(&self.state);
return;
}
'outer: loop {
while wakers.can_push() {
match waiters.pop_back() {
Some(mut waiter) => {
let waiter = unsafe { waiter.as_mut() };
assert!(waiter.notified.is_none());
waiter.notified = Some(NotificationType::AllWaiters);
if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
}
}
None => {
break 'outer;
}
}
}
drop(waiters);
wakers.wake_all();
waiters = self.waiters.lock();
}
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new, SeqCst);
drop(waiters);
wakers.wake_all();
}
}
impl Default for Notify {
fn default() -> Notify {
Notify::new()
}
}
fn notify_locked(waiters: &mut WaitList, state: &AtomicUsize, curr: usize) -> Option<Waker> {
loop {
match get_state(curr) {
EMPTY | NOTIFIED => {
let res = state.compare_exchange(curr, set_state(curr, NOTIFIED), SeqCst, SeqCst);
match res {
Ok(_) => return None,
Err(actual) => {
let actual_state = get_state(actual);
assert!(actual_state == EMPTY || actual_state == NOTIFIED);
state.store(set_state(actual, NOTIFIED), SeqCst);
return None;
}
}
}
WAITING => {
let mut waiter = waiters.pop_back().unwrap();
let waiter = unsafe { waiter.as_mut() };
assert!(waiter.notified.is_none());
waiter.notified = Some(NotificationType::OneWaiter);
let waker = waiter.waker.take();
if waiters.is_empty() {
state.store(set_state(curr, EMPTY), SeqCst);
}
return waker;
}
_ => unreachable!(),
}
}
}
impl Notified<'_> {
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
unsafe {
is_unpin::<&Notify>();
is_unpin::<AtomicUsize>();
let me = self.get_unchecked_mut();
(me.notify, &mut me.state, &me.waiter)
}
}
}
impl Future for Notified<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
use State::*;
let (notify, state, waiter) = self.project();
loop {
match *state {
Init(initial_notify_waiters_calls) => {
let curr = notify.state.load(SeqCst);
let res = notify.state.compare_exchange(
set_state(curr, NOTIFIED),
set_state(curr, EMPTY),
SeqCst,
SeqCst,
);
if res.is_ok() {
*state = Done;
return Poll::Ready(());
}
let waker = cx.waker().clone();
let mut waiters = notify.waiters.lock();
let mut curr = notify.state.load(SeqCst);
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
*state = Done;
return Poll::Ready(());
}
loop {
match get_state(curr) {
EMPTY => {
let res = notify.state.compare_exchange(
set_state(curr, EMPTY),
set_state(curr, WAITING),
SeqCst,
SeqCst,
);
if let Err(actual) = res {
assert_eq!(get_state(actual), NOTIFIED);
curr = actual;
} else {
break;
}
}
WAITING => break,
NOTIFIED => {
let res = notify.state.compare_exchange(
set_state(curr, NOTIFIED),
set_state(curr, EMPTY),
SeqCst,
SeqCst,
);
match res {
Ok(_) => {
*state = Done;
return Poll::Ready(());
}
Err(actual) => {
assert_eq!(get_state(actual), EMPTY);
curr = actual;
}
}
}
_ => unreachable!(),
}
}
unsafe {
(*waiter.get()).waker = Some(waker);
}
waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
*state = Waiting;
return Poll::Pending;
}
Waiting => {
let waiters = notify.waiters.lock();
let w = unsafe { &mut *waiter.get() };
if w.notified.is_some() {
w.waker = None;
w.notified = None;
*state = Done;
} else {
if !w.waker.as_ref().unwrap().will_wake(cx.waker()) {
w.waker = Some(cx.waker().clone());
}
return Poll::Pending;
}
drop(waiters);
}
Done => {
return Poll::Ready(());
}
}
}
}
}
impl Drop for Notified<'_> {
fn drop(&mut self) {
use State::*;
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
if let Waiting = *state {
let mut waiters = notify.waiters.lock();
let mut notify_state = notify.state.load(SeqCst);
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };
if waiters.is_empty() {
if let WAITING = get_state(notify_state) {
notify_state = set_state(notify_state, EMPTY);
notify.state.store(notify_state, SeqCst);
}
}
if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } {
if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) {
drop(waiters);
waker.wake();
}
}
}
}
}
unsafe impl linked_list::Link for Waiter {
type Handle = NonNull<Waiter>;
type Target = Waiter;
fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
*handle
}
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
ptr
}
unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
NonNull::from(&mut target.as_mut().pointers)
}
}
fn is_unpin<T: Unpin>() {}