use pin_project_lite::pin_project;
use std::cell::RefCell;
use std::error::Error;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, thread};
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
macro_rules! task_local {
() => {};
($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
$crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
$crate::task_local!($($rest)*);
};
($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
$crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! __task_local_inner {
($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
$vis static $name: $crate::task::LocalKey<$t> = {
std::thread_local! {
static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
}
$crate::task::LocalKey { inner: __KEY }
};
};
}
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct LocalKey<T: 'static> {
#[doc(hidden)]
pub inner: thread::LocalKey<RefCell<Option<T>>>,
}
impl<T: 'static> LocalKey<T> {
pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
where
F: Future,
{
TaskLocalFuture {
local: self,
slot: Some(value),
future: f,
_pinned: PhantomPinned,
}
}
pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
where
F: FnOnce() -> R,
{
let scope = TaskLocalFuture {
local: self,
slot: Some(value),
future: (),
_pinned: PhantomPinned,
};
crate::pin!(scope);
scope.with_task(|_| f())
}
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.try_with(f).expect(
"cannot access a Task Local Storage value \
without setting it via `LocalKey::set`",
)
}
pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
where
F: FnOnce(&T) -> R,
{
self.inner.with(|v| {
if let Some(val) = v.borrow().as_ref() {
Ok(f(val))
} else {
Err(AccessError { _private: () })
}
})
}
}
impl<T: Copy + 'static> LocalKey<T> {
pub fn get(&'static self) -> T {
self.with(|v| *v)
}
}
impl<T: 'static> fmt::Debug for LocalKey<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("LocalKey { .. }")
}
}
pin_project! {
pub struct TaskLocalFuture<T, F>
where
T: 'static
{
local: &'static LocalKey<T>,
slot: Option<T>,
#[pin]
future: F,
#[pin]
_pinned: PhantomPinned,
}
}
impl<T: 'static, F> TaskLocalFuture<T, F> {
fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R {
struct Guard<'a, T: 'static> {
local: &'static LocalKey<T>,
slot: &'a mut Option<T>,
prev: Option<T>,
}
impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
let value = self.local.inner.with(|c| c.replace(self.prev.take()));
*self.slot = value;
}
}
let mut project = self.project();
let val = project.slot.take();
let prev = project.local.inner.with(|c| c.replace(val));
let _guard = Guard {
prev,
slot: &mut project.slot,
local: *project.local,
};
f(project.future)
}
}
impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.with_task(|f| f.poll(cx))
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct AccessError {
_private: (),
}
impl fmt::Debug for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AccessError").finish()
}
}
impl fmt::Display for AccessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt("task-local value not set", f)
}
}
impl Error for AccessError {}