use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
use crate::loom::sync::atomic::Ordering;
use crate::runtime::context;
use crate::runtime::scheduler;
use crate::sync::AtomicWaker;
use crate::time::Instant;
use crate::util::linked_list;
use std::cell::UnsafeCell as StdUnsafeCell;
use std::task::{Context, Poll, Waker};
use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};
type TimerResult = Result<(), crate::time::error::Error>;
const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
pub(super) const MAX_SAFE_MILLIS_DURATION: u64 = STATE_MIN_VALUE - 1;
pub(super) struct StateCell {
state: AtomicU64,
result: UnsafeCell<TimerResult>,
waker: AtomicWaker,
}
impl Default for StateCell {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for StateCell {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "StateCell({:?})", self.read_state())
}
}
impl StateCell {
fn new() -> Self {
Self {
state: AtomicU64::new(STATE_DEREGISTERED),
result: UnsafeCell::new(Ok(())),
waker: AtomicWaker::new(),
}
}
fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == STATE_DEREGISTERED {
None
} else {
Some(cur_state)
}
}
fn poll(&self, waker: &Waker) -> Poll<TimerResult> {
self.waker.register_by_ref(waker);
self.read_state()
}
fn read_state(&self) -> Poll<TimerResult> {
let cur_state = self.state.load(Ordering::Acquire);
if cur_state == STATE_DEREGISTERED {
Poll::Ready(unsafe { self.result.with(|p| *p) })
} else {
Poll::Pending
}
}
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
let mut cur_state = self.state.load(Ordering::Relaxed);
loop {
assert!(
cur_state < STATE_MIN_VALUE,
"mark_pending called when the timer entry is in an invalid state"
);
if cur_state > not_after {
break Err(cur_state);
}
match self.state.compare_exchange_weak(
cur_state,
STATE_PENDING_FIRE,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break Ok(()),
Err(actual_state) => cur_state = actual_state,
}
}
}
unsafe fn fire(&self, result: TimerResult) -> Option<Waker> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == STATE_DEREGISTERED {
return None;
}
unsafe { self.result.with_mut(|p| *p = result) };
self.state.store(STATE_DEREGISTERED, Ordering::Release);
self.waker.take_waker()
}
fn set_expiration(&self, timestamp: u64) {
debug_assert!(timestamp < STATE_MIN_VALUE);
self.state.store(timestamp, Ordering::Relaxed);
}
fn extend_expiration(&self, new_timestamp: u64) -> Result<(), ()> {
let mut prior = self.state.load(Ordering::Relaxed);
loop {
if new_timestamp < prior || prior >= STATE_MIN_VALUE {
return Err(());
}
match self.state.compare_exchange_weak(
prior,
new_timestamp,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Ok(()),
Err(true_prior) => prior = true_prior,
}
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.load(Ordering::Relaxed) != u64::MAX
}
}
#[derive(Debug)]
pub(crate) struct TimerEntry {
driver: scheduler::Handle,
inner: StdUnsafeCell<Option<TimerShared>>,
deadline: Instant,
registered: bool,
_m: std::marker::PhantomPinned,
}
unsafe impl Send for TimerEntry {}
unsafe impl Sync for TimerEntry {}
#[derive(Debug)]
pub(crate) struct TimerHandle {
inner: NonNull<TimerShared>,
}
pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, TimerShared>;
pub(crate) struct TimerShared {
shard_id: u32,
pointers: linked_list::Pointers<TimerShared>,
cached_when: AtomicU64,
state: StateCell,
_p: PhantomPinned,
}
unsafe impl Send for TimerShared {}
unsafe impl Sync for TimerShared {}
impl std::fmt::Debug for TimerShared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerShared")
.field("cached_when", &self.cached_when.load(Ordering::Relaxed))
.field("state", &self.state)
.finish()
}
}
generate_addr_of_methods! {
impl<> TimerShared {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<TimerShared>> {
&self.pointers
}
}
}
impl TimerShared {
pub(super) fn new(shard_id: u32) -> Self {
Self {
shard_id,
cached_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
_p: PhantomPinned,
}
}
pub(super) fn cached_when(&self) -> u64 {
self.cached_when.load(Ordering::Relaxed)
}
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();
self.cached_when.store(true_when, Ordering::Relaxed);
true_when
}
unsafe fn set_cached_when(&self, when: u64) {
self.cached_when.store(when, Ordering::Relaxed);
}
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
}
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.cached_when.store(t, Ordering::Relaxed);
}
pub(super) fn extend_expiration(&self, t: u64) -> Result<(), ()> {
self.state.extend_expiration(t)
}
pub(super) fn handle(&self) -> TimerHandle {
TimerHandle {
inner: NonNull::from(self),
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.might_be_registered()
}
pub(super) fn shard_id(&self) -> u32 {
self.shard_id
}
}
unsafe impl linked_list::Link for TimerShared {
type Handle = TimerHandle;
type Target = TimerShared;
fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> {
handle.inner
}
unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle {
TimerHandle { inner: ptr }
}
unsafe fn pointers(
target: NonNull<Self::Target>,
) -> NonNull<linked_list::Pointers<Self::Target>> {
TimerShared::addr_of_pointers(target)
}
}
impl TimerEntry {
#[track_caller]
pub(crate) fn new(handle: scheduler::Handle, deadline: Instant) -> Self {
let _ = handle.driver().time();
Self {
driver: handle,
inner: StdUnsafeCell::new(None),
deadline,
registered: false,
_m: std::marker::PhantomPinned,
}
}
fn is_inner_init(&self) -> bool {
unsafe { &*self.inner.get() }.is_some()
}
fn inner(&self) -> &TimerShared {
let inner = unsafe { &*self.inner.get() };
if inner.is_none() {
let shard_size = self.driver.driver().time().inner.get_shard_size();
let shard_id = generate_shard_id(shard_size);
unsafe {
*self.inner.get() = Some(TimerShared::new(shard_id));
}
}
return inner.as_ref().unwrap();
}
pub(crate) fn deadline(&self) -> Instant {
self.deadline
}
pub(crate) fn is_elapsed(&self) -> bool {
self.is_inner_init() && !self.inner().state.might_be_registered() && self.registered
}
pub(crate) fn cancel(self: Pin<&mut Self>) {
if !self.is_inner_init() {
return;
}
unsafe { self.driver().clear_entry(NonNull::from(self.inner())) };
}
pub(crate) fn reset(mut self: Pin<&mut Self>, new_time: Instant, reregister: bool) {
let this = unsafe { self.as_mut().get_unchecked_mut() };
this.deadline = new_time;
this.registered = reregister;
let tick = self.driver().time_source().deadline_to_tick(new_time);
if self.inner().extend_expiration(tick).is_ok() {
return;
}
if reregister {
unsafe {
self.driver()
.reregister(&self.driver.driver().io, tick, self.inner().into());
}
}
}
pub(crate) fn poll_elapsed(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), super::Error>> {
assert!(
!self.driver().is_shutdown(),
"{}",
crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR
);
if !self.registered {
let deadline = self.deadline;
self.as_mut().reset(deadline, true);
}
self.inner().state.poll(cx.waker())
}
pub(crate) fn driver(&self) -> &super::Handle {
self.driver.driver().time()
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(crate) fn clock(&self) -> &super::Clock {
self.driver.driver().clock()
}
}
impl TimerHandle {
pub(super) unsafe fn cached_when(&self) -> u64 {
unsafe { self.inner.as_ref().cached_when() }
}
pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}
pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
}
pub(super) unsafe fn set_expiration(&self, tick: u64) {
self.inner.as_ref().set_expiration(tick);
}
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
}
pub(super) unsafe fn fire(self, completed_state: TimerResult) -> Option<Waker> {
self.inner.as_ref().state.fire(completed_state)
}
}
impl Drop for TimerEntry {
fn drop(&mut self) {
unsafe { Pin::new_unchecked(self) }.as_mut().cancel();
}
}
cfg_rt! {
fn generate_shard_id(shard_size: u32) -> u32 {
let id = context::with_scheduler(|ctx| match ctx {
Some(scheduler::Context::CurrentThread(_ctx)) => 0,
#[cfg(feature = "rt-multi-thread")]
Some(scheduler::Context::MultiThread(ctx)) => ctx.get_worker_index() as u32,
#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))]
Some(scheduler::Context::MultiThreadAlt(ctx)) => ctx.get_worker_index() as u32,
None => context::thread_rng_n(shard_size),
});
id % shard_size
}
}
cfg_not_rt! {
fn generate_shard_id(shard_size: u32) -> u32 {
context::thread_rng_n(shard_size)
}
}