use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
use crate::loom::sync::atomic::Ordering;
use crate::sync::AtomicWaker;
use crate::time::Instant;
use crate::util::linked_list;
use super::Handle;
use std::cell::UnsafeCell as StdUnsafeCell;
use std::task::{Context, Poll, Waker};
use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};
type TimerResult = Result<(), crate::time::error::Error>;
const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
pub(super) struct StateCell {
state: AtomicU64,
result: UnsafeCell<TimerResult>,
waker: CachePadded<AtomicWaker>,
}
impl Default for StateCell {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for StateCell {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "StateCell({:?})", self.read_state())
}
}
impl StateCell {
fn new() -> Self {
Self {
state: AtomicU64::new(STATE_DEREGISTERED),
result: UnsafeCell::new(Ok(())),
waker: CachePadded(AtomicWaker::new()),
}
}
fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == u64::MAX {
None
} else {
Some(cur_state)
}
}
fn poll(&self, waker: &Waker) -> Poll<TimerResult> {
self.waker.0.register_by_ref(waker);
self.read_state()
}
fn read_state(&self) -> Poll<TimerResult> {
let cur_state = self.state.load(Ordering::Acquire);
if cur_state == STATE_DEREGISTERED {
Poll::Ready(unsafe { self.result.with(|p| *p) })
} else {
Poll::Pending
}
}
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
let mut cur_state = self.state.load(Ordering::Relaxed);
loop {
debug_assert!(cur_state < STATE_MIN_VALUE);
if cur_state > not_after {
break Err(cur_state);
}
match self.state.compare_exchange(
cur_state,
STATE_PENDING_FIRE,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
break Ok(());
}
Err(actual_state) => {
cur_state = actual_state;
}
}
}
}
unsafe fn fire(&self, result: TimerResult) -> Option<Waker> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == STATE_DEREGISTERED {
return None;
}
unsafe { self.result.with_mut(|p| *p = result) };
self.state.store(STATE_DEREGISTERED, Ordering::Release);
self.waker.0.take_waker()
}
fn set_expiration(&self, timestamp: u64) {
debug_assert!(timestamp < STATE_MIN_VALUE);
self.state.store(timestamp, Ordering::Relaxed);
}
fn extend_expiration(&self, new_timestamp: u64) -> Result<(), ()> {
let mut prior = self.state.load(Ordering::Relaxed);
loop {
if new_timestamp < prior || prior >= STATE_MIN_VALUE {
return Err(());
}
match self.state.compare_exchange_weak(
prior,
new_timestamp,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Ok(());
}
Err(true_prior) => {
prior = true_prior;
}
}
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.load(Ordering::Relaxed) != u64::MAX
}
}
#[derive(Debug)]
pub(super) struct TimerEntry {
driver: Handle,
inner: StdUnsafeCell<TimerShared>,
initial_deadline: Option<Instant>,
_m: std::marker::PhantomPinned,
}
unsafe impl Send for TimerEntry {}
unsafe impl Sync for TimerEntry {}
#[derive(Debug)]
pub(crate) struct TimerHandle {
inner: NonNull<TimerShared>,
}
pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, TimerShared>;
#[derive(Debug)]
pub(crate) struct TimerShared {
state: StateCell,
driver_state: CachePadded<TimerSharedPadded>,
_p: PhantomPinned,
}
impl TimerShared {
pub(super) fn new() -> Self {
Self {
state: StateCell::default(),
driver_state: CachePadded(TimerSharedPadded::new()),
_p: PhantomPinned,
}
}
pub(super) fn cached_when(&self) -> u64 {
self.driver_state.0.cached_when.load(Ordering::Relaxed)
}
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();
self.driver_state
.0
.cached_when
.store(true_when, Ordering::Relaxed);
true_when
}
unsafe fn set_cached_when(&self, when: u64) {
self.driver_state
.0
.cached_when
.store(when, Ordering::Relaxed);
}
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
}
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.driver_state.0.cached_when.store(t, Ordering::Relaxed);
}
pub(super) fn extend_expiration(&self, t: u64) -> Result<(), ()> {
self.state.extend_expiration(t)
}
pub(super) fn handle(&self) -> TimerHandle {
TimerHandle {
inner: NonNull::from(self),
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.might_be_registered()
}
}
struct TimerSharedPadded {
cached_when: AtomicU64,
true_when: AtomicU64,
pointers: StdUnsafeCell<linked_list::Pointers<TimerShared>>,
}
impl std::fmt::Debug for TimerSharedPadded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerSharedPadded")
.field("when", &self.true_when.load(Ordering::Relaxed))
.field("cached_when", &self.cached_when.load(Ordering::Relaxed))
.finish()
}
}
impl TimerSharedPadded {
fn new() -> Self {
Self {
cached_when: AtomicU64::new(0),
true_when: AtomicU64::new(0),
pointers: StdUnsafeCell::new(linked_list::Pointers::new()),
}
}
}
unsafe impl Send for TimerShared {}
unsafe impl Sync for TimerShared {}
unsafe impl linked_list::Link for TimerShared {
type Handle = TimerHandle;
type Target = TimerShared;
fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> {
handle.inner
}
unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle {
TimerHandle { inner: ptr }
}
unsafe fn pointers(
target: NonNull<Self::Target>,
) -> NonNull<linked_list::Pointers<Self::Target>> {
unsafe { NonNull::new(target.as_ref().driver_state.0.pointers.get()).unwrap() }
}
}
impl TimerEntry {
pub(crate) fn new(handle: &Handle, deadline: Instant) -> Self {
let driver = handle.clone();
Self {
driver,
inner: StdUnsafeCell::new(TimerShared::new()),
initial_deadline: Some(deadline),
_m: std::marker::PhantomPinned,
}
}
fn inner(&self) -> &TimerShared {
unsafe { &*self.inner.get() }
}
pub(crate) fn is_elapsed(&self) -> bool {
!self.inner().state.might_be_registered() && self.initial_deadline.is_none()
}
pub(crate) fn cancel(self: Pin<&mut Self>) {
unsafe { self.driver.clear_entry(NonNull::from(self.inner())) };
}
pub(crate) fn reset(mut self: Pin<&mut Self>, new_time: Instant) {
unsafe { self.as_mut().get_unchecked_mut() }.initial_deadline = None;
let tick = self.driver.time_source().deadline_to_tick(new_time);
if self.inner().extend_expiration(tick).is_ok() {
return;
}
unsafe {
self.driver.reregister(tick, self.inner().into());
}
}
pub(crate) fn poll_elapsed(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), super::Error>> {
if self.driver.is_shutdown() {
panic!("{}", crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR);
}
if let Some(deadline) = self.initial_deadline {
self.as_mut().reset(deadline);
}
let this = unsafe { self.get_unchecked_mut() };
this.inner().state.poll(cx.waker())
}
}
impl TimerHandle {
pub(super) unsafe fn cached_when(&self) -> u64 {
unsafe { self.inner.as_ref().cached_when() }
}
pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}
pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
}
pub(super) unsafe fn set_expiration(&self, tick: u64) {
self.inner.as_ref().set_expiration(tick);
}
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
}
pub(super) unsafe fn fire(self, completed_state: TimerResult) -> Option<Waker> {
self.inner.as_ref().state.fire(completed_state)
}
}
impl Drop for TimerEntry {
fn drop(&mut self) {
unsafe { Pin::new_unchecked(self) }.as_mut().cancel()
}
}
#[cfg_attr(target_arch = "x86_64", repr(align(128)))]
#[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))]
#[derive(Debug, Default)]
struct CachePadded<T>(T);