use ::{Configuration, ExitHandler, PanicHandler, StartHandler};
use coco::deque::{self, Worker, Stealer};
use job::{JobRef, StackJob};
#[cfg(rayon_unstable)]
use job::Job;
#[cfg(rayon_unstable)]
use internal::task::Task;
use latch::{LatchProbe, Latch, CountLatch, LockLatch, SpinLatch, TickleLatch};
use log::Event::*;
use rand::{self, Rng};
use sleep::Sleep;
use std::any::Any;
use std::error::Error;
use std::cell::{Cell, UnsafeCell};
use std::sync::{Arc, Mutex, Once, ONCE_INIT};
use std::thread;
use std::mem;
use std::fmt;
use std::u32;
use std::usize;
use unwind;
use util::leak;
#[derive(Debug,PartialEq)]
struct GlobalPoolAlreadyInitialized;
impl fmt::Display for GlobalPoolAlreadyInitialized {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.description())
}
}
impl Error for GlobalPoolAlreadyInitialized {
fn description(&self) -> &str {
"The global thread pool has already been initialized."
}
}
pub struct Registry {
thread_infos: Vec<ThreadInfo>,
state: Mutex<RegistryState>,
sleep: Sleep,
job_uninjector: Stealer<JobRef>,
panic_handler: Option<Box<PanicHandler>>,
start_handler: Option<Box<StartHandler>>,
exit_handler: Option<Box<ExitHandler>>,
terminate_latch: CountLatch,
}
struct RegistryState {
job_injector: Worker<JobRef>,
}
static mut THE_REGISTRY: Option<&'static Arc<Registry>> = None;
static THE_REGISTRY_SET: Once = ONCE_INIT;
fn global_registry() -> &'static Arc<Registry> {
THE_REGISTRY_SET.call_once(|| unsafe { init_registry(Configuration::new()).unwrap() });
unsafe { THE_REGISTRY.expect("The global thread pool has not been initialized.") }
}
pub fn init_global_registry(config: Configuration) -> Result<&'static Registry, Box<Error>> {
let mut called = false;
let mut init_result = Ok(());;
THE_REGISTRY_SET.call_once(|| unsafe {
init_result = init_registry(config);
called = true;
});
if called {
init_result.map(|()| &**global_registry())
} else {
Err(Box::new(GlobalPoolAlreadyInitialized))
}
}
unsafe fn init_registry(config: Configuration) -> Result<(), Box<Error>> {
Registry::new(config).map(|registry| THE_REGISTRY = Some(leak(registry)))
}
struct Terminator<'a>(&'a Arc<Registry>);
impl<'a> Drop for Terminator<'a> {
fn drop(&mut self) {
self.0.terminate()
}
}
impl Registry {
pub fn new(mut configuration: Configuration) -> Result<Arc<Registry>, Box<Error>> {
let n_threads = configuration.get_num_threads();
let breadth_first = configuration.get_breadth_first();
let (inj_worker, inj_stealer) = deque::new();
let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads).map(|_| deque::new()).unzip();
let registry = Arc::new(Registry {
thread_infos: stealers.into_iter()
.map(|s| ThreadInfo::new(s))
.collect(),
state: Mutex::new(RegistryState::new(inj_worker)),
sleep: Sleep::new(),
job_uninjector: inj_stealer,
terminate_latch: CountLatch::new(),
panic_handler: configuration.take_panic_handler(),
start_handler: configuration.take_start_handler(),
exit_handler: configuration.take_exit_handler(),
});
let t1000 = Terminator(®istry);
for (index, worker) in workers.into_iter().enumerate() {
let registry = registry.clone();
let mut b = thread::Builder::new();
if let Some(name) = configuration.get_thread_name(index) {
b = b.name(name);
}
if let Some(stack_size) = configuration.get_stack_size() {
b = b.stack_size(stack_size);
}
try!(b.spawn(move || unsafe { main_loop(worker, registry, index, breadth_first) }));
}
mem::forget(t1000);
Ok(registry.clone())
}
#[cfg(rayon_unstable)]
pub fn global() -> Arc<Registry> {
global_registry().clone()
}
pub fn current() -> Arc<Registry> {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().clone()
} else {
(*worker_thread).registry.clone()
}
}
}
pub fn current_num_threads() -> usize {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().num_threads()
} else {
(*worker_thread).registry.num_threads()
}
}
}
pub fn id(&self) -> RegistryId {
RegistryId { addr: self as *const Self as usize }
}
pub fn num_threads(&self) -> usize {
self.thread_infos.len()
}
pub fn handle_panic(&self, err: Box<Any + Send>) {
match self.panic_handler {
Some(ref handler) => {
let abort_guard = unwind::AbortIfPanic;
handler(err);
mem::forget(abort_guard);
}
None => {
let _ = unwind::AbortIfPanic; }
}
}
pub fn wait_until_primed(&self) {
for info in &self.thread_infos {
info.primed.wait();
}
}
#[cfg(test)]
pub fn wait_until_stopped(&self) {
for info in &self.thread_infos {
info.stopped.wait();
}
}
pub fn inject_or_push(&self, job_ref: JobRef) {
let worker_thread = WorkerThread::current();
unsafe {
if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
(*worker_thread).push(job_ref);
} else {
self.inject(&[job_ref]);
}
}
}
#[cfg(rayon_unstable)]
pub unsafe fn submit_task<T>(&self, task: Arc<T>)
where T: Task
{
let task_job = TaskJob::new(task);
let task_job_ref = TaskJob::into_job_ref(task_job);
return self.inject_or_push(task_job_ref);
struct TaskJob<T: Task> {
_data: T
}
impl<T: Task> TaskJob<T> {
fn new(arc: Arc<T>) -> Arc<Self> {
unsafe { mem::transmute(arc) }
}
pub fn into_task(this: Arc<TaskJob<T>>) -> Arc<T> {
unsafe { mem::transmute(this) }
}
unsafe fn into_job_ref(this: Arc<Self>) -> JobRef {
let this: *const Self = mem::transmute(this);
JobRef::new(this)
}
}
impl<T: Task> Job for TaskJob<T> {
unsafe fn execute(this: *const Self) {
let this: Arc<Self> = mem::transmute(this);
let task: Arc<T> = TaskJob::into_task(this);
Task::execute(task);
}
}
}
pub fn inject(&self, injected_jobs: &[JobRef]) {
log!(InjectJobs { count: injected_jobs.len() });
{
let state = self.state.lock().unwrap();
assert!(!self.terminate_latch.probe(), "inject() sees state.terminate as true");
for &job_ref in injected_jobs {
state.job_injector.push(job_ref);
}
}
self.sleep.tickle(usize::MAX);
}
fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> {
let stolen = self.job_uninjector.steal();
if stolen.is_some() {
log!(UninjectedWork { worker: worker_index });
}
stolen
}
pub fn in_worker<OP, R>(&self, op: OP) -> R
where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send
{
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
self.in_worker_cold(op)
} else if (*worker_thread).registry().id() != self.id() {
self.in_worker_cross(&*worker_thread, op)
} else {
op(&*worker_thread, false)
}
}
}
#[cold]
unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send
{
debug_assert!(WorkerThread::current().is_null());
let job = StackJob::new(|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
}, LockLatch::new());
self.inject(&[job.as_job_ref()]);
job.latch.wait();
job.into_result()
}
#[cold]
unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send
{
debug_assert!(current_thread.registry().id() != self.id());
let latch = TickleLatch::new(SpinLatch::new(), ¤t_thread.registry().sleep);
let job = StackJob::new(|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
}, latch);
self.inject(&[job.as_job_ref()]);
current_thread.wait_until(&job.latch);
job.into_result()
}
pub fn increment_terminate_count(&self) {
self.terminate_latch.increment();
}
pub fn terminate(&self) {
self.terminate_latch.set();
self.sleep.tickle(usize::MAX);
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct RegistryId {
addr: usize
}
impl RegistryState {
pub fn new(job_injector: Worker<JobRef>) -> RegistryState {
RegistryState {
job_injector: job_injector,
}
}
}
struct ThreadInfo {
primed: LockLatch,
stopped: LockLatch,
stealer: Stealer<JobRef>,
}
impl ThreadInfo {
fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
ThreadInfo {
primed: LockLatch::new(),
stopped: LockLatch::new(),
stealer: stealer,
}
}
}
pub struct WorkerThread {
worker: Worker<JobRef>,
index: usize,
breadth_first: bool,
rng: UnsafeCell<rand::XorShiftRng>,
registry: Arc<Registry>,
}
thread_local! {
static WORKER_THREAD_STATE: Cell<*const WorkerThread> =
Cell::new(0 as *const WorkerThread)
}
impl WorkerThread {
#[inline]
pub fn current() -> *const WorkerThread {
WORKER_THREAD_STATE.with(|t| t.get())
}
unsafe fn set_current(thread: *const WorkerThread) {
WORKER_THREAD_STATE.with(|t| {
assert!(t.get().is_null());
t.set(thread);
});
}
pub fn registry(&self) -> &Arc<Registry> {
&self.registry
}
#[inline]
pub fn index(&self) -> usize {
self.index
}
#[inline]
pub unsafe fn push(&self, job: JobRef) {
self.worker.push(job);
self.registry.sleep.tickle(self.index);
}
#[inline]
pub fn local_deque_is_empty(&self) -> bool {
self.worker.len() == 0
}
#[inline]
pub unsafe fn take_local_job(&self) -> Option<JobRef> {
if !self.breadth_first {
self.worker.pop()
} else {
self.worker.steal()
}
}
#[inline]
pub unsafe fn wait_until<L: LatchProbe + ?Sized>(&self, latch: &L) {
log!(WaitUntil { worker: self.index });
if !latch.probe() {
self.wait_until_cold(latch);
}
}
#[cold]
unsafe fn wait_until_cold<L: LatchProbe + ?Sized>(&self, latch: &L) {
let abort_guard = unwind::AbortIfPanic;
let mut yields = 0;
while !latch.probe() {
if let Some(job) = self.take_local_job()
.or_else(|| self.steal())
.or_else(|| self.registry.pop_injected_job(self.index)) {
yields = self.registry.sleep.work_found(self.index, yields);
self.execute(job);
} else {
yields = self.registry.sleep.no_work_found(self.index, yields);
}
}
self.registry.sleep.work_found(self.index, yields);
log!(LatchSet { worker: self.index });
mem::forget(abort_guard); }
pub unsafe fn execute(&self, job: JobRef) {
job.execute();
self.registry.sleep.tickle(self.index);
}
unsafe fn steal(&self) -> Option<JobRef> {
debug_assert!(self.worker.pop().is_none());
let num_threads = self.registry.thread_infos.len();
if num_threads <= 1 {
return None;
}
assert!(num_threads < (u32::MAX as usize),
"we do not support more than u32::MAX worker threads");
let start = {
let rng = &mut *self.rng.get();
rng.next_u32() % num_threads as u32
} as usize;
(start .. num_threads)
.chain(0 .. start)
.filter(|&i| i != self.index)
.filter_map(|victim_index| {
let victim = &self.registry.thread_infos[victim_index];
let stolen = victim.stealer.steal();
if stolen.is_some() {
log!(StoleWork { worker: self.index, victim: victim_index });
}
stolen
})
.next()
}
}
unsafe fn main_loop(worker: Worker<JobRef>,
registry: Arc<Registry>,
index: usize,
breadth_first: bool) {
let worker_thread = WorkerThread {
worker: worker,
breadth_first: breadth_first,
index: index,
rng: UnsafeCell::new(rand::weak_rng()),
registry: registry.clone(),
};
WorkerThread::set_current(&worker_thread);
registry.thread_infos[index].primed.set();
let abort_guard = unwind::AbortIfPanic;
if let Some(ref handler) = registry.start_handler {
let registry = registry.clone();
match unwind::halt_unwinding(|| handler(index)) {
Ok(()) => {
}
Err(err) => {
registry.handle_panic(err);
}
}
}
worker_thread.wait_until(®istry.terminate_latch);
debug_assert!(worker_thread.take_local_job().is_none());
registry.thread_infos[index].stopped.set();
mem::forget(abort_guard);
if let Some(ref handler) = registry.exit_handler {
let registry = registry.clone();
match unwind::halt_unwinding(|| handler(index)) {
Ok(()) => {
}
Err(err) => {
registry.handle_panic(err);
}
}
}
}
pub fn in_worker<OP, R>(op: OP) -> R
where OP: FnOnce(&WorkerThread, bool) -> R + Send, R: Send
{
unsafe {
let owner_thread = WorkerThread::current();
if !owner_thread.is_null() {
op(&*owner_thread, false)
} else {
global_registry().in_worker_cold(op)
}
}
}