use core::mem::forget;
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Poll;
use event_listener::{Event, EventListener};
use event_listener_strategy::{EventListenerFuture, Strategy};
use crate::futures::Lock;
use crate::Mutex;
const WRITER_BIT: usize = 1;
const ONE_READER: usize = 2;
pub(super) struct RawRwLock {
mutex: Mutex<()>,
no_readers: Event,
no_writer: Event,
state: AtomicUsize,
}
impl RawRwLock {
#[inline]
pub(super) const fn new() -> Self {
RawRwLock {
mutex: Mutex::new(()),
no_readers: Event::new(),
no_writer: Event::new(),
state: AtomicUsize::new(0),
}
}
pub(super) fn try_read(&self) -> bool {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT != 0 {
return false;
}
if state > core::isize::MAX as usize {
crate::abort();
}
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(s) => state = s,
}
}
}
#[inline]
pub(super) fn read(&self) -> RawRead<'_> {
RawRead {
lock: self,
state: self.state.load(Ordering::Acquire),
listener: EventListener::new(),
}
}
pub(super) fn try_upgradable_read(&self) -> bool {
let lock = if let Some(lock) = self.mutex.try_lock() {
lock
} else {
return false;
};
forget(lock);
let mut state = self.state.load(Ordering::Acquire);
if state > core::isize::MAX as usize {
crate::abort();
}
loop {
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(s) => state = s,
}
}
}
#[inline]
pub(super) fn upgradable_read(&self) -> RawUpgradableRead<'_> {
RawUpgradableRead {
lock: self,
acquire: self.mutex.lock(),
}
}
pub(super) fn try_write(&self) -> bool {
let lock = if let Some(lock) = self.mutex.try_lock() {
lock
} else {
return false;
};
if self
.state
.compare_exchange(0, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
forget(lock);
true
} else {
drop(lock);
false
}
}
#[inline]
pub(super) fn write(&self) -> RawWrite<'_> {
RawWrite {
lock: self,
no_readers: EventListener::new(),
state: WriteState::Acquiring {
lock: self.mutex.lock(),
},
}
}
pub(super) unsafe fn try_upgrade(&self) -> bool {
self.state
.compare_exchange(ONE_READER, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
pub(super) unsafe fn upgrade(&self) -> RawUpgrade<'_> {
self.state
.fetch_sub(ONE_READER - WRITER_BIT, Ordering::SeqCst);
RawUpgrade {
lock: Some(self),
listener: EventListener::new(),
}
}
#[inline]
pub(super) unsafe fn downgrade_upgradable_read(&self) {
self.mutex.unlock_unchecked();
}
pub(super) unsafe fn downgrade_write(&self) {
self.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
self.mutex.unlock_unchecked();
self.no_writer.notify(1);
}
pub(super) unsafe fn downgrade_to_upgradable(&self) {
self.state
.fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst);
}
pub(super) unsafe fn read_unlock(&self) {
if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.no_readers.notify(1);
}
}
pub(super) unsafe fn upgradable_read_unlock(&self) {
if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.no_readers.notify(1);
}
self.mutex.unlock_unchecked();
}
pub(super) unsafe fn write_unlock(&self) {
self.state.fetch_and(!WRITER_BIT, Ordering::SeqCst);
self.no_writer.notify(1);
self.mutex.unlock_unchecked();
}
}
pin_project_lite::pin_project! {
pub(super) struct RawRead<'a> {
pub(super) lock: &'a RawRwLock,
state: usize,
#[pin]
listener: EventListener,
}
}
impl<'a> EventListenerFuture for RawRead<'a> {
type Output = ();
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<()> {
let mut this = self.project();
loop {
if *this.state & WRITER_BIT == 0 {
if *this.state > core::isize::MAX as usize {
crate::abort();
}
match this.lock.state.compare_exchange(
*this.state,
*this.state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(()),
Err(s) => *this.state = s,
}
} else {
let load_ordering = if !this.listener.is_listening() {
this.listener.as_mut().listen(&this.lock.no_writer);
Ordering::SeqCst
} else {
ready!(strategy.poll(this.listener.as_mut(), cx));
this.lock.no_writer.notify(1);
Ordering::Acquire
};
*this.state = this.lock.state.load(load_ordering);
}
}
}
}
pin_project_lite::pin_project! {
pub(super) struct RawUpgradableRead<'a> {
pub(super) lock: &'a RawRwLock,
#[pin]
acquire: Lock<'a, ()>,
}
}
impl<'a> EventListenerFuture for RawUpgradableRead<'a> {
type Output = ();
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<()> {
let this = self.project();
let mutex_guard = ready!(this.acquire.poll_with_strategy(strategy, cx));
forget(mutex_guard);
let mut state = this.lock.state.load(Ordering::Acquire);
if state > core::isize::MAX as usize {
crate::abort();
}
loop {
match this.lock.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Poll::Ready(());
}
Err(s) => state = s,
}
}
}
}
pin_project_lite::pin_project! {
pub(super) struct RawWrite<'a> {
pub(super) lock: &'a RawRwLock,
#[pin]
no_readers: EventListener,
#[pin]
state: WriteState<'a>,
}
impl PinnedDrop for RawWrite<'_> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
if matches!(this.state.project(), WriteStateProj::WaitingReaders) {
unsafe {
this.lock.write_unlock();
}
}
}
}
}
pin_project_lite::pin_project! {
#[project = WriteStateProj]
#[project_replace = WriteStateProjReplace]
enum WriteState<'a> {
Acquiring { #[pin] lock: Lock<'a, ()> },
WaitingReaders,
Acquired,
}
}
impl<'a> EventListenerFuture for RawWrite<'a> {
type Output = ();
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<()> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
WriteStateProj::Acquiring { lock } => {
let mutex_guard = ready!(lock.poll_with_strategy(strategy, cx));
forget(mutex_guard);
let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
if new_state == WRITER_BIT {
this.state.as_mut().set(WriteState::Acquired);
return Poll::Ready(());
}
this.no_readers.as_mut().listen(&this.lock.no_readers);
this.state.as_mut().set(WriteState::WaitingReaders);
}
WriteStateProj::WaitingReaders => {
let load_ordering = if this.no_readers.is_listening() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
if this.lock.state.load(load_ordering) == WRITER_BIT {
this.state.as_mut().set(WriteState::Acquired);
return Poll::Ready(());
}
if !this.no_readers.is_listening() {
this.no_readers.as_mut().listen(&this.lock.no_readers);
} else {
ready!(strategy.poll(this.no_readers.as_mut(), cx));
};
}
WriteStateProj::Acquired => panic!("Write lock already acquired"),
}
}
}
}
pin_project_lite::pin_project! {
pub(super) struct RawUpgrade<'a> {
lock: Option<&'a RawRwLock>,
#[pin]
listener: EventListener,
}
impl PinnedDrop for RawUpgrade<'_> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
if let Some(lock) = this.lock {
unsafe {
lock.write_unlock();
}
}
}
}
}
impl<'a> EventListenerFuture for RawUpgrade<'a> {
type Output = &'a RawRwLock;
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<&'a RawRwLock> {
let mut this = self.project();
let lock = this.lock.expect("cannot poll future after completion");
loop {
let load_ordering = if this.listener.is_listening() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
let state = lock.state.load(load_ordering);
if state == WRITER_BIT {
break;
}
if !this.listener.is_listening() {
this.listener.as_mut().listen(&lock.no_readers);
} else {
ready!(strategy.poll(this.listener.as_mut(), cx));
};
}
Poll::Ready(this.lock.take().unwrap())
}
}
impl<'a> RawUpgrade<'a> {
#[inline]
pub(super) fn is_ready(&self) -> bool {
self.lock.is_none()
}
}