use std::marker::PhantomData;
use std::collections::HashMap;
use std::hash::BuildHasherDefault;
use std::any::{Any, TypeId};
use crate::init::Init;
use crate::ident_hash::IdentHash;
use crate::shim::cell::UnsafeCell;
use crate::shim::sync::atomic::{AtomicUsize, Ordering};
use crate::shim::thread::yield_now;
#[cfg(feature = "tls")]
use crate::tls::LocalValue;
pub struct TypeMap<K: kind::Kind> {
init: Init,
map: UnsafeCell<Option<TypeIdMap>>,
mutex: AtomicUsize,
frozen: bool,
_kind: PhantomData<*mut K>
}
mod kind {
pub trait Kind { }
pub struct Send;
impl Kind for Send {}
pub struct SendSync;
impl Kind for SendSync {}
pub struct Neither;
impl Kind for Neither {}
}
pub type TypeMapSend = TypeMap<kind::Send>;
pub type TypeMapSendSync = TypeMap<kind::SendSync>;
pub type TypeMapNeither = TypeMap<kind::Neither>;
#[macro_export]
macro_rules! TypeMap {
() => ($crate::type_map::TypeMapNeither);
(Send) => ($crate::type_map::TypeMapSend);
(Send + Sync) => ($crate::type_map::TypeMapSendSync);
(Sync + Send) => ($crate::type_map::TypeMapSendSync);
}
macro_rules! new {
() => (
TypeMap {
init: Init::new(),
map: UnsafeCell::new(None),
mutex: AtomicUsize::new(0),
frozen: false,
_kind: PhantomData,
}
)
}
type TypeIdMap = HashMap<TypeId, Box<dyn Any>, BuildHasherDefault<IdentHash>>;
impl TypeMap<kind::SendSync> {
#[cfg(not(loom))]
pub const fn new() -> Self {
new!()
}
#[cfg(loom)]
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: Send + Sync + 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
#[inline]
#[cfg(feature = "tls")]
pub fn set_local<T, F>(&self, state_init: F) -> bool
where T: Send + 'static, F: Fn() -> T + Send + Sync + 'static
{
self.set::<LocalValue<T>>(LocalValue::new(state_init))
}
#[inline]
#[cfg(feature = "tls")]
pub fn try_get_local<T: Send + 'static>(&self) -> Option<&T> {
self.try_get::<LocalValue<T>>().map(|value| value.get())
}
#[inline]
#[cfg(feature = "tls")]
pub fn get_local<T: Send + 'static>(&self) -> &T {
self.try_get_local::<T>()
.expect("type_map::get_local(): get_local() called before set_local()")
}
}
unsafe impl Send for TypeMap<kind::SendSync> { }
unsafe impl Sync for TypeMap<kind::SendSync> { }
#[cfg(test)] static_assertions::assert_impl_all!(TypeMap![Send + Sync]: Send, Sync);
#[cfg(test)] static_assertions::assert_impl_all!(TypeMap![Sync + Send]: Send, Sync);
impl TypeMap<kind::Send> {
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: Send + 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
}
unsafe impl Send for TypeMap<kind::Send> { }
#[cfg(test)] static_assertions::assert_impl_all!(TypeMap![Send]: Send);
#[cfg(test)] static_assertions::assert_not_impl_any!(TypeMap![Send]: Sync);
#[cfg(test)] static_assertions::assert_not_impl_any!(TypeMap<kind::Send>: Sync);
impl TypeMap<kind::Neither> {
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
}
#[cfg(test)] static_assertions::assert_not_impl_any!(TypeMap![]: Send, Sync);
#[cfg(test)] static_assertions::assert_not_impl_any!(TypeMap<kind::Neither>: Send, Sync);
impl<K: kind::Kind> TypeMap<K> {
unsafe fn init_map_if_needed(&self) {
if self.init.needed() {
self.map.with_mut(|ptr| *ptr = Some(HashMap::<_, _, _>::default()));
self.init.mark_complete();
}
}
#[inline(always)]
#[allow(clippy::mut_from_ref)]
unsafe fn map_mut(&self) -> &mut TypeIdMap {
self.init_map_if_needed();
self.map.with_mut(|ptr| (*ptr).as_mut().unwrap())
}
#[inline(always)]
unsafe fn map_ref(&self) -> &TypeIdMap {
self.init_map_if_needed();
self.map.with(|ptr| (*ptr).as_ref().unwrap())
}
unsafe fn _set<T: 'static>(&self, state: T) -> bool {
if self.is_frozen() {
return false;
}
self.lock();
let map = self.map_mut();
let type_id = TypeId::of::<T>();
let already_set = map.contains_key(&type_id);
if !already_set {
map.insert(type_id, Box::new(state) as Box<dyn Any>);
}
self.unlock();
!already_set
}
unsafe fn with_map_ref<'a, F, T: 'a>(&'a self, f: F) -> T
where F: FnOnce(&'a TypeIdMap) -> T
{
if self.is_frozen() {
f(self.map_ref())
} else {
self.lock();
let result = f(self.map_ref());
self.unlock();
result
}
}
#[inline]
pub fn try_get<T: 'static>(&self) -> Option<&T> {
unsafe {
self.with_map_ref(|map| {
map.get(&TypeId::of::<T>()).and_then(|ptr| ptr.downcast_ref())
})
}
}
#[inline]
pub fn get<T: 'static>(&self) -> &T {
self.try_get()
.expect("type_map::get(): get() called before set() for given type")
}
#[inline(always)]
pub fn freeze(&mut self) {
self.frozen = true;
}
#[inline(always)]
pub fn is_frozen(&self) -> bool {
self.frozen
}
#[inline]
pub fn len(&self) -> usize {
unsafe { self.with_map_ref(|map| map.len()) }
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline(always)]
fn lock(&self) {
while self.mutex.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Relaxed).is_err() {
yield_now();
}
}
#[inline(always)]
fn unlock(&self) {
assert!(self.mutex.compare_exchange(1, 0, Ordering::AcqRel, Ordering::Relaxed).is_ok());
}
}
impl Default for TypeMap![Send + Sync] {
fn default() -> Self {
<TypeMap![Send + Sync]>::new()
}
}
impl Default for TypeMap![Send] {
fn default() -> Self {
<TypeMap![Send]>::new()
}
}
impl Default for TypeMap![] {
fn default() -> Self {
<TypeMap![]>::new()
}
}
impl<K: kind::Kind> std::fmt::Debug for TypeMap<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypeMap")
.field("len", &self.len())
.finish()
}
}