use std::ops::Deref;
use std::ptr::{self, NonNull, addr_of};
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{self, Relaxed};
use crate::collectible::{Collectible, Link};
use crate::collector::Collector;
pub(super) struct RefCounted<T> {
instance: T,
next_or_refcnt: Link,
}
impl<T> RefCounted<T> {
#[inline]
pub(super) fn new_shared(instance: T) -> NonNull<RefCounted<T>> {
let boxed = Box::new(Self {
instance,
next_or_refcnt: Link::new_shared(),
});
unsafe { NonNull::new_unchecked(Box::into_raw(boxed)) }
}
#[inline]
pub(super) fn new_unique(instance: T) -> NonNull<RefCounted<T>> {
let boxed = Box::new(Self {
instance,
next_or_refcnt: Link::new_unique(),
});
unsafe { NonNull::new_unchecked(Box::into_raw(boxed)) }
}
#[inline]
pub(super) fn try_add_ref(&self, order: Ordering) -> bool {
self.ref_cnt()
.fetch_update(
order,
order,
|r| {
if r & 1 == 1 { Some(r + 2) } else { None }
},
)
.is_ok()
}
#[inline]
pub(super) fn get_mut_shared(&mut self) -> Option<&mut T> {
if self.ref_cnt().load(Relaxed) == 1 {
Some(&mut self.instance)
} else {
None
}
}
#[inline]
pub(super) fn get_mut_unique(&mut self) -> &mut T {
debug_assert_eq!(self.ref_cnt().load(Relaxed), 0);
&mut self.instance
}
#[inline]
pub(super) fn add_ref(&self) {
let mut current = self.ref_cnt().load(Relaxed);
loop {
debug_assert_eq!(current & 1, 1);
debug_assert!(current <= usize::MAX - 2, "reference count overflow");
match self
.ref_cnt()
.compare_exchange_weak(current, current + 2, Relaxed, Relaxed)
{
Ok(_) => break,
Err(actual) => {
current = actual;
}
}
}
}
#[inline]
pub(super) fn drop_ref(&self) -> bool {
let mut current = self.ref_cnt().load(Relaxed);
loop {
debug_assert_ne!(current, 0);
let new = if current <= 1 { 0 } else { current - 2 };
match self
.ref_cnt()
.compare_exchange_weak(current, new, Relaxed, Relaxed)
{
Ok(_) => break,
Err(actual) => {
current = actual;
}
}
}
current == 1
}
#[inline]
pub(super) fn inst_ptr(self_ptr: *const Self) -> *const T {
if self_ptr.is_null() {
ptr::null()
} else {
unsafe { addr_of!((*self_ptr).instance) }
}
}
#[inline]
pub(super) fn ref_cnt(&self) -> &AtomicUsize {
self.next_or_refcnt.ref_cnt()
}
#[inline]
pub(super) fn pass_to_collector(ptr: *mut Self) {
Collector::collect(Collector::current().as_ptr(), ptr as *mut dyn Collectible);
}
}
impl<T> Deref for RefCounted<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.instance
}
}
impl<T> Collectible for RefCounted<T> {
#[inline]
fn next_ptr(&self) -> Option<NonNull<dyn Collectible>> {
self.next_or_refcnt.next_ptr()
}
#[inline]
fn set_next_ptr(&self, next_ptr: Option<NonNull<dyn Collectible>>) {
self.next_or_refcnt.set_next_ptr(next_ptr);
}
}