#![deny(missing_docs)]
extern crate futures;
extern crate num_cpus;
use std::panic::{self, AssertUnwindSafe};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc;
use std::thread;
use futures::{IntoFuture, Future, Poll, Async};
use futures::future::lazy;
use futures::sync::oneshot::{channel, Sender, Receiver};
use futures::executor::{self, Run, Executor};
pub struct CpuPool {
inner: Arc<Inner>,
}
pub struct Builder {
pool_size: usize,
name_prefix: Option<String>,
after_start: Option<Arc<Fn() + Send + Sync>>,
before_stop: Option<Arc<Fn() + Send + Sync>>,
}
struct MySender<F, T> {
fut: F,
tx: Option<Sender<T>>,
keep_running_flag: Arc<AtomicBool>,
}
fn _assert() {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
_assert_send::<CpuPool>();
_assert_sync::<CpuPool>();
}
struct Inner {
tx: Mutex<mpsc::Sender<Message>>,
rx: Mutex<mpsc::Receiver<Message>>,
cnt: AtomicUsize,
size: usize,
after_start: Option<Arc<Fn() + Send + Sync>>,
before_stop: Option<Arc<Fn() + Send + Sync>>,
}
#[must_use]
pub struct CpuFuture<T, E> {
inner: Receiver<thread::Result<Result<T, E>>>,
keep_running_flag: Arc<AtomicBool>,
}
enum Message {
Run(Run),
Close,
}
impl CpuPool {
pub fn new(size: usize) -> CpuPool {
Builder::new().pool_size(size).create()
}
pub fn new_num_cpus() -> CpuPool {
Builder::new().create()
}
pub fn spawn<F>(&self, f: F) -> CpuFuture<F::Item, F::Error>
where F: Future + Send + 'static,
F::Item: Send + 'static,
F::Error: Send + 'static,
{
let (tx, rx) = channel();
let keep_running_flag = Arc::new(AtomicBool::new(false));
let sender = MySender {
fut: AssertUnwindSafe(f).catch_unwind(),
tx: Some(tx),
keep_running_flag: keep_running_flag.clone(),
};
executor::spawn(sender).execute(self.inner.clone());
CpuFuture { inner: rx , keep_running_flag: keep_running_flag.clone() }
}
pub fn spawn_fn<F, R>(&self, f: F) -> CpuFuture<R::Item, R::Error>
where F: FnOnce() -> R + Send + 'static,
R: IntoFuture + 'static,
R::Future: Send + 'static,
R::Item: Send + 'static,
R::Error: Send + 'static,
{
self.spawn(lazy(f))
}
}
impl Inner {
fn send(&self, msg: Message) {
self.tx.lock().unwrap().send(msg).unwrap();
}
fn work(&self) {
self.after_start.as_ref().map(|fun| fun());
loop {
let msg = self.rx.lock().unwrap().recv().unwrap();
match msg {
Message::Run(r) => r.run(),
Message::Close => break,
}
}
self.before_stop.as_ref().map(|fun| fun());
}
}
impl Clone for CpuPool {
fn clone(&self) -> CpuPool {
self.inner.cnt.fetch_add(1, Ordering::Relaxed);
CpuPool { inner: self.inner.clone() }
}
}
impl Drop for CpuPool {
fn drop(&mut self) {
if self.inner.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
for _ in 0..self.inner.size {
self.inner.send(Message::Close);
}
}
}
}
impl Executor for Inner {
fn execute(&self, run: Run) {
self.send(Message::Run(run))
}
}
impl<T, E> CpuFuture<T, E> {
pub fn forget(self) {
self.keep_running_flag.store(true, Ordering::SeqCst);
}
}
impl<T: Send + 'static, E: Send + 'static> Future for CpuFuture<T, E> {
type Item = T;
type Error = E;
fn poll(&mut self) -> Poll<T, E> {
match self.inner.poll().expect("shouldn't be canceled") {
Async::Ready(Ok(Ok(e))) => Ok(e.into()),
Async::Ready(Ok(Err(e))) => Err(e),
Async::Ready(Err(e)) => panic::resume_unwind(e),
Async::NotReady => Ok(Async::NotReady),
}
}
}
impl<F: Future> Future for MySender<F, Result<F::Item, F::Error>> {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
if let Ok(Async::Ready(_)) = self.tx.as_mut().unwrap().poll_cancel() {
if !self.keep_running_flag.load(Ordering::SeqCst) {
return Ok(().into())
}
}
let res = match self.fut.poll() {
Ok(Async::Ready(e)) => Ok(e),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => Err(e),
};
drop(self.tx.take().unwrap().send(res));
Ok(Async::Ready(()))
}
}
impl Builder {
pub fn new() -> Builder {
Builder {
pool_size: num_cpus::get(),
name_prefix: None,
after_start: None,
before_stop: None,
}
}
pub fn pool_size(&mut self, size: usize) -> &mut Self {
self.pool_size = size;
self
}
pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
self.name_prefix = Some(name_prefix.into());
self
}
pub fn after_start<F>(&mut self, f: F) -> &mut Self
where F: Fn() + Send + Sync + 'static
{
self.after_start = Some(Arc::new(f));
self
}
pub fn before_stop<F>(&mut self, f: F) -> &mut Self
where F: Fn() + Send + Sync + 'static
{
self.before_stop = Some(Arc::new(f));
self
}
pub fn create(&mut self) -> CpuPool {
let (tx, rx) = mpsc::channel();
let pool = CpuPool {
inner: Arc::new(Inner {
tx: Mutex::new(tx),
rx: Mutex::new(rx),
cnt: AtomicUsize::new(1),
size: self.pool_size,
after_start: self.after_start.clone(),
before_stop: self.before_stop.clone(),
}),
};
assert!(self.pool_size > 0);
for counter in 0..self.pool_size {
let inner = pool.inner.clone();
let mut thread_builder = thread::Builder::new();
if let Some(ref name_prefix) = self.name_prefix {
thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
}
thread_builder.spawn(move || inner.work()).unwrap();
}
return pool
}
}