[go: up one dir, main page]

keyed-set 0.4.0

Keyed Set: a hashbrown-based HashSet that indexes based on projections of its elements.
Documentation
use core::{
    hash::{BuildHasher, Hash, Hasher},
    marker::PhantomData,
};

use hashbrown::{
    hash_map::DefaultHashBuilder,
    raw::{RawIntoIter, RawIter},
};

#[derive(Clone, Default)]
pub struct KeyedSet<T, Extractor, S = DefaultHashBuilder> {
    inner: hashbrown::raw::RawTable<T>,
    hash_builder: S,
    extractor: Extractor,
}

impl<'a, T, Extractor, S> IntoIterator for &'a KeyedSet<T, Extractor, S> {
    type Item = &'a T;
    type IntoIter = Iter<'a, T>;
    fn into_iter(self) -> Self::IntoIter {
        self.iter()
    }
}
impl<'a, T, Extractor, S> IntoIterator for &'a mut KeyedSet<T, Extractor, S> {
    type Item = &'a mut T;
    type IntoIter = IterMut<'a, T>;
    fn into_iter(self) -> Self::IntoIter {
        self.iter_mut()
    }
}
pub trait KeyExtractor<'a, T> {
    type Key: Hash;
    fn extract(&self, from: &'a T) -> Self::Key;
}
impl<'a, T: 'a, U: Hash, F: Fn(&'a T) -> U> KeyExtractor<'a, T> for F {
    type Key = U;
    fn extract(&self, from: &'a T) -> Self::Key {
        self(from)
    }
}
impl<'a, T: 'a + Hash> KeyExtractor<'a, T> for () {
    type Key = &'a T;
    fn extract(&self, from: &'a T) -> Self::Key {
        from
    }
}
impl<T, Extractor> KeyedSet<T, Extractor>
where
    Extractor: for<'a> KeyExtractor<'a, T>,
    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: std::hash::Hash,
{
    pub fn new(extractor: Extractor) -> Self {
        Self {
            inner: Default::default(),
            hash_builder: Default::default(),
            extractor,
        }
    }
}

impl<T: std::fmt::Debug, Extractor, S> std::fmt::Debug for KeyedSet<T, Extractor, S> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "KeyedSet {{")?;
        for v in self.iter() {
            write!(f, "{:?}, ", v)?;
        }
        write!(f, "}}")
    }
}

