use crate::runtime::handle::Handle;
use crate::runtime::{blocking, driver, Callback, Runtime, Spawner};
use std::fmt;
use std::io;
use std::time::Duration;
pub struct Builder {
kind: Kind,
enable_io: bool,
enable_time: bool,
start_paused: bool,
worker_threads: Option<usize>,
max_blocking_threads: usize,
pub(super) thread_name: ThreadNameFn,
pub(super) thread_stack_size: Option<usize>,
pub(super) after_start: Option<Callback>,
pub(super) before_stop: Option<Callback>,
pub(super) before_park: Option<Callback>,
pub(super) after_unpark: Option<Callback>,
pub(super) keep_alive: Option<Duration>,
}
pub(crate) type ThreadNameFn = std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>;
pub(crate) enum Kind {
CurrentThread,
#[cfg(feature = "rt-multi-thread")]
MultiThread,
}
impl Builder {
pub fn new_current_thread() -> Builder {
Builder::new(Kind::CurrentThread)
}
#[cfg(feature = "rt-multi-thread")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))]
pub fn new_multi_thread() -> Builder {
Builder::new(Kind::MultiThread)
}
pub(crate) fn new(kind: Kind) -> Builder {
Builder {
kind,
enable_io: false,
enable_time: false,
start_paused: false,
worker_threads: None,
max_blocking_threads: 512,
thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()),
thread_stack_size: None,
after_start: None,
before_stop: None,
before_park: None,
after_unpark: None,
keep_alive: None,
}
}
pub fn enable_all(&mut self) -> &mut Self {
#[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))]
self.enable_io();
#[cfg(feature = "time")]
self.enable_time();
self
}
pub fn worker_threads(&mut self, val: usize) -> &mut Self {
assert!(val > 0, "Worker threads cannot be set to 0");
self.worker_threads = Some(val);
self
}
#[cfg_attr(docsrs, doc(alias = "max_threads"))]
pub fn max_blocking_threads(&mut self, val: usize) -> &mut Self {
assert!(val > 0, "Max blocking threads cannot be set to 0");
self.max_blocking_threads = val;
self
}
pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self {
let val = val.into();
self.thread_name = std::sync::Arc::new(move || val.clone());
self
}
pub fn thread_name_fn<F>(&mut self, f: F) -> &mut Self
where
F: Fn() -> String + Send + Sync + 'static,
{
self.thread_name = std::sync::Arc::new(f);
self
}
pub fn thread_stack_size(&mut self, val: usize) -> &mut Self {
self.thread_stack_size = Some(val);
self
}
#[cfg(not(loom))]
pub fn on_thread_start<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.after_start = Some(std::sync::Arc::new(f));
self
}
#[cfg(not(loom))]
pub fn on_thread_stop<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.before_stop = Some(std::sync::Arc::new(f));
self
}
#[cfg(not(loom))]
pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.before_park = Some(std::sync::Arc::new(f));
self
}
#[cfg(not(loom))]
pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.after_unpark = Some(std::sync::Arc::new(f));
self
}
pub fn build(&mut self) -> io::Result<Runtime> {
match &self.kind {
Kind::CurrentThread => self.build_basic_runtime(),
#[cfg(feature = "rt-multi-thread")]
Kind::MultiThread => self.build_threaded_runtime(),
}
}
fn get_cfg(&self) -> driver::Cfg {
driver::Cfg {
enable_pause_time: match self.kind {
Kind::CurrentThread => true,
#[cfg(feature = "rt-multi-thread")]
Kind::MultiThread => false,
},
enable_io: self.enable_io,
enable_time: self.enable_time,
start_paused: self.start_paused,
}
}
pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self {
self.keep_alive = Some(duration);
self
}
fn build_basic_runtime(&mut self) -> io::Result<Runtime> {
use crate::runtime::{BasicScheduler, Kind};
let (driver, resources) = driver::Driver::new(self.get_cfg())?;
let scheduler =
BasicScheduler::new(driver, self.before_park.clone(), self.after_unpark.clone());
let spawner = Spawner::Basic(scheduler.spawner().clone());
let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads);
let blocking_spawner = blocking_pool.spawner().clone();
Ok(Runtime {
kind: Kind::CurrentThread(scheduler),
handle: Handle {
spawner,
io_handle: resources.io_handle,
time_handle: resources.time_handle,
signal_handle: resources.signal_handle,
clock: resources.clock,
blocking_spawner,
},
blocking_pool,
})
}
}
cfg_io_driver! {
impl Builder {
pub fn enable_io(&mut self) -> &mut Self {
self.enable_io = true;
self
}
}
}
cfg_time! {
impl Builder {
pub fn enable_time(&mut self) -> &mut Self {
self.enable_time = true;
self
}
}
}
cfg_test_util! {
impl Builder {
pub fn start_paused(&mut self, start_paused: bool) -> &mut Self {
self.start_paused = start_paused;
self
}
}
}
cfg_rt_multi_thread! {
impl Builder {
fn build_threaded_runtime(&mut self) -> io::Result<Runtime> {
use crate::loom::sys::num_cpus;
use crate::runtime::{Kind, ThreadPool};
use crate::runtime::park::Parker;
let core_threads = self.worker_threads.unwrap_or_else(num_cpus);
let (driver, resources) = driver::Driver::new(self.get_cfg())?;
let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver), self.before_park.clone(), self.after_unpark.clone());
let spawner = Spawner::ThreadPool(scheduler.spawner().clone());
let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads + core_threads);
let blocking_spawner = blocking_pool.spawner().clone();
let handle = Handle {
spawner,
io_handle: resources.io_handle,
time_handle: resources.time_handle,
signal_handle: resources.signal_handle,
clock: resources.clock,
blocking_spawner,
};
let _enter = crate::runtime::context::enter(handle.clone());
launch.launch();
Ok(Runtime {
kind: Kind::ThreadPool(scheduler),
handle,
blocking_pool,
})
}
}
}
impl fmt::Debug for Builder {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Builder")
.field("worker_threads", &self.worker_threads)
.field("max_blocking_threads", &self.max_blocking_threads)
.field(
"thread_name",
&"<dyn Fn() -> String + Send + Sync + 'static>",
)
.field("thread_stack_size", &self.thread_stack_size)
.field("after_start", &self.after_start.as_ref().map(|_| "..."))
.field("before_stop", &self.before_stop.as_ref().map(|_| "..."))
.field("before_park", &self.before_park.as_ref().map(|_| "..."))
.field("after_unpark", &self.after_unpark.as_ref().map(|_| "..."))
.finish()
}
}