#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::process;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_mutex::{Mutex, MutexGuard};
use event_listener::Event;
const WRITER_BIT: usize = 1;
const ONE_READER: usize = 2;
pub struct RwLock<T: ?Sized> {
mutex: Mutex<()>,
no_readers: Event,
no_writer: Event,
state: AtomicUsize,
value: UnsafeCell<T>,
}
unsafe impl<T: Send + ?Sized> Send for RwLock<T> {}
unsafe impl<T: Send + Sync + ?Sized> Sync for RwLock<T> {}
impl<T> RwLock<T> {
pub fn new(t: T) -> RwLock<T> {
RwLock {
mutex: Mutex::new(()),
no_readers: Event::new(),
no_writer: Event::new(),
state: AtomicUsize::new(0),
value: UnsafeCell::new(t),
}
}
pub fn into_inner(self) -> T {
self.value.into_inner()
}
}
impl<T: ?Sized> RwLock<T> {
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT != 0 {
return None;
}
if state > std::isize::MAX as usize {
process::abort();
}
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Some(RwLockReadGuard(self)),
Err(s) => state = s,
}
}
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT == 0 {
if state > std::isize::MAX as usize {
process::abort();
}
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return RwLockReadGuard(self),
Err(s) => state = s,
}
} else {
let listener = self.no_writer.listen();
if self.state.load(Ordering::SeqCst) & WRITER_BIT != 0 {
listener.await;
self.no_writer.notify(1);
}
state = self.state.load(Ordering::Acquire);
}
}
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
let lock = self.mutex.try_lock()?;
if self.state.compare_and_swap(0, WRITER_BIT, Ordering::AcqRel) == 0 {
Some(RwLockWriteGuard(self, lock))
} else {
None
}
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let lock = self.mutex.lock().await;
self.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
let guard = RwLockWriteGuard(self, lock);
while self.state.load(Ordering::SeqCst) != WRITER_BIT {
let listener = self.no_readers.listen();
if self.state.load(Ordering::Acquire) != WRITER_BIT {
listener.await;
}
}
guard
}
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.value.get() }
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Locked;
impl fmt::Debug for Locked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
match self.try_read() {
None => f.debug_struct("RwLock").field("value", &Locked).finish(),
Some(guard) => f.debug_struct("RwLock").field("value", &&*guard).finish(),
}
}
}
impl<T> From<T> for RwLock<T> {
fn from(val: T) -> RwLock<T> {
RwLock::new(val)
}
}
impl<T: Default + ?Sized> Default for RwLock<T> {
fn default() -> RwLock<T> {
RwLock::new(Default::default())
}
}
pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>);
unsafe impl<T: Sync + ?Sized> Send for RwLockReadGuard<'_, T> {}
unsafe impl<T: Sync + ?Sized> Sync for RwLockReadGuard<'_, T> {}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
if self.0.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER {
self.0.no_readers.notify(1);
}
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: fmt::Display + ?Sized> fmt::Display for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.value.get() }
}
}
pub struct RwLockWriteGuard<'a, T: ?Sized>(&'a RwLock<T>, MutexGuard<'a, ()>);
unsafe impl<T: Send + ?Sized> Send for RwLockWriteGuard<'_, T> {}
unsafe impl<T: Sync + ?Sized> Sync for RwLockWriteGuard<'_, T> {}
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
self.0.state.fetch_and(!WRITER_BIT, Ordering::SeqCst);
self.0.no_writer.notify(1);
}
}
impl<T: fmt::Debug + ?Sized> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: fmt::Display + ?Sized> fmt::Display for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.value.get() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.0.value.get() }
}
}