use std::pin::{Pin, pin};
use std::ptr::{addr_of, null, with_exposed_provenance};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
use std::thread;
#[cfg(feature = "loom")]
use loom::sync::atomic::AtomicUsize;
use crate::opcode::Opcode;
use crate::wait_queue::{Entry, WaitQueue};
pub(crate) trait SyncPrimitive: Sized {
fn state(&self) -> &AtomicUsize;
fn max_shared_owners() -> usize;
fn drop_wait_queue_entry(entry: &Entry);
#[inline]
fn addr(&self) -> usize {
let self_ptr: *const Self = addr_of!(*self);
self_ptr.expose_provenance()
}
#[must_use]
fn try_push_wait_queue_entry<F: FnOnce()>(
&self,
wait_queue: Pin<&WaitQueue>,
state: usize,
begin_wait: F,
) -> Option<F> {
let anchor_ptr = wait_queue.anchor_ptr().0;
let anchor_addr = anchor_ptr.expose_provenance();
debug_assert_eq!(anchor_addr & (!WaitQueue::ADDR_MASK), 0);
let tail_anchor_ptr = WaitQueue::to_anchor_ptr(state);
wait_queue
.entry()
.update_next_entry_anchor_ptr(tail_anchor_ptr);
let next_state = (state & (!WaitQueue::ADDR_MASK)) | anchor_addr;
if self
.state()
.compare_exchange(state, next_state, AcqRel, Acquire)
.is_ok()
{
wait_queue.entry().set_pollable();
begin_wait();
None
} else {
Some(begin_wait)
}
}
fn wait_resources_sync<F: FnOnce()>(
&self,
state: usize,
opcode: Opcode,
begin_wait: F,
) -> Result<u8, F> {
debug_assert!(state & WaitQueue::ADDR_MASK != 0 || state & WaitQueue::DATA_MASK != 0);
let pinned_wait_queue = pin!(WaitQueue::default());
pinned_wait_queue.as_ref().construct(self, opcode, true);
if let Some(returned) =
self.try_push_wait_queue_entry(pinned_wait_queue.as_ref(), state, begin_wait)
{
return Err(returned);
}
Ok(pinned_wait_queue.entry().poll_result_sync())
}
fn release_loop(&self, mut state: usize, opcode: Opcode) -> bool {
while opcode.can_release(state) {
if state & WaitQueue::ADDR_MASK == 0
|| state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG
{
match self.state().compare_exchange(
state,
state - opcode.acquired_count(),
Release,
Relaxed,
) {
Ok(_) => return true,
Err(new_state) => state = new_state,
}
} else {
let next_state = (state | WaitQueue::LOCKED_FLAG) - opcode.acquired_count();
if let Err(new_state) = self
.state()
.compare_exchange(state, next_state, AcqRel, Relaxed)
{
state = new_state;
continue;
}
self.process_wait_queue(next_state);
return true;
}
}
false
}
fn process_wait_queue(&self, mut state: usize) {
let mut head_entry_ptr: *const Entry = null();
let mut unlocked = false;
while !unlocked {
debug_assert_eq!(state & WaitQueue::LOCKED_FLAG, WaitQueue::LOCKED_FLAG);
let anchor_ptr = WaitQueue::to_anchor_ptr(state);
let tail_entry_ptr = WaitQueue::to_entry_ptr(anchor_ptr);
if head_entry_ptr.is_null() {
Entry::iter_forward(tail_entry_ptr, true, |entry, next_entry| {
head_entry_ptr = Entry::ref_to_ptr(entry);
next_entry.is_none()
});
} else {
Entry::set_prev_ptr(tail_entry_ptr);
}
let data = state & WaitQueue::DATA_MASK;
let mut transferred = 0;
let mut resolved_entry_ptr: *const Entry = null();
let mut reset_failed = false;
Entry::iter_backward(head_entry_ptr, |entry, prev_entry| {
let desired = entry.opcode().desired_count();
if data + transferred == 0
|| data + transferred + desired <= Self::max_shared_owners()
{
let acquired = entry.opcode().acquired_count();
debug_assert!(acquired <= desired);
if prev_entry.is_some() {
transferred += acquired;
resolved_entry_ptr = Entry::ref_to_ptr(entry);
false
} else {
debug_assert_eq!(tail_entry_ptr, addr_of!(*entry));
if self
.state()
.compare_exchange(state, data + transferred + acquired, AcqRel, Acquire)
.is_err()
{
entry.update_next_entry_anchor_ptr(null());
head_entry_ptr = Entry::ref_to_ptr(entry);
reset_failed = true;
return true;
}
unlocked = true;
resolved_entry_ptr = Entry::ref_to_ptr(entry);
true
}
} else {
entry.update_next_entry_anchor_ptr(null());
head_entry_ptr = Entry::ref_to_ptr(entry);
true
}
});
debug_assert!(!reset_failed || !unlocked);
if !reset_failed && !unlocked {
unlocked = self
.state()
.fetch_update(AcqRel, Acquire, |new_state| {
let new_data = new_state & WaitQueue::DATA_MASK;
debug_assert!(new_data <= data);
debug_assert!(new_data + transferred <= WaitQueue::DATA_MASK);
if new_data == data {
Some((new_state & WaitQueue::ADDR_MASK) | (new_data + transferred))
} else {
None
}
})
.is_ok();
}
if !unlocked {
state = self.state().fetch_add(transferred, AcqRel) + transferred;
}
Entry::iter_forward(resolved_entry_ptr, false, |entry, _next_entry| {
entry.set_result(0);
false
});
}
}
fn remove_wait_queue_entry(
&self,
mut state: usize,
entry_ptr_to_remove: *const Entry,
) -> (usize, bool) {
let mut result = Ok((state, false));
loop {
debug_assert_eq!(state & WaitQueue::LOCKED_FLAG, WaitQueue::LOCKED_FLAG);
debug_assert_ne!(state & WaitQueue::ADDR_MASK, 0);
let anchor_ptr = WaitQueue::to_anchor_ptr(state);
let tail_entry_ptr = WaitQueue::to_entry_ptr(anchor_ptr);
Entry::iter_forward(tail_entry_ptr, true, |entry, next_entry| {
if Entry::ref_to_ptr(entry) == entry_ptr_to_remove {
let prev_entry_ptr = entry.prev_entry_ptr();
if let Some(next_entry) = next_entry {
next_entry.update_prev_entry_ptr(prev_entry_ptr);
}
result = if let Some(prev_entry) = unsafe { prev_entry_ptr.as_ref() } {
prev_entry.update_next_entry_anchor_ptr(entry.next_entry_anchor_ptr());
Ok((state, true))
} else if let Some(next_entry) = next_entry {
let next_entry_addr = Entry::ref_to_ptr(next_entry).expose_provenance();
let next_entry_ptr = with_exposed_provenance(next_entry_addr);
let new_tail_ptr = Entry::to_wait_queue_ptr(next_entry_ptr);
let new_anchor_ptr = unsafe { (*new_tail_ptr).anchor_ptr().0 };
debug_assert_eq!(new_anchor_ptr.addr() & (!WaitQueue::ADDR_MASK), 0);
let next_state =
(state & (!WaitQueue::ADDR_MASK)) | new_anchor_ptr.expose_provenance();
debug_assert_eq!(
next_state & WaitQueue::LOCKED_FLAG,
WaitQueue::LOCKED_FLAG
);
self.state()
.compare_exchange(state, next_state, AcqRel, Acquire)
.map(|_| (next_state, true))
} else {
let next_state = state & WaitQueue::DATA_MASK;
self.state()
.compare_exchange(state, next_state, AcqRel, Acquire)
.map(|_| (next_state, true))
};
true
} else {
false
}
});
match result {
Ok((state, removed)) => return (state, removed),
Err(new_state) => state = new_state,
}
}
}
fn force_remove_wait_queue_entry(entry: &Entry) {
let this: &Self = entry.sync_primitive_ref();
let this_ptr: *const Entry = addr_of!(*entry);
let mut state = this.state().load(Acquire);
let mut need_completion = false;
loop {
if state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG {
thread::yield_now();
state = this.state().load(Acquire);
} else if state & WaitQueue::ADDR_MASK == 0 {
need_completion = true;
break;
} else if let Err(new_state) = this.state().compare_exchange(
state,
state | WaitQueue::LOCKED_FLAG,
AcqRel,
Acquire,
) {
state = new_state;
} else {
let (new_state, removed) =
this.remove_wait_queue_entry(state | WaitQueue::LOCKED_FLAG, this_ptr);
if new_state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG {
this.process_wait_queue(new_state);
}
if !removed {
need_completion = true;
}
break;
}
}
if need_completion {
while !entry.result_finalized() {
thread::yield_now();
}
this.release_loop(state, entry.opcode());
}
}
}