use std::{
cell::UnsafeCell,
convert::Infallible,
future::Future,
panic::{RefUnwindSafe, UnwindSafe},
pin::Pin,
ptr,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
sync::Mutex,
task,
};
#[cfg(feature = "unpin")]
pub mod unpin;
#[derive(Debug)]
pub struct OnceCell<T> {
value: UnsafeCell<Option<T>>,
inner: Inner,
}
unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T> Unpin for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
#[derive(Debug)]
struct Inner {
state: AtomicUsize,
queue: AtomicPtr<Queue>,
}
struct Queue {
wakers: Mutex<Option<Vec<task::Waker>>>,
}
struct QueueRef<'a> {
inner: &'a Inner,
queue: *const Queue,
}
unsafe impl<'a> Sync for QueueRef<'a> {}
unsafe impl<'a> Send for QueueRef<'a> {}
#[derive(Debug)]
struct QuickInitGuard<'a>(&'a Inner);
struct QueueWaiter<'a> {
guard: Option<QueueRef<'a>>,
}
struct QueueHead<'a> {
guard: QueueRef<'a>,
}
const NEW: usize = 0x0;
const QINIT_BIT: usize = 1 + (usize::MAX >> 2);
const READY_BIT: usize = 1 + (usize::MAX >> 1);
impl Inner {
const fn new() -> Self {
Inner { state: AtomicUsize::new(NEW), queue: AtomicPtr::new(ptr::null_mut()) }
}
const fn new_ready() -> Self {
Inner { state: AtomicUsize::new(READY_BIT), queue: AtomicPtr::new(ptr::null_mut()) }
}
#[cold]
fn initialize(&self, try_quick: bool) -> Result<QueueWaiter, QuickInitGuard> {
if try_quick {
if self
.state
.compare_exchange(NEW, QINIT_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return Err(QuickInitGuard(self));
}
}
let prev_state = self.state.fetch_add(1, Ordering::Acquire);
let mut guard = QueueRef { inner: self, queue: self.queue.load(Ordering::Acquire) };
if guard.queue.is_null() && prev_state & READY_BIT == 0 {
let wakers = Mutex::new(None);
let new_queue = Box::into_raw(Box::new(Queue { wakers }));
match self.queue.compare_exchange(
ptr::null_mut(),
new_queue,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_null) => {
guard.queue = new_queue;
}
Err(actual) => {
guard.queue = actual;
unsafe {
Box::from_raw(new_queue);
}
}
}
}
Ok(QueueWaiter { guard: Some(guard) })
}
fn set_ready(&self) {
let prev_state = self.state.fetch_or(READY_BIT, Ordering::Release);
debug_assert_eq!(prev_state & READY_BIT, 0, "Invalid state: somoene else set READY_BIT");
}
}
impl<'a> Drop for QueueRef<'a> {
fn drop(&mut self) {
let prev_state = self.inner.state.fetch_sub(1, Ordering::Release);
let curr_state = prev_state - 1;
if curr_state == READY_BIT || curr_state == READY_BIT | QINIT_BIT {
let queue = self.inner.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Box::from_raw(queue);
}
}
}
}
}
impl<'a> Drop for QuickInitGuard<'a> {
fn drop(&mut self) {
let prev_state = self.0.state.load(Ordering::Relaxed);
if prev_state == QINIT_BIT | READY_BIT || prev_state == QINIT_BIT {
let target = prev_state & !QINIT_BIT;
if self
.0
.state
.compare_exchange(prev_state, target, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
if target == READY_BIT {
let queue = self.0.queue.swap(ptr::null_mut(), Ordering::Relaxed);
if !queue.is_null() {
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
Box::from_raw(queue);
}
}
}
return;
}
}
let waiter = self.0.initialize(false).expect("Got a QuickInitGuard in slow init");
let guard = waiter.guard.expect("No guard available even without polling");
if guard.queue.is_null() {
drop(guard);
} else {
let queue = unsafe { &*guard.queue };
let mut lock = queue.wakers.lock().unwrap();
lock.get_or_insert_with(Vec::new);
self.0.state.fetch_and(!QINIT_BIT, Ordering::Relaxed);
drop(lock);
drop(QueueHead { guard })
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
let queue = *self.queue.get_mut();
if !queue.is_null() {
unsafe {
Box::from_raw(queue);
}
}
}
}
impl<'a> Future for QueueWaiter<'a> {
type Output = Option<QueueHead<'a>>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<QueueHead<'a>>> {
let guard = self.guard.as_ref().expect("Polled future after finished");
let state = guard.inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
let queue = unsafe { &*guard.queue };
let mut lock = queue.wakers.lock().unwrap();
let state = guard.inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
match lock.as_mut() {
None if state & QINIT_BIT == 0 => {
*lock = Some(Vec::new());
drop(lock);
task::Poll::Ready(Some(QueueHead { guard: self.guard.take().unwrap() }))
}
None => {
let waker = cx.waker().clone();
*lock = Some(vec![waker]);
task::Poll::Pending
}
Some(wakers) => {
let my_waker = cx.waker();
for waker in wakers.iter() {
if waker.will_wake(my_waker) {
return task::Poll::Pending;
}
}
wakers.push(my_waker.clone());
task::Poll::Pending
}
}
}
}
impl<'a> Drop for QueueHead<'a> {
fn drop(&mut self) {
if let Some(queue) = unsafe { self.guard.queue.as_ref() } {
let wakers = queue
.wakers
.lock()
.expect("Lock poisoned")
.take()
.expect("QueueHead dropped without a waker list");
for waker in wakers {
waker.wake();
}
}
}
}
impl<T> OnceCell<T> {
pub const fn new() -> Self {
Self { value: UnsafeCell::new(None), inner: Inner::new() }
}
pub const fn new_with(value: Option<T>) -> Self {
let inner = match value {
Some(_) => Inner::new_ready(),
None => Inner::new(),
};
Self { value: UnsafeCell::new(value), inner }
}
pub async fn get_or_init(&self, init: impl Future<Output = T>) -> &T {
match self.get_or_try_init(async move { Ok::<T, Infallible>(init.await) }).await {
Ok(t) => t,
Err(e) => match e {},
}
}
pub async fn get_or_try_init<E>(
&self,
init: impl Future<Output = Result<T, E>>,
) -> Result<&T, E> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
self.init_slow(state == NEW, init).await?;
}
Ok(unsafe { (&*self.value.get()).as_ref().unwrap() })
}
#[cold]
async fn init_slow<E>(
&self,
try_quick: bool,
init: impl Future<Output = Result<T, E>>,
) -> Result<(), E> {
match self.inner.initialize(try_quick) {
Err(guard) => {
let value = init.await?;
unsafe {
*self.value.get() = Some(value);
}
self.inner.set_ready();
drop(guard);
}
Ok(guard) => {
if let Some(init_lock) = guard.await {
let value = init.await?;
unsafe {
*self.value.get() = Some(value);
}
init_lock.guard.inner.set_ready();
} else {
}
}
}
Ok(())
}
pub fn get(&self) -> Option<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
None
} else {
unsafe { (&*self.value.get()).as_ref() }
}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.value.get_mut().as_mut()
}
pub fn take(&mut self) -> Option<T> {
self.value.get_mut().take()
}
pub fn into_inner(self) -> Option<T> {
self.value.into_inner()
}
}
#[derive(Debug)]
enum LazyState<T, F> {
Running(F),
Ready(T),
}
#[derive(Debug)]
pub struct Lazy<T, F> {
value: UnsafeCell<LazyState<T, F>>,
inner: Inner,
}
unsafe impl<T: Sync + Send, F: Sync + Send> Sync for Lazy<T, F> {}
unsafe impl<T: Send, F: Send> Send for Lazy<T, F> {}
impl<T: Unpin, F: Unpin> Unpin for Lazy<T, F> {}
impl<T: RefUnwindSafe + UnwindSafe, F: RefUnwindSafe + UnwindSafe> RefUnwindSafe for Lazy<T, F> {}
impl<T: UnwindSafe, F: UnwindSafe> UnwindSafe for Lazy<T, F> {}
impl<T, F> Lazy<T, F>
where
F: Future<Output = T>,
{
pub fn new(future: F) -> Self {
Self::from_future(future)
}
pub async fn get(self: Pin<&Self>) -> Pin<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
self.init_slow(state == NEW).await;
}
unsafe {
match &*self.value.get() {
LazyState::Ready(v) => Pin::new_unchecked(v),
_ => unreachable!(),
}
}
}
#[cold]
async fn init_slow(self: Pin<&Self>, try_quick: bool) {
match self.inner.initialize(try_quick) {
Err(guard) => {
let init = unsafe {
match &mut *self.value.get() {
LazyState::Running(f) => Pin::new_unchecked(f),
_ => unreachable!(),
}
};
let value = init.await;
unsafe {
*self.value.get() = LazyState::Ready(value);
}
self.inner.set_ready();
drop(guard);
}
Ok(guard) => {
if let Some(init_lock) = guard.await {
let init = unsafe {
match &mut *self.value.get() {
LazyState::Running(f) => Pin::new_unchecked(f),
_ => unreachable!(),
}
};
let value = init.await;
unsafe {
*self.value.get() = LazyState::Ready(value);
}
init_lock.guard.inner.set_ready();
} else {
}
}
}
}
}
impl<T, F> Lazy<T, F> {
pub const fn from_future(future: F) -> Self {
Self { value: UnsafeCell::new(LazyState::Running(future)), inner: Inner::new() }
}
pub const fn with_value(value: T) -> Self {
Self { value: UnsafeCell::new(LazyState::Ready(value)), inner: Inner::new_ready() }
}
pub fn try_get(&self) -> Option<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
None
} else {
match unsafe { &*self.value.get() } {
LazyState::Ready(v) => Some(v),
_ => unreachable!(),
}
}
}
pub fn try_get_mut(self: Pin<&mut Self>) -> Option<Pin<&mut T>> {
unsafe {
match self.get_unchecked_mut().value.get_mut() {
LazyState::Ready(v) => Some(Pin::new_unchecked(v)),
_ => None,
}
}
}
pub fn try_get_mut_unpin(&mut self) -> Option<&mut T> {
match self.value.get_mut() {
LazyState::Ready(v) => Some(v),
_ => None,
}
}
pub fn into_inner(self) -> Option<T> {
match self.value.into_inner() {
LazyState::Ready(v) => Some(v),
_ => None,
}
}
}