use crate::loom::sync::Mutex;
use crate::sync::watch;
#[cfg(all(tokio_unstable, feature = "tracing"))]
use crate::util::trace;
#[derive(Debug)]
pub struct Barrier {
state: Mutex<BarrierState>,
wait: watch::Receiver<usize>,
n: usize,
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span: tracing::Span,
}
#[derive(Debug)]
struct BarrierState {
waker: watch::Sender<usize>,
arrived: usize,
generation: usize,
}
impl Barrier {
#[track_caller]
pub fn new(mut n: usize) -> Barrier {
let (waker, wait) = crate::sync::watch::channel(0);
if n == 0 {
n = 1;
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
let resource_span = {
let location = std::panic::Location::caller();
let resource_span = tracing::trace_span!(
"runtime.resource",
concrete_type = "Barrier",
kind = "Sync",
loc.file = location.file(),
loc.line = location.line(),
loc.col = location.column(),
);
resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
size = n,
);
tracing::trace!(
target: "runtime::resource::state_update",
arrived = 0,
)
});
resource_span
};
Barrier {
state: Mutex::new(BarrierState {
waker,
arrived: 0,
generation: 1,
}),
n,
wait,
#[cfg(all(tokio_unstable, feature = "tracing"))]
resource_span,
}
}
pub async fn wait(&self) -> BarrierWaitResult {
#[cfg(all(tokio_unstable, feature = "tracing"))]
return trace::async_op(
|| self.wait_internal(),
self.resource_span.clone(),
"Barrier::wait",
"poll",
false,
)
.await;
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
return self.wait_internal().await;
}
async fn wait_internal(&self) -> BarrierWaitResult {
crate::trace::async_trace_leaf().await;
let generation = {
let mut state = self.state.lock();
let generation = state.generation;
state.arrived += 1;
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing::trace!(
target: "runtime::resource::state_update",
arrived = 1,
arrived.op = "add",
);
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing::trace!(
target: "runtime::resource::async_op::state_update",
arrived = true,
);
if state.arrived == self.n {
#[cfg(all(tokio_unstable, feature = "tracing"))]
tracing::trace!(
target: "runtime::resource::async_op::state_update",
is_leader = true,
);
state
.waker
.send(state.generation)
.expect("there is at least one receiver");
state.arrived = 0;
state.generation += 1;
return BarrierWaitResult(true);
}
generation
};
let mut wait = self.wait.clone();
loop {
let _ = wait.changed().await;
if *wait.borrow() >= generation {
break;
}
}
BarrierWaitResult(false)
}
}
#[derive(Debug, Clone)]
pub struct BarrierWaitResult(bool);
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.0
}
}