#![deny(unsafe_code)]
use std::pin::{Pin, pin};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed};
#[cfg(feature = "loom")]
use loom::sync::atomic::AtomicUsize;
use crate::opcode::Opcode;
use crate::pager::SyncResult;
use crate::sync_primitive::SyncPrimitive;
use crate::wait_queue::{Entry, WaitQueue};
use crate::{Pager, pager};
#[derive(Debug, Default)]
pub struct Gate {
state: AtomicUsize,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum State {
Controlled = 0_u8,
Sealed = 1_u8,
Open = 2_u8,
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum Error {
Rejected = 4_u8,
Sealed = 8_u8,
SpuriousFailure = 12_u8,
NotRegistered = 16_u8,
WrongMode = 20_u8,
NotReady = 24_u8,
}
impl Gate {
const STATE_MASK: u8 = 0b11;
#[inline]
pub fn state(&self, mo: Ordering) -> State {
State::from(self.state.load(mo) & WaitQueue::DATA_MASK)
}
#[inline]
pub fn reset(&self) -> Option<State> {
match self.state.fetch_update(Relaxed, Relaxed, |value| {
let state = State::from(value & WaitQueue::DATA_MASK);
if state == State::Controlled {
None
} else {
debug_assert_eq!(value & WaitQueue::ADDR_MASK, 0);
Some((value & WaitQueue::ADDR_MASK) | u8::from(state) as usize)
}
}) {
Ok(state) => Some(State::from(state & WaitQueue::DATA_MASK)),
Err(_) => None,
}
}
#[inline]
pub fn permit(&self) -> Result<usize, State> {
let (state, count) = self.wake_all(None, None);
if state == State::Controlled {
Ok(count)
} else {
debug_assert_eq!(count, 0);
Err(state)
}
}
#[inline]
pub fn reject(&self) -> Result<usize, State> {
let (state, count) = self.wake_all(None, Some(Error::Rejected));
if state == State::Controlled {
Ok(count)
} else {
debug_assert_eq!(count, 0);
Err(state)
}
}
#[inline]
pub fn open(&self) -> (State, usize) {
self.wake_all(Some(State::Open), None)
}
#[inline]
pub fn seal(&self) -> (State, usize) {
self.wake_all(Some(State::Sealed), Some(Error::Sealed))
}
#[inline]
pub async fn enter_async(&self) -> Result<State, Error> {
let mut pinned_pager = pin!(Pager::default());
pinned_pager
.wait_queue()
.construct(self, Opcode::Wait(0), false);
self.push_wait_queue_entry(&mut pinned_pager, || {});
pinned_pager.poll_async().await
}
#[inline]
pub async fn enter_async_with<F: FnOnce()>(&self, begin_wait: F) -> Result<State, Error> {
let mut pinned_pager = pin!(Pager::default());
pinned_pager
.wait_queue()
.construct(self, Opcode::Wait(0), false);
self.push_wait_queue_entry(&mut pinned_pager, begin_wait);
pinned_pager.poll_async().await
}
#[inline]
pub fn enter_sync(&self) -> Result<State, Error> {
self.enter_sync_with(|| ())
}
#[inline]
pub fn enter_sync_with<F: FnOnce()>(&self, begin_wait: F) -> Result<State, Error> {
let mut pinned_pager = pin!(Pager::default());
pinned_pager
.wait_queue()
.construct(self, Opcode::Wait(0), true);
self.push_wait_queue_entry(&mut pinned_pager, begin_wait);
pinned_pager.poll_sync()
}
#[inline]
pub fn register_pager<'g>(
&'g self,
pager: &mut Pin<&mut Pager<'g, Self>>,
is_sync: bool,
) -> bool {
if pager.is_registered() {
return false;
}
pager.wait_queue().construct(self, Opcode::Wait(0), is_sync);
self.push_wait_queue_entry(pager, || ());
true
}
fn wake_all(&self, next_state: Option<State>, error: Option<Error>) -> (State, usize) {
match self.state.fetch_update(AcqRel, Acquire, |value| {
if let Some(new_value) = next_state {
Some(u8::from(new_value) as usize)
} else {
Some(value & WaitQueue::DATA_MASK)
}
}) {
Ok(value) | Err(value) => {
let mut count = 0;
let prev_state = State::from(value & WaitQueue::DATA_MASK);
let next_state = next_state.unwrap_or(prev_state);
let result = Self::into_u8(next_state, error);
let anchor_ptr = WaitQueue::to_anchor_ptr(value);
if !anchor_ptr.is_null() {
let tail_entry_ptr = WaitQueue::to_entry_ptr(anchor_ptr);
Entry::iter_forward(tail_entry_ptr, false, |entry, _| {
entry.set_result(result);
count += 1;
false
});
}
(prev_state, count)
}
}
}
#[inline]
fn push_wait_queue_entry<F: FnOnce()>(
&self,
pager: &mut Pin<&mut Pager<Self>>,
mut begin_wait: F,
) {
loop {
let state = self.state.load(Acquire);
match State::from(state & WaitQueue::DATA_MASK) {
State::Controlled => {
if let Some(returned) =
self.try_push_wait_queue_entry(pager.wait_queue(), state, begin_wait)
{
begin_wait = returned;
continue;
}
}
State::Sealed => {
pager
.wait_queue()
.entry()
.set_result(Self::into_u8(State::Sealed, Some(Error::Sealed)));
}
State::Open => {
pager
.wait_queue()
.entry()
.set_result(Self::into_u8(State::Open, None));
}
}
break;
}
}
#[inline]
fn into_u8(state: State, error: Option<Error>) -> u8 {
u8::from(state) | error.map_or(0_u8, u8::from)
}
}
impl Drop for Gate {
#[inline]
fn drop(&mut self) {
if self.state.load(Relaxed) & WaitQueue::ADDR_MASK == 0 {
return;
}
self.seal();
}
}
impl SyncPrimitive for Gate {
#[inline]
fn state(&self) -> &AtomicUsize {
&self.state
}
#[inline]
fn max_shared_owners() -> usize {
usize::MAX
}
#[inline]
fn drop_wait_queue_entry(entry: &Entry) {
if entry.try_consume_result().is_none() {
let this: &Self = entry.sync_primitive_ref();
this.wake_all(None, Some(Error::SpuriousFailure));
entry.acknowledge_result_sync();
}
}
}
impl SyncResult for Gate {
type Result = Result<State, Error>;
#[inline]
fn to_result(value: u8, pager_error: Option<pager::Error>) -> Self::Result {
if let Some(pager_error) = pager_error {
match pager_error {
pager::Error::NotRegistered => Err(Error::NotRegistered),
pager::Error::WrongMode => Err(Error::WrongMode),
pager::Error::NotReady => Err(Error::NotReady),
}
} else {
let state = State::from(value & Self::STATE_MASK);
let error = value & !(Self::STATE_MASK);
if error != 0 {
Err(Error::from(error))
} else {
Ok(state)
}
}
}
}
impl From<State> for u8 {
#[inline]
fn from(value: State) -> Self {
match value {
State::Controlled => 0_u8,
State::Sealed => 1_u8,
State::Open => 2_u8,
}
}
}
impl From<u8> for State {
#[inline]
fn from(value: u8) -> Self {
State::from(value as usize)
}
}
impl From<usize> for State {
#[inline]
fn from(value: usize) -> Self {
match value {
0 => State::Controlled,
1 => State::Sealed,
_ => State::Open,
}
}
}
impl From<Error> for u8 {
#[inline]
fn from(value: Error) -> Self {
match value {
Error::Rejected => 4_u8,
Error::Sealed => 8_u8,
Error::SpuriousFailure => 12_u8,
Error::NotRegistered => 16_u8,
Error::WrongMode => 20_u8,
Error::NotReady => 24_u8,
}
}
}
impl From<u8> for Error {
#[inline]
fn from(value: u8) -> Self {
Error::from(value as usize)
}
}
impl From<usize> for Error {
#[inline]
fn from(value: usize) -> Self {
match value {
4 => Error::Rejected,
8 => Error::Sealed,
12 => Error::SpuriousFailure,
16 => Error::NotRegistered,
20 => Error::WrongMode,
_ => Error::NotReady,
}
}
}