use crate::coop::CoopFutureExt;
use crate::sync::batch_semaphore::{AcquireError, Semaphore};
use std::cell::UnsafeCell;
use std::ops;
#[cfg(not(loom))]
const MAX_READS: usize = 32;
#[cfg(loom)]
const MAX_READS: usize = 10;
#[derive(Debug)]
pub struct RwLock<T> {
s: Semaphore,
c: UnsafeCell<T>,
}
#[derive(Debug)]
pub struct RwLockReadGuard<'a, T> {
permit: ReleasingPermit<'a, T>,
lock: &'a RwLock<T>,
}
#[derive(Debug)]
pub struct RwLockWriteGuard<'a, T> {
permit: ReleasingPermit<'a, T>,
lock: &'a RwLock<T>,
}
#[derive(Debug)]
struct ReleasingPermit<'a, T> {
num_permits: u16,
lock: &'a RwLock<T>,
}
impl<'a, T> ReleasingPermit<'a, T> {
async fn acquire(
lock: &'a RwLock<T>,
num_permits: u16,
) -> Result<ReleasingPermit<'a, T>, AcquireError> {
lock.s.acquire(num_permits).cooperate().await?;
Ok(Self { num_permits, lock })
}
}
impl<'a, T> Drop for ReleasingPermit<'a, T> {
fn drop(&mut self) {
self.lock.s.release(self.num_permits as usize);
}
}
#[test]
#[cfg(not(loom))]
fn bounds() {
fn check_send<T: Send>() {}
fn check_sync<T: Sync>() {}
fn check_unpin<T: Unpin>() {}
fn check_send_sync_val<T: Send + Sync>(_t: T) {}
check_send::<RwLock<u32>>();
check_sync::<RwLock<u32>>();
check_unpin::<RwLock<u32>>();
check_sync::<RwLockReadGuard<'_, u32>>();
check_unpin::<RwLockReadGuard<'_, u32>>();
check_sync::<RwLockWriteGuard<'_, u32>>();
check_unpin::<RwLockWriteGuard<'_, u32>>();
let rwlock = RwLock::new(0);
check_send_sync_val(rwlock.read());
check_send_sync_val(rwlock.write());
}
unsafe impl<T> Send for RwLock<T> where T: Send {}
unsafe impl<T> Sync for RwLock<T> where T: Send + Sync {}
unsafe impl<'a, T> Sync for RwLockReadGuard<'a, T> where T: Send + Sync {}
unsafe impl<'a, T> Sync for RwLockWriteGuard<'a, T> where T: Send + Sync {}
impl<T> RwLock<T> {
pub fn new(value: T) -> RwLock<T> {
RwLock {
c: UnsafeCell::new(value),
s: Semaphore::new(MAX_READS),
}
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| {
unreachable!()
});
RwLockReadGuard { lock: self, permit }
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let permit = ReleasingPermit::acquire(self, MAX_READS as u16)
.await
.unwrap_or_else(|_| {
unreachable!()
});
RwLockWriteGuard { lock: self, permit }
}
pub fn into_inner(self) -> T {
self.c.into_inner()
}
}
impl<T> ops::Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.c.get() }
}
}
impl<T> ops::Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.c.get() }
}
}
impl<T> ops::DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.c.get() }
}
}
impl<T> From<T> for RwLock<T> {
fn from(s: T) -> Self {
Self::new(s)
}
}
impl<T> Default for RwLock<T>
where
T: Default,
{
fn default() -> Self {
Self::new(T::default())
}
}