use crate::loom::sync::atomic::AtomicU8;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
use std::cell::UnsafeCell;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
#[derive(Debug)]
pub struct Notify {
state: AtomicU8,
waiters: Mutex<LinkedList<Waiter>>,
}
#[derive(Debug)]
struct Waiter {
pointers: linked_list::Pointers<Waiter>,
waker: Option<Waker>,
notified: bool,
_p: PhantomPinned,
}
#[derive(Debug)]
struct Notified<'a> {
notify: &'a Notify,
state: State,
waiter: UnsafeCell<Waiter>,
}
unsafe impl<'a> Send for Notified<'a> {}
unsafe impl<'a> Sync for Notified<'a> {}
#[derive(Debug)]
enum State {
Init,
Waiting,
Done,
}
const EMPTY: u8 = 0;
const WAITING: u8 = 1;
const NOTIFIED: u8 = 2;
impl Notify {
pub fn new() -> Notify {
Notify {
state: AtomicU8::new(0),
waiters: Mutex::new(LinkedList::new()),
}
}
pub async fn notified(&self) {
Notified {
notify: self,
state: State::Init,
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: false,
_p: PhantomPinned,
}),
}
.await
}
pub fn notify(&self) {
let mut curr = self.state.load(SeqCst);
while let EMPTY | NOTIFIED = curr {
let res = self.state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst);
match res {
Ok(_) => return,
Err(actual) => {
curr = actual;
}
}
}
let mut waiters = self.waiters.lock().unwrap();
curr = self.state.load(SeqCst);
if let Some(waker) = notify_locked(&mut waiters, &self.state, curr) {
drop(waiters);
waker.wake();
}
}
}
impl Default for Notify {
fn default() -> Notify {
Notify::new()
}
}
fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> {
loop {
match curr {
EMPTY | NOTIFIED => {
let res = state.compare_exchange(curr, NOTIFIED, SeqCst, SeqCst);
match res {
Ok(_) => return None,
Err(actual) => {
assert!(actual == EMPTY || actual == NOTIFIED);
state.store(NOTIFIED, SeqCst);
return None;
}
}
}
WAITING => {
let mut waiter = waiters.pop_back().unwrap();
let waiter = unsafe { waiter.as_mut() };
assert!(!waiter.notified);
waiter.notified = true;
let waker = waiter.waker.take();
if waiters.is_empty() {
state.store(EMPTY, SeqCst);
}
return waker;
}
_ => unreachable!(),
}
}
}
impl Notified<'_> {
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
unsafe {
is_unpin::<&Notify>();
is_unpin::<AtomicU8>();
let me = self.get_unchecked_mut();
(&me.notify, &mut me.state, &me.waiter)
}
}
}
impl Future for Notified<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
use State::*;
let (notify, state, waiter) = self.project();
loop {
match *state {
Init => {
let res = notify
.state
.compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst);
if res.is_ok() {
*state = Done;
return Poll::Ready(());
}
let mut waiters = notify.waiters.lock().unwrap();
let mut curr = notify.state.load(SeqCst);
loop {
match curr {
EMPTY => {
let res = notify
.state
.compare_exchange(EMPTY, WAITING, SeqCst, SeqCst);
if let Err(actual) = res {
assert_eq!(actual, NOTIFIED);
curr = actual;
} else {
break;
}
}
WAITING => break,
NOTIFIED => {
let res = notify
.state
.compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst);
match res {
Ok(_) => {
*state = Done;
return Poll::Ready(());
}
Err(actual) => {
assert_eq!(actual, EMPTY);
curr = actual;
}
}
}
_ => unreachable!(),
}
}
unsafe {
(*waiter.get()).waker = Some(cx.waker().clone());
}
waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
*state = Waiting;
}
Waiting => {
let waiters = notify.waiters.lock().unwrap();
let w = unsafe { &mut *waiter.get() };
if w.notified {
w.waker = None;
w.notified = false;
*state = Done;
} else {
if !w.waker.as_ref().unwrap().will_wake(cx.waker()) {
w.waker = Some(cx.waker().clone());
}
return Poll::Pending;
}
drop(waiters);
}
Done => {
return Poll::Ready(());
}
}
}
}
}
impl Drop for Notified<'_> {
fn drop(&mut self) {
use State::*;
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
if let Waiting = *state {
let mut notify_state = WAITING;
let mut waiters = notify.waiters.lock().unwrap();
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };
if waiters.is_empty() {
notify_state = EMPTY;
notify.state.store(EMPTY, SeqCst);
}
let notified = unsafe { (*waiter.get()).notified };
if notified {
if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) {
drop(waiters);
waker.wake();
}
}
}
}
}
unsafe impl linked_list::Link for Waiter {
type Handle = NonNull<Waiter>;
type Target = Waiter;
fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
*handle
}
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
ptr
}
unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
NonNull::from(&mut target.as_mut().pointers)
}
}
fn is_unpin<T: Unpin>() {}