#![doc(html_root_url = "https://docs.rs/arc-swap/0.1.3/arc-swap/")]
#![deny(missing_docs)]
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
use std::marker::PhantomData;
use std::mem;
use std::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
struct GenLock(usize);
impl Drop for GenLock {
fn drop(&mut self) {
unreachable!("Forgot to unlock generation");
}
}
const GEN_CNT: usize = 2;
fn strip<T>(arc: Arc<T>) -> *mut T {
Arc::into_raw(arc) as *mut T
}
pub struct ArcSwap<T> {
ptr: AtomicPtr<T>,
gen_idx: AtomicUsize,
reader_group_cnts: [AtomicUsize; GEN_CNT],
_phantom_arc: PhantomData<Arc<T>>,
}
impl<T> From<Arc<T>> for ArcSwap<T> {
fn from(arc: Arc<T>) -> Self {
let ptr = strip(arc);
Self {
ptr: AtomicPtr::new(ptr),
gen_idx: AtomicUsize::new(0),
reader_group_cnts: [AtomicUsize::new(0), AtomicUsize::new(0)],
_phantom_arc: PhantomData,
}
}
}
impl<T> Drop for ArcSwap<T> {
fn drop(&mut self) {
let ptr = *self.ptr.get_mut();
drop(unsafe { Arc::from_raw(ptr) });
}
}
impl<T> Clone for ArcSwap<T> {
fn clone(&self) -> Self {
Self::from(self.load())
}
}
impl<T: Debug> Debug for ArcSwap<T> {
fn fmt(&self, formatter: &mut Formatter) -> FmtResult {
self.load().fmt(formatter)
}
}
impl<T: Display> Display for ArcSwap<T> {
fn fmt(&self, formatter: &mut Formatter) -> FmtResult {
self.load().fmt(formatter)
}
}
impl<T> ArcSwap<T> {
#[inline]
pub fn load(&self) -> Arc<T> {
let gen = self.gen_lock();
let ptr = self.ptr.load(Ordering::Acquire);
let arc = unsafe { Arc::from_raw(ptr) };
Arc::into_raw(Arc::clone(&arc));
self.gen_unlock(gen);
arc
}
#[inline]
pub fn store(&self, arc: Arc<T>) {
drop(self.swap(arc));
}
#[inline]
pub fn swap(&self, arc: Arc<T>) -> Arc<T> {
let new = strip(arc);
let old = self.ptr.swap(new, Ordering::SeqCst);
self.wait_for_readers();
unsafe { Arc::from_raw(old) }
}
#[inline]
pub fn compare_and_swap(&self, current: Arc<T>, new: Arc<T>) -> (bool, Arc<T>) {
let current = strip(current);
let new = strip(new);
let gen = self.gen_lock();
let previous = self.ptr.compare_and_swap(current, new, Ordering::SeqCst);
let swapped = current == previous;
let previous = unsafe { Arc::from_raw(previous) };
if swapped {
} else {
Arc::into_raw(Arc::clone(&previous));
}
self.gen_unlock(gen);
if swapped {
self.wait_for_readers();
} else {
drop(unsafe { Arc::from_raw(new) });
}
drop(unsafe { Arc::from_raw(current) });
(swapped, previous)
}
#[inline]
fn wait_for_readers(&self) {
let mut seen_group = [false; GEN_CNT];
while !seen_group.iter().all(|seen| *seen) {
let gen = self.gen_idx.load(Ordering::Relaxed);
let groups = [
self.reader_group_cnts[0].load(Ordering::Acquire),
self.reader_group_cnts[1].load(Ordering::Acquire),
];
let next_gen = gen.wrapping_add(1);
if groups[next_gen % GEN_CNT] == 0 {
self.gen_idx
.compare_and_swap(gen, next_gen, Ordering::Relaxed);
}
for i in 0..GEN_CNT {
seen_group[i] = seen_group[i] || (groups[i] == 0);
}
atomic::spin_loop_hint();
}
}
#[inline]
fn gen_lock(&self) -> GenLock {
let gen = self.gen_idx.load(Ordering::Relaxed) % GEN_CNT;
self.reader_group_cnts[gen].fetch_add(1, Ordering::SeqCst);
GenLock(gen)
}
#[inline]
fn gen_unlock(&self, lock: GenLock) {
let gen = lock.0;
mem::forget(lock);
self.reader_group_cnts[gen].fetch_sub(1, Ordering::Release);
}
pub fn rcu<R, F>(&self, mut f: F) -> Arc<T>
where
F: FnMut(&Arc<T>) -> R,
R: Into<Arc<T>>,
{
let mut cur = self.load();
loop {
let new = f(&cur).into();
let (swapped, prev) = self.compare_and_swap(cur, new);
if swapped {
return prev;
} else {
cur = prev;
}
}
}
pub fn rcu_unwrap<R, F>(&self, mut f: F) -> T
where
F: FnMut(&T) -> R,
R: Into<Arc<T>>,
{
let mut wrapped = self.rcu(|prev| f(&*prev));
loop {
match Arc::try_unwrap(wrapped) {
Ok(val) => return val,
Err(w) => {
wrapped = w;
thread::yield_now();
}
}
}
}
}
#[cfg(test)]
mod tests {
extern crate crossbeam_utils;
use std::sync::atomic::AtomicUsize;
use std::sync::Barrier;
use self::crossbeam_utils::scoped as thread;
use super::*;
#[test]
fn publish() {
for _ in 0..100 {
let config = ArcSwap::from(Arc::new(String::default()));
let ended = AtomicUsize::new(0);
thread::scope(|scope| {
for _ in 0..20 {
scope.spawn(|| loop {
let cfg = config.load();
if !cfg.is_empty() {
assert_eq!(*cfg, "New configuration");
ended.fetch_add(1, Ordering::Relaxed);
return;
}
atomic::spin_loop_hint();
});
}
scope.spawn(|| {
let new_conf = Arc::new("New configuration".to_owned());
config.store(new_conf);
});
});
assert_eq!(20, ended.load(Ordering::Relaxed));
assert_eq!(2, Arc::strong_count(&config.load()));
assert_eq!(0, Arc::weak_count(&config.load()));
}
}
#[test]
fn swap_load() {
for _ in 0..100 {
let arc = Arc::new(42);
let arc_swap = ArcSwap::from(Arc::clone(&arc));
assert_eq!(42, *arc_swap.load());
assert_eq!(42, *arc_swap.load());
let new_arc = Arc::new(0);
assert_eq!(42, *arc_swap.swap(Arc::clone(&new_arc)));
assert_eq!(0, *arc_swap.load());
assert_eq!(3, Arc::strong_count(&arc_swap.load()));
assert_eq!(0, Arc::weak_count(&arc_swap.load()));
assert_eq!(1, Arc::strong_count(&arc));
assert_eq!(0, Arc::weak_count(&arc));
}
}
#[test]
fn multi_writers() {
let first_value = Arc::new((0, 0));
let shared = ArcSwap::from(Arc::clone(&first_value));
const WRITER_CNT: usize = 2;
const READER_CNT: usize = 3;
const ITERATIONS: usize = 100;
const SEQ: usize = 50;
let barrier = Barrier::new(READER_CNT + WRITER_CNT);
thread::scope(|scope| {
for w in 0..WRITER_CNT {
let barrier = &barrier;
let shared = &shared;
let first_value = &first_value;
scope.spawn(move || {
for _ in 0..ITERATIONS {
barrier.wait();
shared.store(Arc::clone(&first_value));
barrier.wait();
for i in 0..SEQ {
shared.store(Arc::new((w, i + 1)));
}
}
});
}
for _ in 0..READER_CNT {
scope.spawn(|| {
for _ in 0..ITERATIONS {
barrier.wait();
barrier.wait();
let mut previous = [0; 2];
let mut last = Arc::clone(&first_value);
loop {
let cur = shared.load();
if Arc::ptr_eq(&last, &cur) {
atomic::spin_loop_hint();
continue;
}
let (w, s) = *cur;
assert!(previous[w] < s);
previous[w] = s;
last = cur;
if s == SEQ {
break;
}
}
}
});
}
});
}
#[test]
fn cas_ref_cnt() {
const ITERATIONS: usize = 50;
let shared = ArcSwap::from(Arc::new(0));
for i in 0..ITERATIONS {
let orig = shared.load();
assert_eq!(i, *orig);
assert_eq!(2, Arc::strong_count(&orig));
let n1 = Arc::new(i + 1);
let (swapped, prev) = shared.compare_and_swap(Arc::clone(&orig), Arc::clone(&n1));
assert!(swapped);
assert!(Arc::ptr_eq(&orig, &prev));
assert_eq!(2, Arc::strong_count(&orig));
assert_eq!(2, Arc::strong_count(&n1));
assert_eq!(i + 1, *shared.load());
let n2 = Arc::new(i);
drop(prev);
let (swapped, prev) = shared.compare_and_swap(Arc::clone(&orig), Arc::clone(&n2));
assert!(!swapped);
assert!(Arc::ptr_eq(&n1, &prev));
assert_eq!(1, Arc::strong_count(&orig));
assert_eq!(3, Arc::strong_count(&n1));
assert_eq!(1, Arc::strong_count(&n2));
assert_eq!(i + 1, *shared.load());
}
}
#[test]
fn rcu() {
const ITERATIONS: usize = 50;
const THREADS: usize = 10;
let shared = ArcSwap::from(Arc::new(0));
thread::scope(|scope| {
for _ in 0..THREADS {
scope.spawn(|| {
for _ in 0..ITERATIONS {
shared.rcu(|old| **old + 1);
}
});
}
});
assert_eq!(THREADS * ITERATIONS, *shared.load());
}
#[test]
fn rcu_unwrap() {
const ITERATIONS: usize = 50;
const THREADS: usize = 10;
let shared = ArcSwap::from(Arc::new(0));
thread::scope(|scope| {
for _ in 0..THREADS {
scope.spawn(|| {
for _ in 0..ITERATIONS {
shared.rcu_unwrap(|old| *old + 1);
}
});
}
});
assert_eq!(THREADS * ITERATIONS, *shared.load());
}
}