use std::cell::UnsafeCell;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::collections::HashMap;
use std::hash::BuildHasherDefault;
use std::any::{Any, TypeId};
use init::Init;
use ident_hash::IdentHash;
#[cfg(feature = "tls")]
use tls::LocalValue;
pub struct Container {
init: Init,
map: UnsafeCell<*mut HashMap<TypeId, *mut Any, BuildHasherDefault<IdentHash>>>,
mutex: AtomicUsize,
}
impl Container {
pub const fn new() -> Container {
Container {
init: Init::new(),
map: UnsafeCell::new(0 as *mut _),
mutex: AtomicUsize::new(0)
}
}
#[inline(always)]
fn ensure_map_initialized(&self) {
if self.init.needed() {
unsafe {
*self.map.get() = Box::into_raw(Box::new(HashMap::<_, _, _>::default()));
}
self.init.mark_complete();
}
}
#[inline(always)]
fn lock(&self) {
while self.mutex.compare_and_swap(0, 1, Ordering::SeqCst) != 0 {}
}
#[inline(always)]
fn unlock(&self) {
assert!(self.mutex.compare_and_swap(1, 0, Ordering::SeqCst) == 1);
}
#[inline]
pub fn set<T: Send + Sync + 'static>(&self, state: T) -> bool {
self.ensure_map_initialized();
let type_id = TypeId::of::<T>();
unsafe {
self.lock();
let already_set = (**self.map.get()).contains_key(&type_id);
if !already_set {
let state_entry = Box::into_raw(Box::new(state) as Box<Any>);
(**self.map.get()).insert(type_id, state_entry);
}
self.unlock();
!already_set
}
}
#[inline]
pub fn try_get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.ensure_map_initialized();
let type_id = TypeId::of::<T>();
unsafe {
self.lock();
let item = (**self.map.get()).get(&type_id);
self.unlock();
item.map(|ptr| &*(*ptr as *const Any as *const T))
}
}
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> &T {
self.try_get()
.expect("container::get(): get() called before set() for given type")
}
#[cfg(feature = "tls")]
#[inline]
pub fn set_local<T, F>(&self, state_init: F) -> bool
where T: Send + 'static, F: Fn() -> T + 'static
{
self.set::<LocalValue<T>>(LocalValue::new(state_init))
}
#[cfg(feature = "tls")]
#[inline]
pub fn try_get_local<T: Send + 'static>(&self) -> Option<&T> {
self.try_get::<LocalValue<T>>().map(|value| value.get())
}
#[cfg(feature = "tls")]
#[inline]
pub fn get_local<T: Send + 'static>(&self) -> &T {
self.try_get_local::<T>()
.expect("container::get_local(): get_local() called before set_local()")
}
}
unsafe impl Sync for Container { }
unsafe impl Send for Container { }
impl Drop for Container {
fn drop(&mut self) {
if !self.init.has_completed() {
return
}
unsafe {
let map = &mut **self.map.get();
for value in map.values_mut() {
let mut boxed_any: Box<Any> = Box::from_raw(*value);
drop(&mut boxed_any);
}
let mut boxed_map: Box<HashMap<_, _, _>> = Box::from_raw(map);
drop(&mut boxed_map);
}
}
}