use crate::job::{ArcJob, StackJob};
use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
mod test;
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
unsafe { broadcast_in(op, &Registry::current()) }
}
pub fn spawn_broadcast<OP>(op: OP)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
unsafe { spawn_broadcast_in(op, &Registry::current()) }
}
pub struct BroadcastContext<'a> {
worker: &'a WorkerThread,
_marker: PhantomData<&'a mut dyn Fn()>,
}
impl<'a> BroadcastContext<'a> {
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
f(BroadcastContext {
worker: unsafe { &*worker_thread },
_marker: PhantomData,
})
}
#[inline]
pub fn index(&self) -> usize {
self.worker.index()
}
#[inline]
pub fn num_threads(&self) -> usize {
self.worker.registry().num_threads()
}
}
impl<'a> fmt::Debug for BroadcastContext<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BroadcastContext")
.field("index", &self.index())
.field("num_threads", &self.num_threads())
.field("pool_id", &self.worker.registry().id())
.finish()
}
}
pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let f = move |injected: bool| {
debug_assert!(injected);
BroadcastContext::with(&op)
};
let n_threads = registry.num_threads();
let current_thread = WorkerThread::current().as_ref();
let latch = ScopeLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect();
let job_refs = jobs.iter().map(|job| job.as_job_ref());
registry.inject_broadcast(job_refs);
latch.wait(current_thread);
jobs.into_iter().map(|job| job.into_result()).collect()
}
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
registry.catch_unwind(|| BroadcastContext::with(&op));
registry.terminate(); }
});
let n_threads = registry.num_threads();
let job_refs = (0..n_threads).map(|_| {
registry.increment_terminate_count();
ArcJob::as_static_job_ref(&job)
});
registry.inject_broadcast(job_refs);
}