1use std::borrow::Borrow;
25use std::collections::hash_map::RandomState;
26use std::collections::{self, BTreeSet};
27use std::fmt::{Debug, Error, Formatter};
28use std::hash::{BuildHasher, Hash};
29use std::iter::{FromIterator, FusedIterator, Sum};
30use std::ops::{Add, Deref, Mul};
31
32use archery::{SharedPointer, SharedPointerKind};
33
34use crate::nodes::hamt::{hash_key, Drain as NodeDrain, HashValue, Iter as NodeIter, Node};
35use crate::ordset::GenericOrdSet;
36use crate::shared_ptr::DefaultSharedPtr;
37use crate::GenericVector;
38
39#[macro_export]
54macro_rules! hashset {
55 () => { $crate::hashset::HashSet::new() };
56
57 ( $($x:expr),* ) => {{
58 let mut l = $crate::hashset::HashSet::new();
59 $(
60 l.insert($x);
61 )*
62 l
63 }};
64
65 ( $($x:expr ,)* ) => {{
66 let mut l = $crate::hashset::HashSet::new();
67 $(
68 l.insert($x);
69 )*
70 l
71 }};
72}
73
74pub type HashSet<A> = GenericHashSet<A, RandomState, DefaultSharedPtr>;
80
81pub struct GenericHashSet<A, S, P: SharedPointerKind> {
100 hasher: S,
101 root: Option<SharedPointer<Node<Value<A>, P>, P>>,
102 size: usize,
103}
104
105#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
106struct Value<A>(A);
107
108impl<A> Deref for Value<A> {
109 type Target = A;
110 fn deref(&self) -> &Self::Target {
111 &self.0
112 }
113}
114
115impl<A> HashValue for Value<A>
118where
119 A: Hash + Eq,
120{
121 type Key = A;
122
123 fn extract_key(&self) -> &Self::Key {
124 &self.0
125 }
126
127 fn ptr_eq(&self, _other: &Self) -> bool {
128 false
129 }
130}
131
132impl<A, S, P> GenericHashSet<A, S, P>
133where
134 A: Hash + Eq + Clone,
135 S: BuildHasher + Default + Clone,
136 P: SharedPointerKind,
137{
138 #[inline]
150 #[must_use]
151 pub fn unit(a: A) -> Self {
152 GenericHashSet::new().update(a)
153 }
154}
155
156impl<A, S, P: SharedPointerKind> GenericHashSet<A, S, P> {
157 #[must_use]
159 pub fn new() -> Self
160 where
161 S: Default,
162 {
163 Self::default()
164 }
165
166 #[inline]
183 #[must_use]
184 pub fn is_empty(&self) -> bool {
185 self.len() == 0
186 }
187
188 #[inline]
200 #[must_use]
201 pub fn len(&self) -> usize {
202 self.size
203 }
204
205 pub fn ptr_eq(&self, other: &Self) -> bool {
215 match (&self.root, &other.root) {
216 (Some(a), Some(b)) => SharedPointer::ptr_eq(a, b),
217 (None, None) => true,
218 _ => false,
219 }
220 }
221
222 #[inline]
224 #[must_use]
225 pub fn with_hasher(hasher: S) -> Self {
226 GenericHashSet {
227 size: 0,
228 root: None,
229 hasher,
230 }
231 }
232
233 #[must_use]
237 pub fn hasher(&self) -> &S {
238 &self.hasher
239 }
240
241 #[inline]
243 #[must_use]
244 pub fn new_from<A2>(&self) -> GenericHashSet<A2, S, P>
245 where
246 A2: Hash + Eq + Clone,
247 S: Clone,
248 {
249 GenericHashSet {
250 size: 0,
251 root: None,
252 hasher: self.hasher.clone(),
253 }
254 }
255
256 pub fn clear(&mut self) {
273 self.root = None;
274 self.size = 0;
275 }
276
277 #[must_use]
285 pub fn iter(&self) -> Iter<'_, A, P> {
286 Iter {
287 it: NodeIter::new(self.root.as_deref(), self.size),
288 }
289 }
290}
291
292impl<A, S, P> GenericHashSet<A, S, P>
293where
294 A: Hash + Eq,
295 S: BuildHasher,
296 P: SharedPointerKind,
297{
298 fn test_eq<S2: BuildHasher, P2: SharedPointerKind>(
299 &self,
300 other: &GenericHashSet<A, S2, P2>,
301 ) -> bool {
302 if self.len() != other.len() {
303 return false;
304 }
305 let mut seen = collections::HashSet::new();
306 for value in self.iter() {
307 if !other.contains(value) {
308 return false;
309 }
310 seen.insert(value);
311 }
312 for value in other.iter() {
313 if !seen.contains(&value) {
314 return false;
315 }
316 }
317 true
318 }
319
320 #[must_use]
324 pub fn contains<BA>(&self, a: &BA) -> bool
325 where
326 BA: Hash + Eq + ?Sized,
327 A: Borrow<BA>,
328 {
329 if let Some(root) = &self.root {
330 root.get(hash_key(&self.hasher, a), 0, a).is_some()
331 } else {
332 false
333 }
334 }
335
336 #[must_use]
341 pub fn is_subset<RS>(&self, other: RS) -> bool
342 where
343 RS: Borrow<Self>,
344 {
345 let o = other.borrow();
346 self.iter().all(|a| o.contains(a))
347 }
348
349 #[must_use]
355 pub fn is_proper_subset<RS>(&self, other: RS) -> bool
356 where
357 RS: Borrow<Self>,
358 {
359 self.len() != other.borrow().len() && self.is_subset(other)
360 }
361}
362
363impl<A, S, P> GenericHashSet<A, S, P>
364where
365 A: Hash + Eq + Clone,
366 S: BuildHasher + Clone,
367 P: SharedPointerKind,
368{
369 #[inline]
373 pub fn insert(&mut self, a: A) -> Option<A> {
374 let hash = hash_key(&self.hasher, &a);
375 let root = SharedPointer::make_mut(self.root.get_or_insert_with(Default::default));
376 match root.insert(hash, 0, Value(a)) {
377 None => {
378 self.size += 1;
379 None
380 }
381 Some(Value(old_value)) => Some(old_value),
382 }
383 }
384
385 pub fn remove<BA>(&mut self, a: &BA) -> Option<A>
389 where
390 BA: Hash + Eq + ?Sized,
391 A: Borrow<BA>,
392 {
393 let root = SharedPointer::make_mut(self.root.get_or_insert_with(Default::default));
394 let result = root.remove(hash_key(&self.hasher, a), 0, a);
395 if result.is_some() {
396 self.size -= 1;
397 }
398 result.map(|v| v.0)
399 }
400
401 #[must_use]
419 pub fn update(&self, a: A) -> Self {
420 let mut out = self.clone();
421 out.insert(a);
422 out
423 }
424
425 #[must_use]
430 pub fn without<BA>(&self, a: &BA) -> Self
431 where
432 BA: Hash + Eq + ?Sized,
433 A: Borrow<BA>,
434 {
435 let mut out = self.clone();
436 out.remove(a);
437 out
438 }
439
440 pub fn retain<F>(&mut self, mut f: F)
460 where
461 F: FnMut(&A) -> bool,
462 {
463 let Some(root) = &mut self.root else {
464 return;
465 };
466 let old_root = root.clone();
467 let root = SharedPointer::make_mut(root);
468 for (value, hash) in NodeIter::new(Some(&old_root), self.size) {
469 if !f(value) && root.remove(hash, 0, value).is_some() {
470 self.size -= 1;
471 }
472 }
473 }
474
475 #[must_use]
490 pub fn union(self, other: Self) -> Self {
491 let (mut to_mutate, to_consume) = if self.len() >= other.len() {
492 (self, other)
493 } else {
494 (other, self)
495 };
496 for value in to_consume {
497 to_mutate.insert(value);
498 }
499 to_mutate
500 }
501
502 #[must_use]
506 pub fn unions<I>(i: I) -> Self
507 where
508 I: IntoIterator<Item = Self>,
509 S: Default,
510 {
511 i.into_iter().fold(Self::default(), Self::union)
512 }
513
514 #[deprecated(
534 since = "2.0.1",
535 note = "to avoid conflicting behaviors between std and imbl, the `difference` alias for `symmetric_difference` will be removed."
536 )]
537 #[must_use]
538 pub fn difference(self, other: Self) -> Self {
539 self.symmetric_difference(other)
540 }
541
542 #[must_use]
557 pub fn symmetric_difference(mut self, other: Self) -> Self {
558 for value in other {
559 if self.remove(&value).is_none() {
560 self.insert(value);
561 }
562 }
563 self
564 }
565
566 #[must_use]
582 pub fn relative_complement(mut self, other: Self) -> Self {
583 for value in other {
584 let _ = self.remove(&value);
585 }
586 self
587 }
588
589 #[must_use]
604 pub fn intersection(self, other: Self) -> Self {
605 let mut out = self.new_from();
606 for value in other {
607 if self.contains(&value) {
608 out.insert(value);
609 }
610 }
611 out
612 }
613}
614
615impl<A, S, P: SharedPointerKind> Clone for GenericHashSet<A, S, P>
618where
619 A: Clone,
620 S: Clone,
621 P: SharedPointerKind,
622{
623 #[inline]
627 fn clone(&self) -> Self {
628 GenericHashSet {
629 hasher: self.hasher.clone(),
630 root: self.root.clone(),
631 size: self.size,
632 }
633 }
634}
635
636impl<A, S1, P1, S2, P2> PartialEq<GenericHashSet<A, S2, P2>> for GenericHashSet<A, S1, P1>
637where
638 A: Hash + Eq,
639 S1: BuildHasher,
640 S2: BuildHasher,
641 P1: SharedPointerKind,
642 P2: SharedPointerKind,
643{
644 fn eq(&self, other: &GenericHashSet<A, S2, P2>) -> bool {
645 self.test_eq(other)
646 }
647}
648
649impl<A, S, P> Eq for GenericHashSet<A, S, P>
650where
651 A: Hash + Eq,
652 S: BuildHasher,
653 P: SharedPointerKind,
654{
655}
656
657impl<A, S, P> Default for GenericHashSet<A, S, P>
658where
659 S: Default,
660 P: SharedPointerKind,
661{
662 fn default() -> Self {
663 GenericHashSet {
664 hasher: Default::default(),
665 root: None,
666 size: 0,
667 }
668 }
669}
670
671impl<A, S, P> Add for GenericHashSet<A, S, P>
672where
673 A: Hash + Eq + Clone,
674 S: BuildHasher + Clone,
675 P: SharedPointerKind,
676{
677 type Output = GenericHashSet<A, S, P>;
678
679 fn add(self, other: Self) -> Self::Output {
680 self.union(other)
681 }
682}
683
684impl<A, S, P> Mul for GenericHashSet<A, S, P>
685where
686 A: Hash + Eq + Clone,
687 S: BuildHasher + Clone,
688 P: SharedPointerKind,
689{
690 type Output = GenericHashSet<A, S, P>;
691
692 fn mul(self, other: Self) -> Self::Output {
693 self.intersection(other)
694 }
695}
696
697impl<'a, A, S, P> Add for &'a GenericHashSet<A, S, P>
698where
699 A: Hash + Eq + Clone,
700 S: BuildHasher + Clone,
701 P: SharedPointerKind,
702{
703 type Output = GenericHashSet<A, S, P>;
704
705 fn add(self, other: Self) -> Self::Output {
706 self.clone().union(other.clone())
707 }
708}
709
710impl<'a, A, S, P> Mul for &'a GenericHashSet<A, S, P>
711where
712 A: Hash + Eq + Clone,
713 S: BuildHasher + Clone,
714 P: SharedPointerKind,
715{
716 type Output = GenericHashSet<A, S, P>;
717
718 fn mul(self, other: Self) -> Self::Output {
719 self.clone().intersection(other.clone())
720 }
721}
722
723impl<A, S, P: SharedPointerKind> Sum for GenericHashSet<A, S, P>
724where
725 A: Hash + Eq + Clone,
726 S: BuildHasher + Default + Clone,
727 P: SharedPointerKind,
728{
729 fn sum<I>(it: I) -> Self
730 where
731 I: Iterator<Item = Self>,
732 {
733 it.fold(Self::default(), |a, b| a + b)
734 }
735}
736
737impl<A, S, R, P: SharedPointerKind> Extend<R> for GenericHashSet<A, S, P>
738where
739 A: Hash + Eq + Clone + From<R>,
740 S: BuildHasher + Clone,
741{
742 fn extend<I>(&mut self, iter: I)
743 where
744 I: IntoIterator<Item = R>,
745 {
746 for value in iter {
747 self.insert(From::from(value));
748 }
749 }
750}
751
752impl<A, S, P> Debug for GenericHashSet<A, S, P>
753where
754 A: Hash + Eq + Debug,
755 S: BuildHasher,
756 P: SharedPointerKind,
757{
758 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
759 f.debug_set().entries(self.iter()).finish()
760 }
761}
762
763pub struct Iter<'a, A, P: SharedPointerKind> {
767 it: NodeIter<'a, Value<A>, P>,
768}
769
770impl<'a, A, P: SharedPointerKind> Clone for Iter<'a, A, P> {
772 fn clone(&self) -> Self {
773 Iter {
774 it: self.it.clone(),
775 }
776 }
777}
778
779impl<'a, A, P> Iterator for Iter<'a, A, P>
780where
781 A: 'a,
782 P: SharedPointerKind,
783{
784 type Item = &'a A;
785
786 fn next(&mut self) -> Option<Self::Item> {
787 self.it.next().map(|(v, _)| &v.0)
788 }
789
790 fn size_hint(&self) -> (usize, Option<usize>) {
791 self.it.size_hint()
792 }
793}
794
795impl<'a, A, P: SharedPointerKind> ExactSizeIterator for Iter<'a, A, P> {}
796
797impl<'a, A, P: SharedPointerKind> FusedIterator for Iter<'a, A, P> {}
798
799pub struct ConsumingIter<A, P>
801where
802 A: Hash + Eq + Clone,
803 P: SharedPointerKind,
804{
805 it: NodeDrain<Value<A>, P>,
806}
807
808impl<A, P> Iterator for ConsumingIter<A, P>
809where
810 A: Hash + Eq + Clone,
811 P: SharedPointerKind,
812{
813 type Item = A;
814
815 fn next(&mut self) -> Option<Self::Item> {
816 self.it.next().map(|(v, _)| v.0)
817 }
818
819 fn size_hint(&self) -> (usize, Option<usize>) {
820 self.it.size_hint()
821 }
822}
823
824impl<A, P> ExactSizeIterator for ConsumingIter<A, P>
825where
826 A: Hash + Eq + Clone,
827 P: SharedPointerKind,
828{
829}
830
831impl<A, P> FusedIterator for ConsumingIter<A, P>
832where
833 A: Hash + Eq + Clone,
834 P: SharedPointerKind,
835{
836}
837
838impl<A, RA, S, P> FromIterator<RA> for GenericHashSet<A, S, P>
841where
842 A: Hash + Eq + Clone + From<RA>,
843 S: BuildHasher + Default + Clone,
844 P: SharedPointerKind,
845{
846 fn from_iter<T>(i: T) -> Self
847 where
848 T: IntoIterator<Item = RA>,
849 {
850 let mut set = Self::default();
851 for value in i {
852 set.insert(From::from(value));
853 }
854 set
855 }
856}
857
858impl<'a, A, S, P> IntoIterator for &'a GenericHashSet<A, S, P>
859where
860 A: Hash + Eq,
861 S: BuildHasher,
862 P: SharedPointerKind,
863{
864 type Item = &'a A;
865 type IntoIter = Iter<'a, A, P>;
866
867 fn into_iter(self) -> Self::IntoIter {
868 self.iter()
869 }
870}
871
872impl<A, S, P> IntoIterator for GenericHashSet<A, S, P>
873where
874 A: Hash + Eq + Clone,
875 S: BuildHasher,
876 P: SharedPointerKind,
877{
878 type Item = A;
879 type IntoIter = ConsumingIter<Self::Item, P>;
880
881 fn into_iter(self) -> Self::IntoIter {
882 ConsumingIter {
883 it: NodeDrain::new(self.root, self.size),
884 }
885 }
886}
887
888impl<'s, 'a, A, OA, SA, SB, P1, P2> From<&'s GenericHashSet<&'a A, SA, P1>>
891 for GenericHashSet<OA, SB, P2>
892where
893 A: ToOwned<Owned = OA> + Hash + Eq + ?Sized,
894 OA: Borrow<A> + Hash + Eq + Clone,
895 SA: BuildHasher,
896 SB: BuildHasher + Default + Clone,
897 P1: SharedPointerKind,
898 P2: SharedPointerKind,
899{
900 fn from(set: &GenericHashSet<&A, SA, P1>) -> Self {
901 set.iter().map(|a| (*a).to_owned()).collect()
902 }
903}
904
905impl<A, S, const N: usize, P> From<[A; N]> for GenericHashSet<A, S, P>
906where
907 A: Hash + Eq + Clone,
908 S: BuildHasher + Default + Clone,
909 P: SharedPointerKind,
910{
911 fn from(arr: [A; N]) -> Self {
912 IntoIterator::into_iter(arr).collect()
913 }
914}
915
916impl<'a, A, S, P> From<&'a [A]> for GenericHashSet<A, S, P>
917where
918 A: Hash + Eq + Clone,
919 S: BuildHasher + Default + Clone,
920 P: SharedPointerKind,
921{
922 fn from(slice: &'a [A]) -> Self {
923 slice.iter().cloned().collect()
924 }
925}
926
927impl<A, S, P> From<Vec<A>> for GenericHashSet<A, S, P>
928where
929 A: Hash + Eq + Clone,
930 S: BuildHasher + Default + Clone,
931 P: SharedPointerKind,
932{
933 fn from(vec: Vec<A>) -> Self {
934 vec.into_iter().collect()
935 }
936}
937
938impl<'a, A, S, P> From<&'a Vec<A>> for GenericHashSet<A, S, P>
939where
940 A: Hash + Eq + Clone,
941 S: BuildHasher + Default + Clone,
942 P: SharedPointerKind,
943{
944 fn from(vec: &Vec<A>) -> Self {
945 vec.iter().cloned().collect()
946 }
947}
948
949impl<A, S, P1, P2> From<GenericVector<A, P2>> for GenericHashSet<A, S, P1>
950where
951 A: Hash + Eq + Clone,
952 S: BuildHasher + Default + Clone,
953 P1: SharedPointerKind,
954 P2: SharedPointerKind,
955{
956 fn from(vector: GenericVector<A, P2>) -> Self {
957 vector.into_iter().collect()
958 }
959}
960
961impl<'a, A, S, P1, P2> From<&'a GenericVector<A, P2>> for GenericHashSet<A, S, P1>
962where
963 A: Hash + Eq + Clone,
964 S: BuildHasher + Default + Clone,
965 P1: SharedPointerKind,
966 P2: SharedPointerKind,
967{
968 fn from(vector: &GenericVector<A, P2>) -> Self {
969 vector.iter().cloned().collect()
970 }
971}
972
973impl<A, S, P> From<collections::HashSet<A>> for GenericHashSet<A, S, P>
974where
975 A: Eq + Hash + Clone,
976 S: BuildHasher + Default + Clone,
977 P: SharedPointerKind,
978{
979 fn from(hash_set: collections::HashSet<A>) -> Self {
980 hash_set.into_iter().collect()
981 }
982}
983
984impl<'a, A, S, P> From<&'a collections::HashSet<A>> for GenericHashSet<A, S, P>
985where
986 A: Eq + Hash + Clone,
987 S: BuildHasher + Default + Clone,
988 P: SharedPointerKind,
989{
990 fn from(hash_set: &collections::HashSet<A>) -> Self {
991 hash_set.iter().cloned().collect()
992 }
993}
994
995impl<'a, A, S, P> From<&'a BTreeSet<A>> for GenericHashSet<A, S, P>
996where
997 A: Hash + Eq + Clone,
998 S: BuildHasher + Default + Clone,
999 P: SharedPointerKind,
1000{
1001 fn from(btree_set: &BTreeSet<A>) -> Self {
1002 btree_set.iter().cloned().collect()
1003 }
1004}
1005
1006impl<A, S, P1, P2> From<GenericOrdSet<A, P2>> for GenericHashSet<A, S, P1>
1007where
1008 A: Ord + Hash + Eq + Clone,
1009 S: BuildHasher + Default + Clone,
1010 P1: SharedPointerKind,
1011 P2: SharedPointerKind,
1012{
1013 fn from(ordset: GenericOrdSet<A, P2>) -> Self {
1014 ordset.into_iter().collect()
1015 }
1016}
1017
1018impl<'a, A, S, P1, P2> From<&'a GenericOrdSet<A, P2>> for GenericHashSet<A, S, P1>
1019where
1020 A: Ord + Hash + Eq + Clone,
1021 S: BuildHasher + Default + Clone,
1022 P1: SharedPointerKind,
1023 P2: SharedPointerKind,
1024{
1025 fn from(ordset: &GenericOrdSet<A, P2>) -> Self {
1026 ordset.into_iter().cloned().collect()
1027 }
1028}
1029
1030#[cfg(any(test, feature = "proptest"))]
1032#[doc(hidden)]
1033pub mod proptest {
1034 #[deprecated(
1035 since = "14.3.0",
1036 note = "proptest strategies have moved to imbl::proptest"
1037 )]
1038 pub use crate::proptest::hash_set;
1039}
1040
1041#[cfg(test)]
1042mod test {
1043 use super::proptest::*;
1044 use super::*;
1045 use crate::test::LolHasher;
1046 use ::proptest::num::i16;
1047 use ::proptest::proptest;
1048 use static_assertions::{assert_impl_all, assert_not_impl_any};
1049 use std::hash::BuildHasherDefault;
1050
1051 assert_impl_all!(HashSet<i32>: Send, Sync);
1052 assert_not_impl_any!(HashSet<*const i32>: Send, Sync);
1053 assert_covariant!(HashSet<T> in T);
1054
1055 #[test]
1056 fn insert_failing() {
1057 let mut set: GenericHashSet<i16, BuildHasherDefault<LolHasher>, DefaultSharedPtr> =
1058 Default::default();
1059 set.insert(14658);
1060 assert_eq!(1, set.len());
1061 set.insert(-19198);
1062 assert_eq!(2, set.len());
1063 }
1064
1065 #[test]
1066 fn match_strings_with_string_slices() {
1067 let mut set: HashSet<String> = From::from(&hashset!["foo", "bar"]);
1068 set = set.without("bar");
1069 assert!(!set.contains("bar"));
1070 set.remove("foo");
1071 assert!(!set.contains("foo"));
1072 }
1073
1074 #[test]
1075 fn macro_allows_trailing_comma() {
1076 let set1 = hashset! {"foo", "bar"};
1077 let set2 = hashset! {
1078 "foo",
1079 "bar",
1080 };
1081 assert_eq!(set1, set2);
1082 }
1083
1084 #[test]
1085 fn issue_60_drain_iterator_memory_corruption() {
1086 use crate::test::MetroHashBuilder;
1087 for i in 0..1000 {
1088 let mut lhs = vec![0, 1, 2];
1089 lhs.sort_unstable();
1090
1091 let hasher = MetroHashBuilder::new(i);
1092 let mut iset: GenericHashSet<_, MetroHashBuilder, DefaultSharedPtr> =
1093 GenericHashSet::with_hasher(hasher.clone());
1094 for &i in &lhs {
1095 iset.insert(i);
1096 }
1097
1098 let mut rhs: Vec<_> = iset.clone().into_iter().collect();
1099 rhs.sort_unstable();
1100
1101 if lhs != rhs {
1102 println!("iteration: {}", i);
1103 println!("seed: {}", hasher.seed());
1104 println!("lhs: {}: {:?}", lhs.len(), &lhs);
1105 println!("rhs: {}: {:?}", rhs.len(), &rhs);
1106 panic!();
1107 }
1108 }
1109 }
1110
1111 proptest! {
1112 #[test]
1113 fn proptest_a_set(ref s in hash_set(".*", 10..100)) {
1114 assert!(s.len() < 100);
1115 assert!(s.len() >= 10);
1116 }
1117 }
1118}