use event_listener::{Event, EventListener};
use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
use core::fmt;
use core::pin::Pin;
use core::task::Poll;
use crate::futures::Lock;
use crate::Mutex;
#[derive(Debug)]
pub struct Barrier {
n: usize,
state: Mutex<State>,
event: Event,
}
#[derive(Debug)]
struct State {
count: usize,
generation_id: u64,
}
impl Barrier {
pub const fn new(n: usize) -> Barrier {
Barrier {
n,
state: Mutex::new(State {
count: 0,
generation_id: 0,
}),
event: Event::new(),
}
}
pub fn wait(&self) -> BarrierWait<'_> {
BarrierWait::_new(BarrierWaitInner {
barrier: self,
lock: Some(self.state.lock()),
evl: EventListener::new(),
state: WaitState::Initial,
})
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn wait_blocking(&self) -> BarrierWaitResult {
self.wait().wait()
}
}
easy_wrapper! {
pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult);
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}
pin_project_lite::pin_project! {
struct BarrierWaitInner<'a> {
barrier: &'a Barrier,
#[pin]
lock: Option<Lock<'a, State>>,
#[pin]
evl: EventListener,
state: WaitState,
}
}
impl fmt::Debug for BarrierWait<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BarrierWait { .. }")
}
}
enum WaitState {
Initial,
Waiting { local_gen: u64 },
Reacquiring { local_gen: u64 },
}
impl EventListenerFuture for BarrierWaitInner<'_> {
type Output = BarrierWaitResult;
fn poll_with_strategy<'a, S: Strategy<'a>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state {
WaitState::Initial => {
let mut state = ready!(this
.lock
.as_mut()
.as_pin_mut()
.unwrap()
.poll_with_strategy(strategy, cx));
this.lock.as_mut().set(None);
let local_gen = state.generation_id;
state.count += 1;
if state.count < this.barrier.n {
this.evl.as_mut().listen(&this.barrier.event);
*this.state = WaitState::Waiting { local_gen };
} else {
state.count = 0;
state.generation_id = state.generation_id.wrapping_add(1);
this.barrier.event.notify(core::usize::MAX);
return Poll::Ready(BarrierWaitResult { is_leader: true });
}
}
WaitState::Waiting { local_gen } => {
ready!(strategy.poll(this.evl.as_mut(), cx));
this.lock.as_mut().set(Some(this.barrier.state.lock()));
*this.state = WaitState::Reacquiring {
local_gen: *local_gen,
};
}
WaitState::Reacquiring { local_gen } => {
let state = ready!(this
.lock
.as_mut()
.as_pin_mut()
.unwrap()
.poll_with_strategy(strategy, cx));
this.lock.set(None);
if *local_gen == state.generation_id && state.count < this.barrier.n {
this.evl.as_mut().listen(&this.barrier.event);
*this.state = WaitState::Waiting {
local_gen: *local_gen,
};
} else {
return Poll::Ready(BarrierWaitResult { is_leader: false });
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct BarrierWaitResult {
is_leader: bool,
}
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.is_leader
}
}