use broadcaster::BroadcastChannel;
use crate::sync::Mutex;
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
#[derive(Debug)]
pub struct Barrier {
state: Mutex<BarrierState>,
wait: BroadcastChannel<(usize, usize)>,
n: usize,
}
#[derive(Debug)]
struct BarrierState {
waker: BroadcastChannel<(usize, usize)>,
count: usize,
generation_id: usize,
}
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
#[derive(Debug, Clone)]
pub struct BarrierWaitResult(bool);
impl Barrier {
pub fn new(mut n: usize) -> Barrier {
let waker = BroadcastChannel::new();
let wait = waker.clone();
if n == 0 {
n = 1;
}
Barrier {
state: Mutex::new(BarrierState {
waker,
count: 0,
generation_id: 1,
}),
n,
wait,
}
}
pub async fn wait(&self) -> BarrierWaitResult {
let mut lock = self.state.lock().await;
let local_gen = lock.generation_id;
lock.count += 1;
if lock.count < self.n {
let mut wait = self.wait.clone();
let mut generation_id = lock.generation_id;
let mut count = lock.count;
drop(lock);
while local_gen == generation_id && count < self.n {
let (g, c) = wait.recv().await.expect("sender has not been closed");
generation_id = g;
count = c;
}
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
lock.waker
.send(&(lock.generation_id, lock.count))
.await
.expect("there should be at least one receiver");
BarrierWaitResult(true)
}
}
}
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.0
}
}
#[cfg(test)]
mod test {
use futures::channel::mpsc::unbounded;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use crate::sync::{Arc, Barrier};
use crate::task;
#[test]
fn test_barrier() {
for _ in 0..1_000 {
task::block_on(async move {
const N: usize = 10;
let barrier = Arc::new(Barrier::new(N));
let (tx, mut rx) = unbounded();
for _ in 0..N - 1 {
let c = barrier.clone();
let mut tx = tx.clone();
task::spawn(async move {
let res = c.wait().await;
tx.send(res.is_leader()).await.unwrap();
});
}
let res = rx.try_next();
assert!(match res {
Err(_err) => true,
_ => false,
});
let mut leader_found = barrier.wait().await.is_leader();
for _ in 0..N - 1 {
if rx.next().await.unwrap() {
assert!(!leader_found);
leader_found = true;
}
}
assert!(leader_found);
});
}
}
}