#![allow(unsafe_code)]
use std::borrow::Borrow;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use std::mem;
use std::sync::atomic::Ordering;
use arrayvec::ArrayVec;
use bitflags::bitflags;
use crossbeam_epoch::{Atomic, Guard, Owned, Shared};
use smallvec::SmallVec;
pub mod config;
pub mod debug;
pub mod iterator;
use self::config::Config;
use crate::existing_or_new::ExistingOrNew;
pub(crate) const LEVEL_BITS: usize = 4;
pub(crate) const LEVEL_MASK: u64 = 0b1111;
pub(crate) const LEVEL_CELLS: usize = 16;
pub(crate) const MAX_LEVELS: usize = mem::size_of::<u64>() * 8 / LEVEL_BITS;
bitflags! {
struct NodeFlags: usize {
const CONDEMNED = 0b01;
const DATA = 0b10;
}
}
fn nf(node: Shared<Inner>) -> NodeFlags {
NodeFlags::from_bits(node.tag()).expect("Invalid node flags")
}
unsafe fn load_data<'a, C: Config>(node: Shared<'a, Inner>) -> &'a Data<C> {
assert!(
nf(node).contains(NodeFlags::DATA),
"Tried to load data from inner node pointer"
);
(node.as_raw() as usize as *const Data<C>)
.as_ref()
.expect("A null pointer with data flag found")
}
fn owned_data<C: Config>(data: Data<C>) -> Owned<Inner> {
unsafe {
Owned::<Inner>::from_raw(Box::into_raw(Box::new(data)) as usize as *mut _)
.with_tag(NodeFlags::DATA.bits())
}
}
unsafe fn drop_data<C: Config>(ptr: Shared<Inner>) {
assert!(
nf(ptr).contains(NodeFlags::DATA),
"Tried to drop data from inner node pointer"
);
drop(Owned::from_raw(ptr.as_raw() as usize as *mut Data<C>));
}
#[derive(Default)]
struct Inner([Atomic<Inner>; LEVEL_CELLS]);
type Data<C> = SmallVec<[<C as Config>::Payload; 2]>;
enum TraverseState<C: Config, F> {
Empty, Created(C::Payload),
Future { key: C::Key, constructor: F },
}
impl<C: Config, F: FnOnce(C::Key) -> C::Payload> TraverseState<C, F> {
fn key(&self) -> &C::Key {
match self {
TraverseState::Empty => unreachable!("Not supposed to live in the empty state"),
TraverseState::Created(payload) => payload.borrow(),
TraverseState::Future { key, .. } => key,
}
}
fn payload(&mut self) -> C::Payload {
let (new_val, result) = match mem::replace(self, TraverseState::Empty) {
TraverseState::Empty => unreachable!("Not supposed to live in the empty state"),
TraverseState::Created(payload) => (TraverseState::Created(payload.clone()), payload),
TraverseState::Future { key, constructor } => {
let payload = constructor(key);
let created = TraverseState::Created(payload.clone());
(created, payload)
}
};
*self = new_val;
result
}
fn data_owned(&mut self) -> Owned<Inner> {
let mut data = Data::<C>::new();
data.push(self.payload());
owned_data::<C>(data)
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum TraverseMode {
Overwrite,
IfMissing,
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum PruneResult {
Null,
Singleton,
Copy,
CasFail,
}
pub struct Raw<C: Config, S> {
hash_builder: S,
root: Atomic<Inner>,
_data: PhantomData<C::Payload>,
}
impl<C, S> Raw<C, S>
where
C: Config,
S: BuildHasher,
{
pub fn with_hasher(hash_builder: S) -> Self {
assert!(
mem::align_of::<Data<C>>().trailing_zeros() >= NodeFlags::all().bits().count_ones(),
"BUG: Alignment of Data<Payload> is not large enough to store the internal flags",
);
assert!(
mem::align_of::<Inner>().trailing_zeros() >= NodeFlags::all().bits().count_ones(),
"BUG: Alignment of Inner not large enough to store internal flags",
);
Self {
hash_builder,
root: Atomic::null(),
_data: PhantomData,
}
}
fn hash<Q>(&self, key: &Q) -> u64
where
Q: ?Sized + Hash,
{
let mut hasher = self.hash_builder.build_hasher();
key.hash(&mut hasher);
hasher.finish()
}
pub fn insert<'s, 'p, 'r>(
&'s self,
payload: C::Payload,
pin: &'p Guard,
) -> Option<&'r C::Payload>
where
's: 'r,
'p: 'r,
{
self.traverse(
TraverseState::<C, fn(C::Key) -> C::Payload>::Created(payload),
TraverseMode::Overwrite,
pin,
)
.map(ExistingOrNew::into_inner)
}
unsafe fn prune(pin: &Guard, parent: &Atomic<Inner>, child: Shared<Inner>) -> PruneResult {
assert!(
!nf(child).contains(NodeFlags::DATA),
"Child passed to prune must not be data"
);
let inner = child.as_ref().expect("Null child node passed to prune");
let mut allow_contract = true;
let mut child_cnt = 0;
let mut last_leaf = None;
let mut new_child = Inner::default();
for (new, grandchild) in new_child.0.iter_mut().zip(&inner.0) {
let gc = grandchild.fetch_or(NodeFlags::CONDEMNED.bits(), Ordering::AcqRel, pin);
let flags = nf(gc) & !NodeFlags::CONDEMNED;
let gc = gc.with_tag(flags.bits());
if gc.is_null() {
} else if flags.contains(NodeFlags::DATA) {
last_leaf.replace(gc);
let gc = load_data::<C>(gc);
child_cnt += gc.len();
} else {
allow_contract = false;
child_cnt += 1;
}
*new = Atomic::from(gc);
}
let mut cleanup = None;
let (insert, prune_result) = match (allow_contract, child_cnt, last_leaf) {
(true, 1, Some(child)) => (child, PruneResult::Singleton),
(_, 0, None) => (Shared::null(), PruneResult::Null),
_ => {
let new = Owned::new(new_child).into_shared(pin);
cleanup = Some(new);
(new, PruneResult::Copy)
}
};
assert_eq!(
0,
child.tag(),
"Attempt to replace condemned pointer or prune data node"
);
let result = parent
.compare_and_set(child, insert, (Ordering::Release, Ordering::Relaxed), pin)
.is_ok();
if result {
pin.defer_destroy(child);
prune_result
} else {
drop(cleanup.map(|c| Shared::into_owned(c)));
PruneResult::CasFail
}
}
fn traverse<'s, 'p, 'r, F>(
&'s self,
mut state: TraverseState<C, F>,
mode: TraverseMode,
pin: &'p Guard,
) -> Option<ExistingOrNew<&'r C::Payload>>
where
's: 'r,
'p: 'r,
F: FnOnce(C::Key) -> C::Payload,
{
let hash = self.hash(state.key());
let mut shift = 0;
let mut current = &self.root;
let mut parent = None;
loop {
let node = current.load_consume(&pin);
let flags = nf(node);
let replace = |with: Owned<Inner>, delete_previous| {
let result = current.compare_and_set_weak(
node,
with,
(Ordering::Release, Ordering::Relaxed),
pin,
);
match result {
Ok(new) if !node.is_null() && delete_previous => {
assert!(flags.contains(NodeFlags::DATA));
let node = Shared::from(node.as_raw() as usize as *const Data<C>);
unsafe { pin.defer_destroy(node) };
Some(new)
}
Ok(new) => Some(new),
Err(e) => {
if NodeFlags::from_bits(e.new.tag())
.expect("Invalid flags")
.contains(NodeFlags::DATA)
{
unsafe { drop_data::<C>(e.new.into_shared(&pin)) };
}
None
}
}
};
if flags.contains(NodeFlags::CONDEMNED) {
unsafe {
let (parent, child) = parent.expect("Condemned the root!");
Self::prune(&pin, parent, child);
}
shift = 0;
current = &self.root;
parent = None;
} else if node.is_null() {
if let Some(new) = replace(state.data_owned(), true) {
if mode == TraverseMode::Overwrite {
return None;
} else {
let new = unsafe { load_data::<C>(new) };
return Some(ExistingOrNew::New(&new[0]));
}
}
} else if flags.contains(NodeFlags::DATA) {
let data = unsafe { load_data::<C>(node) };
assert!(!data.is_empty(), "Empty data nodes must not be kept around");
if data[0].borrow() != state.key() && shift < mem::size_of_val(&hash) * 8 {
assert!(data.len() == 1, "Collision node not deep enough");
let other_hash = self.hash(data[0].borrow());
let other_bits = (other_hash >> shift) & LEVEL_MASK;
let mut inner = Inner::default();
inner.0[other_bits as usize] = Atomic::from(node);
let split = Owned::new(inner);
replace(split, false);
} else {
let mut result = data
.iter()
.find(|l| (*l).borrow().borrow() == state.key())
.map(ExistingOrNew::Existing);
if result.is_none() || mode == TraverseMode::Overwrite {
let mut new = Data::<C>::with_capacity(data.len() + 1);
new.extend(
data.iter()
.filter(|l| (*l).borrow() != state.key())
.cloned(),
);
new.push(state.payload());
new.shrink_to_fit();
let new = owned_data::<C>(new);
if let Some(new) = replace(new, true) {
if result.is_none() && mode == TraverseMode::IfMissing {
let new = unsafe { load_data::<C>(new) };
result = Some(ExistingOrNew::New(new.last().unwrap()));
}
} else {
continue;
}
}
return result;
}
} else {
let inner = unsafe { node.as_ref().expect("We just checked this is not NULL") };
let bits = (hash >> shift) & LEVEL_MASK;
shift += LEVEL_BITS;
parent = Some((current, node));
current = &inner.0[bits as usize];
}
}
}
pub fn get<'r, 's, 'p, Q>(&'s self, key: &Q, pin: &'p Guard) -> Option<&'r C::Payload>
where
's: 'r,
'p: 's,
Q: ?Sized + Eq + Hash,
C::Key: Borrow<Q>,
{
let mut current = &self.root;
let mut hash = self.hash(key);
loop {
let node = current.load_consume(pin);
let flags = nf(node);
if node.is_null() {
return None;
} else if flags.contains(NodeFlags::DATA) {
return unsafe { load_data::<C>(node) }
.iter()
.find(|l| (*l).borrow().borrow() == key);
} else {
let inner = unsafe { node.as_ref().expect("We just checked this is not NULL") };
let bits = hash & LEVEL_MASK;
hash >>= LEVEL_BITS;
current = &inner.0[bits as usize];
}
}
}
pub fn get_or_insert_with<'s, 'p, 'r, F>(
&'s self,
key: C::Key,
create: F,
pin: &'p Guard,
) -> ExistingOrNew<&'r C::Payload>
where
's: 'r,
'p: 'r,
F: FnOnce(C::Key) -> C::Payload,
{
let state = TraverseState::Future {
key,
constructor: create,
};
self.traverse(state, TraverseMode::IfMissing, pin)
.expect("Should have created one for me")
}
pub fn remove<'r, 's, 'p, Q>(&'s self, key: &Q, pin: &'p Guard) -> Option<&'r C::Payload>
where
's: 'r,
'p: 'r,
Q: ?Sized + Eq + Hash,
C::Key: Borrow<Q>,
{
let mut current = &self.root;
let hash = self.hash(key);
let mut shift = 0;
let mut levels: ArrayVec<[_; MAX_LEVELS]> = ArrayVec::new();
let deleted = loop {
let node = current.load_consume(&pin);
let flags = nf(node);
let replace = |with: Shared<_>| {
let result = current.compare_and_set_weak(
node,
with,
(Ordering::Release, Ordering::Relaxed),
&pin,
);
match result {
Ok(_) => {
assert!(flags.contains(NodeFlags::DATA));
unsafe {
let node = Shared::from(node.as_raw() as usize as *const Data<C>);
pin.defer_destroy(node);
}
true
}
Err(ref e) if !e.new.is_null() => {
assert!(nf(e.new).contains(NodeFlags::DATA));
unsafe { drop_data::<C>(e.new) };
false
}
Err(_) => false,
}
};
if node.is_null() {
return None;
} else if flags.contains(NodeFlags::CONDEMNED) {
unsafe {
let (current, node) = levels.pop().expect("Condemned the root");
Self::prune(&pin, current, node);
}
levels.clear();
shift = 0;
current = &self.root;
} else if flags.contains(NodeFlags::DATA) {
let data = unsafe { load_data::<C>(node) };
let mut deleted = None;
let new = data
.iter()
.filter(|l| {
if (*l).borrow().borrow() == key {
deleted = Some(*l);
false
} else {
true
}
})
.cloned()
.collect::<Data<C>>();
if deleted.is_some() {
let new = if new.is_empty() {
Shared::null()
} else {
owned_data::<C>(new).into_shared(&pin)
};
if !replace(new) {
continue;
}
}
break deleted;
} else {
let inner = unsafe { node.as_ref().expect("We just checked for NULL") };
levels.push((current, node));
let bits = (hash >> shift) & LEVEL_MASK;
shift += LEVEL_BITS;
current = &inner.0[bits as usize];
}
};
if deleted.is_some() {
for (parent, child) in levels.into_iter().rev() {
let inner = unsafe { child.as_ref().expect("We just checked for NULL") };
let non_null = inner
.0
.iter()
.filter(|ptr| !ptr.load(Ordering::Relaxed, &pin).is_null())
.count();
if non_null > 1 {
break;
}
if let PruneResult::Copy = unsafe { Self::prune(&pin, parent, child) } {
break;
}
}
}
deleted
}
}
impl<C: Config, S> Raw<C, S> {
pub fn is_empty(&self) -> bool {
unsafe {
self.root
.load(Ordering::Relaxed, &crossbeam_epoch::unprotected())
.is_null()
}
}
pub fn hash_builder(&self) -> &S {
&self.hash_builder
}
}
impl<C: Config, S> Drop for Raw<C, S> {
fn drop(&mut self) {
unsafe fn drop_recursive<C: Config>(node: &Atomic<Inner>) {
let pin = crossbeam_epoch::unprotected();
let extract = node.load(Ordering::Relaxed, &pin);
let flags = nf(extract);
if extract.is_null() {
} else if flags.contains(NodeFlags::DATA) {
drop_data::<C>(extract);
} else {
let owned = extract.into_owned();
for sub in &owned.0 {
drop_recursive::<C>(sub);
}
drop(owned);
}
}
unsafe { drop_recursive::<C>(&self.root) };
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::ptr;
use super::config::Trivial as TrivialConfig;
use super::*;
pub(crate) struct NoHasher;
impl Hasher for NoHasher {
fn finish(&self) -> u64 {
0
}
fn write(&mut self, _: &[u8]) {}
}
impl BuildHasher for NoHasher {
type Hasher = NoHasher;
fn build_hasher(&self) -> NoHasher {
NoHasher
}
}
#[derive(Clone, Copy, Debug, Default)]
pub(crate) struct SplatHasher(u64);
impl Hasher for SplatHasher {
fn finish(&self) -> u64 {
self.0
}
fn write(&mut self, value: &[u8]) {
for val in value {
for idx in 0..mem::size_of::<u64>() {
self.0 ^= u64::from(*val) << (8 * idx);
}
}
}
}
pub(crate) struct MakeSplatHasher;
impl BuildHasher for MakeSplatHasher {
type Hasher = SplatHasher;
fn build_hasher(&self) -> SplatHasher {
SplatHasher::default()
}
}
#[test]
fn splat_hasher() {
let mut hasher = MakeSplatHasher.build_hasher();
hasher.write_u8(0);
assert_eq!(0, hasher.finish());
hasher.write_u8(8);
assert_eq!(0x0808_0808_0808_0808, hasher.finish());
}
#[test]
fn consts_consistent() {
assert!(LEVEL_CELLS.is_power_of_two());
assert_eq!(LEVEL_BITS, LEVEL_MASK.count_ones() as usize);
assert_eq!(LEVEL_BITS, (!LEVEL_MASK).trailing_zeros() as usize);
assert_eq!(LEVEL_CELLS, 2usize.pow(LEVEL_BITS as u32));
}
#[test]
fn prune_on_insert() {
let mut map = Raw::<TrivialConfig<u8>, _>::with_hasher(MakeSplatHasher);
let pin = crossbeam_epoch::pin();
for i in 0..LEVEL_CELLS as u8 {
assert!(map.insert(i, &pin).is_none());
}
eprintln!("{}", debug::PrintShape(&map));
let root = map.root.load(Ordering::Relaxed, &pin);
let flags = nf(root);
assert_eq!(
NodeFlags::empty(),
flags,
"Root should be non-condemned inner node"
);
assert!(!root.is_null());
let old_root = root.as_raw();
let root = unsafe { root.deref() };
for ptr in &root.0 {
let ptr = ptr.load(Ordering::Relaxed, &pin);
assert!(!ptr.is_null());
let flags = nf(ptr);
assert_eq!(
NodeFlags::DATA,
flags,
"Expected a data node, found {:?}",
ptr
);
}
root.0[0].fetch_or(NodeFlags::CONDEMNED.bits(), Ordering::Relaxed, &pin);
let old = map.insert(0, &pin);
assert_eq!(0, *old.unwrap());
map.assert_pruned();
let new_root = map.root.load(Ordering::Relaxed, &pin).as_raw();
assert!(!ptr::eq(old_root, new_root), "Condemned node not replaced");
for i in 0..LEVEL_CELLS as u8 {
assert_eq!(i, *map.get(&i, &pin).unwrap());
}
}
fn with_leftover() -> Raw<TrivialConfig<u8>, MakeSplatHasher> {
let map = Raw::<TrivialConfig<u8>, _>::with_hasher(MakeSplatHasher);
let pin = crossbeam_epoch::pin();
let i = Inner::default();
i.0[0].fetch_or(NodeFlags::CONDEMNED.bits(), Ordering::Relaxed, &pin);
map.root.store(Owned::new(i), Ordering::Relaxed);
assert!(iterator::Iter::new(&map).next().is_none());
assert!(!map.is_empty());
map
}
#[test]
fn prune_on_insert_empty() {
let mut map = with_leftover();
let pin = crossbeam_epoch::pin();
let old_root = map.root.load(Ordering::Relaxed, &pin).as_raw();
assert!(map.insert(0, &pin).is_none());
map.assert_pruned();
let new_root = map.root.load(Ordering::Relaxed, &pin);
let new_flags = nf(new_root);
assert_eq!(NodeFlags::DATA, new_flags);
assert!(
!ptr::eq(old_root, new_root.as_raw()),
"Condemned node not replaced"
);
}
#[test]
fn prune_on_remove() {
let map = Raw::<TrivialConfig<u8>, _>::with_hasher(MakeSplatHasher);
let pin = crossbeam_epoch::pin();
let i_inner = Inner::default();
let i_outer = Inner::default();
i_outer.0[0].store(
Owned::new(i_inner).with_tag(NodeFlags::CONDEMNED.bits()),
Ordering::Relaxed,
);
map.root.store(Owned::new(i_outer), Ordering::Relaxed);
assert!(iterator::Iter::new(&map).next().is_none());
assert!(!map.is_empty());
assert!(map.remove(&0, &pin).is_none());
eprintln!("{}", debug::PrintShape(&map));
assert_eq!(0, map.root.load(Ordering::Relaxed, &pin).tag());
}
}