use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::usize;
use crate::registry::{Registry, WorkerThread};
pub(super) trait Latch {
unsafe fn set(this: *const Self);
}
pub(super) trait AsCoreLatch {
fn as_core_latch(&self) -> &CoreLatch;
}
const UNSET: usize = 0;
const SLEEPY: usize = 1;
const SLEEPING: usize = 2;
const SET: usize = 3;
#[derive(Debug)]
pub(super) struct CoreLatch {
state: AtomicUsize,
}
impl CoreLatch {
#[inline]
fn new() -> Self {
Self {
state: AtomicUsize::new(0),
}
}
#[inline]
pub(super) fn addr(&self) -> usize {
self as *const CoreLatch as usize
}
#[inline]
pub(super) fn get_sleepy(&self) -> bool {
self.state
.compare_exchange(UNSET, SLEEPY, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub(super) fn fall_asleep(&self) -> bool {
self.state
.compare_exchange(SLEEPY, SLEEPING, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
}
#[inline]
pub(super) fn wake_up(&self) {
if !self.probe() {
let _ =
self.state
.compare_exchange(SLEEPING, UNSET, Ordering::SeqCst, Ordering::Relaxed);
}
}
#[inline]
unsafe fn set(this: *const Self) -> bool {
let old_state = (*this).state.swap(SET, Ordering::AcqRel);
old_state == SLEEPING
}
#[inline]
pub(super) fn probe(&self) -> bool {
self.state.load(Ordering::Acquire) == SET
}
}
pub(super) struct SpinLatch<'r> {
core_latch: CoreLatch,
registry: &'r Arc<Registry>,
target_worker_index: usize,
cross: bool,
}
impl<'r> SpinLatch<'r> {
#[inline]
pub(super) fn new(thread: &'r WorkerThread) -> SpinLatch<'r> {
SpinLatch {
core_latch: CoreLatch::new(),
registry: thread.registry(),
target_worker_index: thread.index(),
cross: false,
}
}
#[inline]
pub(super) fn cross(thread: &'r WorkerThread) -> SpinLatch<'r> {
SpinLatch {
cross: true,
..SpinLatch::new(thread)
}
}
#[inline]
pub(super) fn probe(&self) -> bool {
self.core_latch.probe()
}
}
impl<'r> AsCoreLatch for SpinLatch<'r> {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
&self.core_latch
}
}
impl<'r> Latch for SpinLatch<'r> {
#[inline]
unsafe fn set(this: *const Self) {
let cross_registry;
let registry: &Registry = if (*this).cross {
cross_registry = Arc::clone((*this).registry);
&cross_registry
} else {
(*this).registry
};
let target_worker_index = (*this).target_worker_index;
if CoreLatch::set(&(*this).core_latch) {
registry.notify_worker_latch_is_set(target_worker_index);
}
}
}
#[derive(Debug)]
pub(super) struct LockLatch {
m: Mutex<bool>,
v: Condvar,
}
impl LockLatch {
#[inline]
pub(super) fn new() -> LockLatch {
LockLatch {
m: Mutex::new(false),
v: Condvar::new(),
}
}
pub(super) fn wait_and_reset(&self) {
let mut guard = self.m.lock().unwrap();
while !*guard {
guard = self.v.wait(guard).unwrap();
}
*guard = false;
}
pub(super) fn wait(&self) {
let mut guard = self.m.lock().unwrap();
while !*guard {
guard = self.v.wait(guard).unwrap();
}
}
}
impl Latch for LockLatch {
#[inline]
unsafe fn set(this: *const Self) {
let mut guard = (*this).m.lock().unwrap();
*guard = true;
(*this).v.notify_all();
}
}
#[derive(Debug)]
pub(super) struct CountLatch {
core_latch: CoreLatch,
counter: AtomicUsize,
}
impl CountLatch {
#[inline]
pub(super) fn new() -> CountLatch {
Self::with_count(1)
}
#[inline]
pub(super) fn with_count(n: usize) -> CountLatch {
CountLatch {
core_latch: CoreLatch::new(),
counter: AtomicUsize::new(n),
}
}
#[inline]
pub(super) fn increment(&self) {
debug_assert!(!self.core_latch.probe());
self.counter.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub(super) unsafe fn set(this: *const Self) -> bool {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
CoreLatch::set(&(*this).core_latch);
true
} else {
false
}
}
#[inline]
pub(super) unsafe fn set_and_tickle_one(
this: *const Self,
registry: &Registry,
target_worker_index: usize,
) {
if Self::set(this) {
registry.notify_worker_latch_is_set(target_worker_index);
}
}
}
impl AsCoreLatch for CountLatch {
#[inline]
fn as_core_latch(&self) -> &CoreLatch {
&self.core_latch
}
}
#[derive(Debug)]
pub(super) struct CountLockLatch {
lock_latch: LockLatch,
counter: AtomicUsize,
}
impl CountLockLatch {
#[inline]
pub(super) fn with_count(n: usize) -> CountLockLatch {
CountLockLatch {
lock_latch: LockLatch::new(),
counter: AtomicUsize::new(n),
}
}
#[inline]
pub(super) fn increment(&self) {
let old_counter = self.counter.fetch_add(1, Ordering::Relaxed);
debug_assert!(old_counter != 0);
}
pub(super) fn wait(&self) {
self.lock_latch.wait();
}
}
impl Latch for CountLockLatch {
#[inline]
unsafe fn set(this: *const Self) {
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
LockLatch::set(&(*this).lock_latch);
}
}
}
pub(super) struct LatchRef<'a, L> {
inner: *const L,
marker: PhantomData<&'a L>,
}
impl<L> LatchRef<'_, L> {
pub(super) fn new(inner: &L) -> LatchRef<'_, L> {
LatchRef {
inner,
marker: PhantomData,
}
}
}
unsafe impl<L: Sync> Sync for LatchRef<'_, L> {}
impl<L> Deref for LatchRef<'_, L> {
type Target = L;
fn deref(&self) -> &L {
unsafe { &*self.inner }
}
}
impl<L: Latch> Latch for LatchRef<'_, L> {
#[inline]
unsafe fn set(this: *const Self) {
L::set((*this).inner);
}
}