pub use crate::rt::thread::AccessError;
pub use crate::rt::yield_now;
use crate::rt::{self, Execution, Location};
#[doc(no_inline)]
pub use std::thread::panicking;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use std::{fmt, io};
use tracing::trace;
pub struct JoinHandle<T> {
result: Arc<Mutex<Option<std::thread::Result<T>>>>,
notify: rt::Notify,
thread: Thread,
}
#[derive(Clone, Debug)]
pub struct Thread {
id: ThreadId,
name: Option<String>,
}
impl Thread {
pub fn id(&self) -> ThreadId {
self.id
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn unpark(&self) {
rt::execution(|execution| execution.threads.unpark(self.id.id));
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub struct ThreadId {
id: crate::rt::thread::Id,
}
impl std::fmt::Debug for ThreadId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ThreadId({})", self.id.public_id())
}
}
pub struct LocalKey<T> {
#[doc(hidden)]
pub init: fn() -> T,
#[doc(hidden)]
pub _p: PhantomData<fn(T)>,
}
#[derive(Debug)]
pub struct Builder {
name: Option<String>,
stack_size: Option<usize>,
}
static CURRENT_THREAD_KEY: LocalKey<Thread> = LocalKey {
init: || unreachable!(),
_p: PhantomData,
};
fn init_current(execution: &mut Execution, name: Option<String>) -> Thread {
let id = execution.threads.active_id();
let thread = Thread {
id: ThreadId { id },
name,
};
execution
.threads
.local_init(&CURRENT_THREAD_KEY, thread.clone());
thread
}
pub fn current() -> Thread {
rt::execution(|execution| {
let thread = execution.threads.local(&CURRENT_THREAD_KEY);
if let Some(thread) = thread {
thread.unwrap().clone()
} else {
init_current(execution, None)
}
})
}
#[track_caller]
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
spawn_internal(f, None, None, location!())
}
#[track_caller]
pub fn park() {
rt::park(location!());
}
fn spawn_internal<F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
location: Location,
) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
let result = Arc::new(Mutex::new(None));
let notify = rt::Notify::new(true, false);
let id = {
let name = name.clone();
let result = result.clone();
rt::spawn(stack_size, move || {
rt::execution(|execution| {
init_current(execution, name);
});
*result.lock().unwrap() = Some(Ok(f()));
notify.notify(location);
})
};
JoinHandle {
result,
notify,
thread: Thread {
id: ThreadId { id },
name,
},
}
}
impl Builder {
#[allow(clippy::new_without_default)]
pub fn new() -> Builder {
Builder {
name: None,
stack_size: None,
}
}
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
pub fn stack_size(mut self, size: usize) -> Builder {
self.stack_size = Some(size);
self
}
#[track_caller]
pub fn spawn<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
Ok(spawn_internal(f, self.name, self.stack_size, location!()))
}
}
impl<T> JoinHandle<T> {
#[track_caller]
pub fn join(self) -> std::thread::Result<T> {
self.notify.wait(location!());
self.result.lock().unwrap().take().unwrap()
}
pub fn thread(&self) -> &Thread {
&self.thread
}
}
impl<T: fmt::Debug> fmt::Debug for JoinHandle<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("JoinHandle").finish()
}
}
fn _assert_traits() {
fn assert<T: Send + Sync>() {}
assert::<JoinHandle<()>>();
}
impl<T: 'static> LocalKey<T> {
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.try_with(f)
.expect("cannot access a (mock) TLS value during or after it is destroyed")
}
pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
where
F: FnOnce(&T) -> R,
{
let value = match unsafe { self.get() } {
Some(v) => v?,
None => {
let value = (self.init)();
rt::execution(|execution| {
trace!("LocalKey::try_with");
execution.threads.local_init(self, value);
});
unsafe { self.get() }.expect("bug")?
}
};
Ok(f(value))
}
unsafe fn get(&'static self) -> Option<Result<&T, AccessError>> {
unsafe fn transmute_lt<'a, 'b, T>(t: &'a T) -> &'b T {
std::mem::transmute::<&'a T, &'b T>(t)
}
rt::execution(|execution| {
trace!("LocalKey::get");
let res = execution.threads.local(self)?;
let local = match res {
Ok(l) => l,
Err(e) => return Some(Err(e)),
};
Some(Ok(transmute_lt(local)))
})
}
}
impl<T: 'static> fmt::Debug for LocalKey<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("LocalKey { .. }")
}
}