use core::cell::UnsafeCell;
use core::ffi::c_void;
use core::fmt;
use core::panic::{RefUnwindSafe, UnwindSafe};
use core::ptr::NonNull;
use core::sync::atomic::{AtomicIsize, Ordering};
use crate::generated::dispatch_once_t;
#[cfg_attr(not(feature = "std"), doc = "[`std::sync::Once`]: #std-not-enabled")]
#[doc(alias = "dispatch_once_t")]
pub struct DispatchOnce {
predicate: UnsafeCell<dispatch_once_t>,
}
extern "C" fn invoke_closure<F>(context: *mut c_void)
where
F: FnOnce(),
{
let context: *mut Option<F> = context.cast();
let closure: &mut Option<F> = unsafe { &mut *context };
let closure = unsafe { closure.take().unwrap_unchecked() };
(closure)();
}
#[cfg_attr(
// DISPATCH_ONCE_INLINE_FASTPATH, see DispatchOnce::call_once below.
any(target_arch = "x86", target_arch = "x86_64", target_vendor = "apple"),
cold,
inline(never)
)]
fn invoke_dispatch_once<F>(predicate: NonNull<dispatch_once_t>, closure: F)
where
F: FnOnce(),
{
let mut closure = Some(closure);
let context: *mut Option<F> = &mut closure;
let context: *mut c_void = context.cast();
unsafe { DispatchOnce::once_f(predicate, context, invoke_closure::<F>) };
}
impl DispatchOnce {
#[inline]
#[allow(clippy::new_without_default)] pub const fn new() -> Self {
Self {
predicate: UnsafeCell::new(0),
}
}
#[inline]
#[doc(alias = "dispatch_once")]
#[doc(alias = "dispatch_once_f")]
pub fn call_once<F>(&self, work: F)
where
F: FnOnce(),
{
let predicate = NonNull::new(self.predicate.get()).unwrap();
if cfg!(any(
target_arch = "x86",
target_arch = "x86_64",
target_vendor = "apple"
)) {
let atomic_predicate: &AtomicIsize = unsafe { predicate.cast().as_ref() };
if atomic_predicate.load(Ordering::Acquire) != !0 {
invoke_dispatch_once(predicate, work);
}
} else {
invoke_dispatch_once(predicate, work);
}
}
}
unsafe impl Send for DispatchOnce {}
unsafe impl Sync for DispatchOnce {}
impl UnwindSafe for DispatchOnce {}
impl RefUnwindSafe for DispatchOnce {}
impl fmt::Debug for DispatchOnce {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DispatchOnce").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use core::cell::Cell;
use core::mem::ManuallyDrop;
use super::*;
#[test]
fn test_static() {
static ONCE: DispatchOnce = DispatchOnce::new();
let mut num = 0;
ONCE.call_once(|| num += 1);
ONCE.call_once(|| num += 1);
assert!(num == 1);
}
#[test]
fn test_in_loop() {
let once = DispatchOnce::new();
let mut call_count = 0;
for _ in 0..10 {
once.call_once(|| call_count += 1);
}
assert_eq!(call_count, 1);
}
#[test]
fn test_move() {
let once = DispatchOnce::new();
let mut call_count = 0;
for _ in 0..10 {
once.call_once(|| call_count += 1);
}
#[allow(clippy::redundant_locals)]
let once = once;
for _ in 0..10 {
once.call_once(|| call_count += 1);
}
let once = DispatchOnce {
predicate: UnsafeCell::new(once.predicate.into_inner()),
};
for _ in 0..10 {
once.call_once(|| call_count += 1);
}
assert_eq!(call_count, 1);
}
#[test]
#[cfg(feature = "std")]
fn test_threaded() {
let once = DispatchOnce::new();
let num = AtomicIsize::new(0);
std::thread::scope(|scope| {
scope.spawn(|| {
once.call_once(|| {
num.fetch_add(1, Ordering::Relaxed);
});
});
scope.spawn(|| {
once.call_once(|| {
num.fetch_add(1, Ordering::Relaxed);
});
});
scope.spawn(|| {
once.call_once(|| {
num.fetch_add(1, Ordering::Relaxed);
});
});
});
assert!(num.load(Ordering::Relaxed) == 1);
}
#[derive(Clone)]
struct DropTest<'a>(&'a Cell<usize>);
impl Drop for DropTest<'_> {
fn drop(&mut self) {
self.0.set(self.0.get() + 1);
}
}
#[test]
fn test_drop_in_closure() {
let amount_of_drops = Cell::new(0);
let once = DispatchOnce::new();
let tester = DropTest(&amount_of_drops);
once.call_once(move || {
let _tester = tester;
});
assert_eq!(amount_of_drops.get(), 1);
let tester = DropTest(&amount_of_drops);
once.call_once(move || {
let _tester = tester;
});
assert_eq!(amount_of_drops.get(), 2);
}
#[test]
fn test_drop_in_closure_with_leak() {
let amount_of_drops = Cell::new(0);
let once = DispatchOnce::new();
let tester = DropTest(&amount_of_drops);
once.call_once(move || {
let _tester = ManuallyDrop::new(tester);
});
assert_eq!(amount_of_drops.get(), 0);
let tester = DropTest(&amount_of_drops);
once.call_once(move || {
let _tester = ManuallyDrop::new(tester);
});
assert_eq!(amount_of_drops.get(), 1);
}
#[test]
#[ignore = "traps the process (as expected)"]
fn test_recursive_invocation() {
let once = DispatchOnce::new();
once.call_once(|| {
once.call_once(|| {});
});
}
}