impl<T, Extractor, S> KeyedSet<T, Extractor, S>
where
    Extractor: for<'a> KeyExtractor<'a, T>,
    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: std::hash::Hash,
    S: BuildHasher,
{
    pub fn insert(&mut self, value: T) -> Option<T>
    where
        for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
            PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
    {
        let key = self.extractor.extract(&value);
        let mut hasher = self.hash_builder.build_hasher();
        key.hash(&mut hasher);
        let hash = hasher.finish();
        match self
            .inner
            .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
        {
            Some(bucket) => {
                core::mem::drop(key);
                Some(core::mem::replace(bucket, value))
            }
            None => {
                core::mem::drop(key);
                let hasher = make_hasher(&self.hash_builder, &self.extractor);
                self.inner.insert(hash, value, hasher);
                None
            }
        }
    }
    pub fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
    where
        K: std::hash::Hash,
        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
    {
        <Self as IEntry<T, Extractor, S, DefaultBorrower>>::entry(self, key)
    }
    pub fn write(&mut self, value: T) -> &mut T
    where
        for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
            PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
    {
        let key = self.extractor.extract(&value);
        let mut hasher = self.hash_builder.build_hasher();
        key.hash(&mut hasher);
        let hash = hasher.finish();
        match self
            .inner
            .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
        {
            Some(bucket) => {
                core::mem::drop(key);
                *bucket = value;
                unsafe { std::mem::transmute(bucket) }
            }
            None => {
                core::mem::drop(key);
                let hasher = make_hasher(&self.hash_builder, &self.extractor);
                let bucket = self.inner.insert(hash, value, hasher);
                unsafe { &mut *bucket.as_ptr() }
            }
        }
    }
    pub fn get<K>(&self, key: &K) -> Option<&T>
    where
        K: std::hash::Hash,
        for<'a> <Extractor as KeyExtractor<'a, T>>::Key: std::hash::Hash + PartialEq<K>,
    {
        let mut hasher = self.hash_builder.build_hasher();
        key.hash(&mut hasher);
        let hash = hasher.finish();
        self.inner.get(hash, |i| self.extractor.extract(i).eq(key))
    }
    pub fn get_mut<'a, K>(&'a mut self, key: &'a K) -> Option<KeyedSetGuard<'a, K, T, Extractor>>
    where
        K: std::hash::Hash,
        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
    {
        let mut hasher = self.hash_builder.build_hasher();
        key.hash(&mut hasher);
        let hash = hasher.finish();
        self.inner
            .get_mut(hash, |i| self.extractor.extract(i).eq(key))
            .map(|guarded| KeyedSetGuard {
                guarded,
                key,
                extractor: &self.extractor,
            })
    }
    pub fn get_mut_unguarded<'a, K>(&'a mut self, key: &K) -> Option<&'a mut T>
    where
        K: std::hash::Hash,
        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
    {
        let mut hasher = self.hash_builder.build_hasher();
        key.hash(&mut hasher);
        let hash = hasher.finish();
        self.inner
            .get_mut(hash, |i| self.extractor.extract(i).eq(key))
    }
}
pub trait IEntry<T, Extractor, S, Borrower = DefaultBorrower>
where
    Extractor: for<'a> KeyExtractor<'a, T>,
    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: std::hash::Hash,
    S: BuildHasher,
{
    fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
    where
        Borrower: IBorrower<K>,
        <Borrower as IBorrower<K>>::Borrowed: std::hash::Hash,
        for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
            std::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>;
}
impl<T, Extractor, S, Borrower> IEntry<T, Extractor, S, Borrower> for KeyedSet<T, Extractor, S>
where
    Extractor: for<'a> KeyExtractor<'a, T>,
    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: std::hash::Hash,
    S: BuildHasher,
{
    fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
    where
        Borrower: IBorrower<K>,
        <Borrower as IBorrower<K>>::Borrowed: std::hash::Hash,
        for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
            std::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>,
    {
        match self.get_mut_unguarded(Borrower::borrow(&key)) {
            Some(entry) => Entry::OccupiedEntry(unsafe { std::mem::transmute(entry) }),
            None => Entry::Vacant(VacantEntry { set: self, key }),
        }
    }
}
pub struct DefaultBorrower;
pub trait IBorrower<T> {
    type Borrowed;
    fn borrow(value: &T) -> &Self::Borrowed;
}
impl<T> IBorrower<T> for DefaultBorrower {
    type Borrowed = T;

    fn borrow(value: &T) -> &Self::Borrowed {
        value
    }
}
impl<T, Extractor, S> KeyedSet<T, Extractor, S> {
    pub fn iter(&self) -> Iter<T> {
        Iter {
            inner: unsafe { self.inner.iter() },
            marker: PhantomData,
        }
    }
    pub fn iter_mut(&mut self) -> IterMut<T> {
        IterMut {
            inner: unsafe { self.inner.iter() },
            marker: PhantomData,
        }
    }
    pub fn len(&self) -> usize {
        self.inner.len()
    }
    pub fn is_empty(&self) -> bool {
        self.inner.is_empty()
    }
}

pub struct KeyedSetGuard<'a, K, T, Extractor>
where
    Extractor: for<'z> KeyExtractor<'z, T>,
    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
{
    guarded: &'a mut T,
    key: &'a K,
    extractor: &'a Extractor,
}
impl<'a, K, T, Extractor> std::ops::Deref for KeyedSetGuard<'a, K, T, Extractor>
where
    Extractor: for<'z> KeyExtractor<'z, T>,
    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
{
    type Target = T;

    fn deref(&self) -> &Self::Target {
        self.guarded
    }
}
impl<'a, K, T, Extractor> std::ops::DerefMut for KeyedSetGuard<'a, K, T, Extractor>
where
    Extractor: for<'z> KeyExtractor<'z, T>,
    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
{
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.guarded
    }
}
impl<'a, K, T, Extractor> Drop for KeyedSetGuard<'a, K, T, Extractor>
where
    Extractor: for<'z> KeyExtractor<'z, T>,
    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: std::hash::Hash + PartialEq<K>,
{
    fn drop(&mut self) {
        if !self.extractor.extract(&*self.guarded).eq(self.key) {
            panic!("KeyedSetGuard dropped with new value that would change the key, breaking the internal table's invariants.")
        }
    }
}

