use std::marker::PhantomData;
use std::collections::HashMap;
use std::hash::BuildHasherDefault;
use std::any::{Any, TypeId};
use crate::init::Init;
use crate::ident_hash::IdentHash;
use crate::cell::UnsafeCell;
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::thread::yield_now;
#[cfg(feature = "tls")]
use crate::tls::LocalValue;
pub struct Container<K: kind::Kind> {
init: Init,
map: UnsafeCell<Option<TypeMap>>,
mutex: AtomicUsize,
frozen: bool,
_kind: PhantomData<K>
}
mod kind {
pub trait Kind { }
pub struct Send;
impl Kind for Send {}
pub struct SendSync;
impl Kind for SendSync {}
pub struct Neither;
impl Kind for Neither {}
}
pub type ContainerSend = Container<kind::Send>;
pub type ContainerSendSync = Container<kind::SendSync>;
pub type ContainerNeither = Container<kind::Neither>;
#[macro_export]
macro_rules! Container {
() => ($crate::container::ContainerNeither);
(Send) => ($crate::container::ContainerSend);
(Send + Sync) => ($crate::container::ContainerSendSync);
(Sync + Send) => ($crate::container::ContainerSendSync);
}
macro_rules! new {
() => (
Container {
init: Init::new(),
map: UnsafeCell::new(None),
mutex: AtomicUsize::new(0),
frozen: false,
_kind: PhantomData,
}
)
}
type TypeMap = HashMap<TypeId, AnyObject, BuildHasherDefault<IdentHash>>;
#[repr(C)]
struct AnyObject {
data: *mut (),
vtable: *mut (),
}
impl AnyObject {
fn anonymize<T: 'static>(value: T) -> AnyObject {
let any: Box<dyn Any> = Box::new(value) as Box<dyn Any>;
let any: *mut dyn Any = Box::into_raw(any);
unsafe { std::mem::transmute(any) }
}
fn deanonymize<T: 'static>(&self) -> Option<&T> {
unsafe {
let any: *const *const dyn Any = std::mem::transmute(self);
let any: &dyn Any = &*(*any as *const dyn Any);
any.downcast_ref()
}
}
}
impl Drop for AnyObject {
fn drop(&mut self) {
unsafe {
let any: *mut *mut dyn Any = std::mem::transmute(self);
let any: *mut dyn Any = *any;
let any: Box<dyn Any> = Box::from_raw(any);
drop(any)
}
}
}
impl Container<kind::SendSync> {
#[cfg(not(loom))]
pub const fn new() -> Self {
new!()
}
#[cfg(loom)]
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: Send + Sync + 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
#[inline]
#[cfg(feature = "tls")]
pub fn set_local<T, F>(&self, state_init: F) -> bool
where T: Send + 'static, F: Fn() -> T + Send + Sync + 'static
{
self.set::<LocalValue<T>>(LocalValue::new(state_init))
}
#[inline]
#[cfg(feature = "tls")]
pub fn try_get_local<T: Send + 'static>(&self) -> Option<&T> {
self.try_get::<LocalValue<T>>().map(|value| value.get())
}
#[inline]
#[cfg(feature = "tls")]
pub fn get_local<T: Send + 'static>(&self) -> &T {
self.try_get_local::<T>()
.expect("container::get_local(): get_local() called before set_local()")
}
}
unsafe impl Send for Container<kind::SendSync> { }
unsafe impl Sync for Container<kind::SendSync> { }
#[cfg(test)] static_assertions::assert_impl_all!(Container![Send + Sync]: Send, Sync);
#[cfg(test)] static_assertions::assert_impl_all!(Container![Sync + Send]: Send, Sync);
impl Container<kind::Send> {
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: Send + 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
}
unsafe impl Send for Container<kind::Send> { }
#[cfg(test)] static_assertions::assert_impl_all!(Container![Send]: Send);
#[cfg(test)] static_assertions::assert_not_impl_any!(Container![Send]: Sync);
#[cfg(test)] static_assertions::assert_not_impl_any!(Container<kind::Send>: Sync);
impl Container<kind::Neither> {
pub fn new() -> Self {
new!()
}
#[inline]
pub fn set<T: 'static>(&self, state: T) -> bool {
unsafe { self._set(state) }
}
}
#[cfg(test)] static_assertions::assert_not_impl_any!(Container![]: Send, Sync);
#[cfg(test)] static_assertions::assert_not_impl_any!(Container<kind::Neither>: Send, Sync);
impl<K: kind::Kind> Container<K> {
unsafe fn init_map_if_needed(&self) {
if self.init.needed() {
self.map.with_mut(|ptr| *ptr = Some(HashMap::<_, _, _>::default()));
self.init.mark_complete();
}
}
#[inline(always)]
unsafe fn map_mut(&self) -> &mut TypeMap {
self.init_map_if_needed();
self.map.with_mut(|ptr| (*ptr).as_mut().unwrap())
}
#[inline(always)]
unsafe fn map_ref(&self) -> &TypeMap {
self.init_map_if_needed();
self.map.with(|ptr| (*ptr).as_ref().unwrap())
}
unsafe fn _set<T: 'static>(&self, state: T) -> bool {
if self.is_frozen() {
return false;
}
self.lock();
let map = self.map_mut();
let type_id = TypeId::of::<T>();
let already_set = map.contains_key(&type_id);
if !already_set {
map.insert(type_id, AnyObject::anonymize(state));
}
self.unlock();
!already_set
}
unsafe fn with_map_ref<'a, F, T: 'a>(&'a self, f: F) -> T
where F: FnOnce(&'a TypeMap) -> T
{
if self.is_frozen() {
f(self.map_ref())
} else {
self.lock();
let result = f(self.map_ref());
self.unlock();
result
}
}
#[inline]
pub fn try_get<T: 'static>(&self) -> Option<&T> {
unsafe {
self.with_map_ref(|map| {
map.get(&TypeId::of::<T>()).and_then(|ptr| ptr.deanonymize())
})
}
}
#[inline]
pub fn get<T: 'static>(&self) -> &T {
self.try_get()
.expect("container::get(): get() called before set() for given type")
}
#[inline(always)]
pub fn freeze(&mut self) {
self.frozen = true;
}
#[inline(always)]
pub fn is_frozen(&self) -> bool {
self.frozen
}
pub fn len(&self) -> usize {
unsafe { self.with_map_ref(|map| map.len()) }
}
#[inline(always)]
fn lock(&self) {
while self.mutex.compare_exchange(0, 1, Ordering::AcqRel, Ordering::Relaxed).is_err() {
yield_now()
}
}
#[inline(always)]
fn unlock(&self) {
assert!(self.mutex.compare_exchange(1, 0, Ordering::AcqRel, Ordering::Relaxed).is_ok());
}
}
impl Default for Container![Send + Sync] {
fn default() -> Self {
<Container![Send + Sync]>::new()
}
}
impl Default for Container![Send] {
fn default() -> Self {
<Container![Send]>::new()
}
}
impl Default for Container![] {
fn default() -> Self {
<Container![]>::new()
}
}
impl<K: kind::Kind> std::fmt::Debug for Container<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Container")
.field("len", &self.len())
.finish()
}
}