pub struct IntoIter<T>(RawIntoIter<T>);

impl<T> ExactSizeIterator for IntoIter<T> {
    fn len(&self) -> usize {
        self.0.len()
    }
}
impl<T> Iterator for IntoIter<T> {
    type Item = T;
    fn next(&mut self) -> Option<Self::Item> {
        self.0.next()
    }
}

pub struct Iter<'a, T> {
    inner: RawIter<T>,
    marker: PhantomData<&'a ()>,
}
impl<'a, T: 'a> Iterator for Iter<'a, T> {
    type Item = &'a T;
    fn next(&mut self) -> Option<Self::Item> {
        self.inner.next().map(|b| unsafe { b.as_ref() })
    }
}
impl<'a, T: 'a> ExactSizeIterator for Iter<'a, T> {
    fn len(&self) -> usize {
        self.inner.len()
    }
}
pub struct IterMut<'a, T> {
    inner: RawIter<T>,
    marker: PhantomData<&'a mut ()>,
}
impl<'a, T: 'a> Iterator for IterMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        self.inner.next().map(|b| unsafe { b.as_mut() })
    }
}
impl<'a, T: 'a> ExactSizeIterator for IterMut<'a, T> {
    fn len(&self) -> usize {
        self.inner.len()
    }
}

pub struct VacantEntry<'a, T: 'a, Extractor, K, S> {
    pub set: &'a mut KeyedSet<T, Extractor, S>,
    pub key: K,
}
pub enum Entry<'a, T, Extractor, K, S = DefaultHashBuilder> {
    Vacant(VacantEntry<'a, T, Extractor, K, S>),
    OccupiedEntry(&'a mut T),
}

impl<'a, T: 'a, Extractor, S, K> Entry<'a, T, Extractor, K, S>
where
    S: BuildHasher,
    for<'z> Extractor: KeyExtractor<'z, T>,
    for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
        PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
{
    pub fn get_or_insert_with(self, f: impl FnOnce(K) -> T) -> &'a mut T {
        match self {
            Entry::Vacant(entry) => entry.insert_with(f),
            Entry::OccupiedEntry(entry) => entry,
        }
    }
    pub fn get_or_insert_with_into(self) -> &'a mut T
    where
        K: Into<T>,
    {
        self.get_or_insert_with(|k| k.into())
    }
}
impl<'a, K, T, Extractor, S> VacantEntry<'a, T, Extractor, K, S>
where
    S: BuildHasher,
    for<'z> Extractor: KeyExtractor<'z, T>,
    for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
        PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
{
    pub fn insert_with<F: FnOnce(K) -> T>(self, f: F) -> &'a mut T {
        self.set.write(f(self.key))
    }
}

fn make_hasher<'a, S: BuildHasher, Extractor, T>(
    hash_builder: &'a S,
    extractor: &'a Extractor,
) -> impl Fn(&T) -> u64 + 'a
where
    Extractor: for<'b> KeyExtractor<'b, T>,
    for<'b> <Extractor as KeyExtractor<'b, T>>::Key: std::hash::Hash,
{
    move |value| {
        let key = extractor.extract(value);
        let mut hasher = hash_builder.build_hasher();
        key.hash(&mut hasher);
        hasher.finish()
    }
}

#[test]
fn test() {
    let mut set = KeyedSet::new(|value: &(u64, u64)| value.0);
    assert_eq!(set.len(), 0);
    set.insert((0, 0));
    assert_eq!(set.insert((0, 1)), Some((0, 0)));
    assert_eq!(set.len(), 1);
    assert_eq!(set.get(&0), Some(&(0, 1)));
    assert!(set.get(&1).is_none());
    assert_eq!(*set.entry(12).get_or_insert_with(|k| (k, k)), (12, 12));
    dbg!(&set);
}