mod afd;
mod port;
use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock};
use port::{IoCompletionPort, OverlappedEntry};
use windows_sys::Win32::Foundation::{
BOOLEAN, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED,
};
use windows_sys::Win32::System::Threading::{
RegisterWaitForSingleObject, UnregisterWait, INFINITE, WT_EXECUTELONGFUNCTION,
WT_EXECUTEONLYONCE,
};
use crate::{Event, PollMode};
use concurrent_queue::ConcurrentQueue;
use pin_project_lite::pin_project;
use std::cell::UnsafeCell;
use std::collections::hash_map::{Entry, HashMap};
use std::ffi::c_void;
use std::fmt;
use std::io;
use std::marker::PhantomPinned;
use std::mem::{forget, MaybeUninit};
use std::os::windows::io::{
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
use std::time::{Duration, Instant};
macro_rules! lock {
($lock_result:expr) => {{
$lock_result.unwrap_or_else(|e| e.into_inner())
}};
}
#[derive(Debug)]
pub(super) struct Poller {
port: Arc<IoCompletionPort<Packet>>,
afd: Mutex<Vec<Weak<Afd<Packet>>>>,
sources: RwLock<HashMap<RawSocket, Packet>>,
waitables: RwLock<HashMap<RawHandle, Packet>>,
pending_updates: ConcurrentQueue<Packet>,
polling: AtomicBool,
notifier: Packet,
}
unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}
impl Poller {
pub(super) fn new() -> io::Result<Self> {
if let Err(e) = afd::NtdllImports::force_load() {
return Err(crate::unsupported_error(format!(
"Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.",
e
)));
}
Afd::<Packet>::new().map_err(|e| crate::unsupported_error(format!(
"Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.",
e,
)))?;
let port = IoCompletionPort::new(0)?;
tracing::trace!(handle = ?port, "new");
Ok(Poller {
port: Arc::new(port),
afd: Mutex::new(vec![]),
sources: RwLock::new(HashMap::new()),
waitables: RwLock::new(HashMap::new()),
pending_updates: ConcurrentQueue::bounded(1024),
polling: AtomicBool::new(false),
notifier: Arc::pin(
PacketInner::Wakeup {
_pinned: PhantomPinned,
}
.into(),
),
})
}
pub(super) fn supports_level(&self) -> bool {
true
}
pub(super) fn supports_edge(&self) -> bool {
false
}
pub(super) unsafe fn add(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
let span = tracing::trace_span!(
"add",
handle = ?self.port,
sock = ?socket,
ev = ?interest,
);
let _enter = span.enter();
if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
let socket_state = {
let state = SocketState {
socket,
base_socket: base_socket(socket)?,
interest,
interest_error: true,
afd: self.afd_handle()?,
mode,
waiting_on_delete: false,
status: SocketStatus::Idle,
};
Arc::pin(IoStatusBlock::from(PacketInner::Socket {
packet: UnsafeCell::new(AfdPollInfo::default()),
socket: Mutex::new(state),
}))
};
{
let mut sources = lock!(self.sources.write());
match sources.entry(socket) {
Entry::Vacant(v) => {
v.insert(Pin::<Arc<_>>::clone(&socket_state));
}
Entry::Occupied(_) => {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
}
}
self.update_packet(socket_state)
}
pub(super) fn modify(
&self,
socket: BorrowedSocket<'_>,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
let span = tracing::trace_span!(
"modify",
handle = ?self.port,
sock = ?socket,
ev = ?interest,
);
let _enter = span.enter();
if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
let source = {
let sources = lock!(self.sources.read());
sources
.get(&socket.as_raw_socket())
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
if source.as_ref().set_events(interest, mode) {
self.update_packet(source)?;
}
Ok(())
}
pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"remove",
handle = ?self.port,
sock = ?socket,
);
let _enter = span.enter();
let source = {
let mut sources = lock!(self.sources.write());
match sources.remove(&socket.as_raw_socket()) {
Some(s) => s,
None => {
return Ok(());
}
}
};
source.begin_delete()
}
pub(super) fn add_waitable(
&self,
handle: RawHandle,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
tracing::trace!(
"add_waitable: handle={:?}, waitable={:p}, ev={:?}",
self.port,
handle,
interest
);
if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
let handle_state = {
let state = WaitableState {
handle,
port: Arc::downgrade(&self.port),
interest,
mode,
status: WaitableStatus::Idle,
};
Arc::pin(IoStatusBlock::from(PacketInner::Waitable {
handle: Mutex::new(state),
}))
};
{
let mut sources = lock!(self.waitables.write());
match sources.entry(handle) {
Entry::Vacant(v) => {
v.insert(Pin::<Arc<_>>::clone(&handle_state));
}
Entry::Occupied(_) => {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
}
}
self.update_packet(handle_state)
}
pub(crate) fn modify_waitable(
&self,
waitable: RawHandle,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
tracing::trace!(
"modify_waitable: handle={:?}, waitable={:p}, ev={:?}",
self.port,
waitable,
interest
);
if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
let source = {
let sources = lock!(self.waitables.read());
sources
.get(&waitable)
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
if source.as_ref().set_events(interest, mode) {
self.update_packet(source)?;
}
Ok(())
}
pub(super) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> {
tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable);
let source = {
let mut sources = lock!(self.waitables.write());
match sources.remove(&waitable) {
Some(s) => s,
None => {
return Ok(());
}
}
};
source.begin_delete()
}
pub(super) fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let span = tracing::trace_span!(
"wait",
handle = ?self.port,
?timeout,
);
let _enter = span.enter();
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
let mut notified = false;
events.packets.clear();
loop {
let mut new_events = 0;
let was_polling = self.polling.swap(true, Ordering::SeqCst);
debug_assert!(!was_polling);
let guard = CallOnDrop(|| {
let was_polling = self.polling.swap(false, Ordering::SeqCst);
debug_assert!(was_polling);
});
self.drain_update_queue(false)?;
let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now()));
let len = self.port.wait(&mut events.completions, timeout)?;
tracing::trace!(
handle = ?self.port,
res = ?len,
"new events");
drop(guard);
for entry in events.completions.drain(..) {
let packet = entry.into_packet();
match packet.feed_event(self)? {
FeedEventResult::NoEvent => {}
FeedEventResult::Event(event) => {
events.packets.push(event);
new_events += 1;
}
FeedEventResult::Notified => {
notified = true;
}
}
}
let timeout_is_empty =
timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0);
if notified || new_events > 0 || timeout_is_empty {
break;
}
tracing::trace!("wait: no events found, re-entering polling loop");
}
Ok(())
}
pub(super) fn notify(&self) -> io::Result<()> {
self.port.post(0, 0, self.notifier.clone())
}
pub(super) fn post(&self, packet: CompletionPacket) -> io::Result<()> {
self.port.post(0, 0, packet.0)
}
fn update_packet(&self, mut packet: Packet) -> io::Result<()> {
loop {
if self.polling.load(Ordering::Acquire) {
packet.update()?;
return Ok(());
}
match self.pending_updates.push(packet) {
Ok(()) => return Ok(()),
Err(p) => packet = p.into_inner(),
}
self.drain_update_queue(true)?;
}
}
fn drain_update_queue(&self, limit: bool) -> io::Result<()> {
let max = if limit {
self.pending_updates.capacity().unwrap()
} else {
std::usize::MAX
};
self.pending_updates
.try_iter()
.take(max)
.try_for_each(|packet| packet.update())
}
fn afd_handle(&self) -> io::Result<Arc<Afd<Packet>>> {
const AFD_MAX_SIZE: usize = 32;
let mut afd_handles = lock!(self.afd.lock());
let mut i = 0;
while i < afd_handles.len() {
let refcount = Weak::strong_count(&afd_handles[i]);
match refcount {
0 => {
afd_handles.swap_remove(i);
}
refcount if refcount >= AFD_MAX_SIZE => {
i += 1;
}
_ => {
match afd_handles[i].upgrade() {
Some(afd) => return Ok(afd),
None => {
afd_handles.swap_remove(i);
}
}
}
}
}
let afd = Arc::new(Afd::new()?);
self.port.register(&*afd, true)?;
afd_handles.push(Arc::downgrade(&afd));
Ok(afd)
}
}
impl AsRawHandle for Poller {
fn as_raw_handle(&self) -> RawHandle {
self.port.as_raw_handle()
}
}
impl AsHandle for Poller {
fn as_handle(&self) -> BorrowedHandle<'_> {
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
}
}
pub(super) struct Events {
packets: Vec<Event>,
completions: Vec<OverlappedEntry<Packet>>,
}
unsafe impl Send for Events {}
impl Events {
pub fn with_capacity(cap: usize) -> Events {
Events {
packets: Vec::with_capacity(cap),
completions: Vec::with_capacity(cap),
}
}
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.packets.iter().copied()
}
pub fn clear(&mut self) {
self.packets.clear();
}
pub fn capacity(&self) -> usize {
self.packets.capacity()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct EventExtra {
flags: AfdPollMask,
}
impl EventExtra {
#[inline]
pub const fn empty() -> EventExtra {
EventExtra {
flags: AfdPollMask::empty(),
}
}
#[inline]
pub fn is_hup(&self) -> bool {
self.flags.intersects(AfdPollMask::ABORT)
}
#[inline]
pub fn is_pri(&self) -> bool {
self.flags.intersects(AfdPollMask::RECEIVE_EXPEDITED)
}
#[inline]
pub fn set_hup(&mut self, active: bool) {
self.flags.set(AfdPollMask::ABORT, active);
}
#[inline]
pub fn set_pri(&mut self, active: bool) {
self.flags.set(AfdPollMask::RECEIVE_EXPEDITED, active);
}
}
#[derive(Debug, Clone)]
pub struct CompletionPacket(Packet);
impl CompletionPacket {
pub fn new(event: Event) -> Self {
Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event })))
}
pub fn event(&self) -> &Event {
let data = self.0.as_ref().data().project_ref();
match data {
PacketInnerProj::Custom { event } => event,
_ => unreachable!(),
}
}
}
type Packet = Pin<Arc<PacketUnwrapped>>;
type PacketUnwrapped = IoStatusBlock<PacketInner>;
pin_project! {
#[project_ref = PacketInnerProj]
#[project = PacketInnerProjMut]
enum PacketInner {
Socket {
#[pin]
packet: UnsafeCell<AfdPollInfo>,
socket: Mutex<SocketState>
},
Waitable {
handle: Mutex<WaitableState>
},
Custom {
event: Event,
},
Wakeup { #[pin] _pinned: PhantomPinned },
}
}
unsafe impl Send for PacketInner {}
unsafe impl Sync for PacketInner {}
impl fmt::Debug for PacketInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Wakeup { .. } => f.write_str("Wakeup { .. }"),
Self::Custom { event } => f.debug_struct("Custom").field("event", event).finish(),
Self::Socket { socket, .. } => f
.debug_struct("Socket")
.field("packet", &"..")
.field("socket", socket)
.finish(),
Self::Waitable { handle } => {
f.debug_struct("Waitable").field("handle", handle).finish()
}
}
}
}
impl HasAfdInfo for PacketInner {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
match self.project_ref() {
PacketInnerProj::Socket { packet, .. } => packet,
_ => unreachable!(),
}
}
}
impl PacketUnwrapped {
fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool {
match self.data().project_ref() {
PacketInnerProj::Socket { socket, .. } => {
let mut socket = lock!(socket.lock());
socket.interest = interest;
socket.mode = mode;
socket.interest_error = true;
match socket.status {
SocketStatus::Polling { flags } => {
let our_flags = event_to_afd_mask(socket.interest, socket.interest_error);
our_flags != flags
}
_ => true,
}
}
PacketInnerProj::Waitable { handle } => {
let mut handle = lock!(handle.lock());
handle.interest = interest;
handle.mode = mode;
handle.status.is_idle()
}
_ => true,
}
}
fn update(self: Pin<Arc<Self>>) -> io::Result<()> {
let mut socket = match self.as_ref().data().project_ref() {
PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()),
PacketInnerProj::Waitable { handle } => {
let mut handle = lock!(handle.lock());
if !handle.interest.readable && !handle.interest.writable {
return Ok(());
}
if !handle.status.is_idle() {
return Ok(());
}
let packet = self.clone();
let wait_handle = WaitHandle::new(
handle.handle,
move || {
let mut handle = match packet.as_ref().data().project_ref() {
PacketInnerProj::Waitable { handle } => lock!(handle.lock()),
_ => unreachable!(),
};
let iocp = match handle.port.upgrade() {
Some(iocp) => iocp,
None => return,
};
handle.status = WaitableStatus::Idle;
drop(handle);
if let Err(e) = iocp.post(0, 0, packet) {
tracing::error!("failed to post completion packet: {}", e);
}
},
None,
false,
)?;
handle.status = WaitableStatus::Waiting(wait_handle);
return Ok(());
}
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")),
};
if socket.waiting_on_delete {
return Ok(());
}
match socket.status {
SocketStatus::Polling { flags } => {
let our_flags = event_to_afd_mask(socket.interest, socket.interest_error);
if our_flags != flags {
return self.cancel(socket);
}
Ok(())
}
SocketStatus::Cancelled => {
Ok(())
}
SocketStatus::Idle => {
let mask = event_to_afd_mask(socket.interest, socket.interest_error);
let result = socket.afd.poll(self.clone(), socket.base_socket, mask);
match result {
Ok(()) => {}
Err(err)
if err.raw_os_error() == Some(ERROR_IO_PENDING as i32)
|| err.kind() == io::ErrorKind::WouldBlock =>
{
}
Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => {
}
Err(err) => return Err(err),
}
socket.status = SocketStatus::Polling { flags: mask };
Ok(())
}
}
}
fn feed_event(self: Pin<Arc<Self>>, poller: &Poller) -> io::Result<FeedEventResult> {
let inner = self.as_ref().data().project_ref();
let (afd_info, socket) = match inner {
PacketInnerProj::Socket { packet, socket } => (packet, socket),
PacketInnerProj::Custom { event } => {
return Ok(FeedEventResult::Event(*event));
}
PacketInnerProj::Wakeup { .. } => {
return Ok(FeedEventResult::Notified);
}
PacketInnerProj::Waitable { handle } => {
let mut handle = lock!(handle.lock());
let event = handle.interest;
if matches!(handle.mode, PollMode::Oneshot) {
handle.interest = Event::none(handle.interest.key);
}
drop(handle);
poller.update_packet(self)?;
return Ok(FeedEventResult::Event(event));
}
};
let mut socket_state = lock!(socket.lock());
let mut event = Event::none(socket_state.interest.key);
socket_state.status = SocketStatus::Idle;
if socket_state.waiting_on_delete {
return Ok(FeedEventResult::NoEvent);
}
unsafe {
let iosb = &mut *self.as_ref().iosb().get();
match iosb.Anonymous.Status {
STATUS_CANCELLED => {
}
status if status < 0 => {
event.readable = true;
event.writable = true;
}
_ => {
let afd_data = &*afd_info.get();
if afd_data.handle_count() >= 1 {
let events = afd_data.events();
if events.intersects(AfdPollMask::LOCAL_CLOSE) {
let source = lock!(poller.sources.write())
.remove(&socket_state.socket)
.unwrap();
return source.begin_delete().map(|()| FeedEventResult::NoEvent);
}
let (readable, writable) = afd_mask_to_event(events);
event.readable = readable;
event.writable = writable;
event.extra.flags = events;
}
}
}
}
event.readable &= socket_state.interest.readable;
event.writable &= socket_state.interest.writable;
let return_value = if event.readable
|| event.writable
|| event
.extra
.flags
.intersects(socket_state.interest.extra.flags)
{
if matches!(socket_state.mode, PollMode::Oneshot) {
socket_state.interest = Event::none(socket_state.interest.key);
socket_state.interest_error = false;
}
FeedEventResult::Event(event)
} else {
FeedEventResult::NoEvent
};
drop(socket_state);
poller.update_packet(self)?;
Ok(return_value)
}
fn begin_delete(self: Pin<Arc<Self>>) -> io::Result<()> {
let mut socket = match self.as_ref().data().project_ref() {
PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()),
PacketInnerProj::Waitable { handle } => {
let mut handle = lock!(handle.lock());
handle.status = WaitableStatus::Cancelled;
return Ok(());
}
_ => panic!("can't delete packet that doesn't belong to a socket"),
};
if !socket.waiting_on_delete {
socket.waiting_on_delete = true;
if matches!(socket.status, SocketStatus::Polling { .. }) {
self.cancel(socket)?;
}
}
Ok(())
}
fn cancel(self: &Pin<Arc<Self>>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> {
assert!(matches!(socket.status, SocketStatus::Polling { .. }));
unsafe {
socket.afd.cancel(self)?;
}
socket.status = SocketStatus::Cancelled;
Ok(())
}
}
#[derive(Debug)]
struct SocketState {
socket: RawSocket,
base_socket: RawSocket,
interest: Event,
interest_error: bool,
mode: PollMode,
afd: Arc<Afd<Packet>>,
waiting_on_delete: bool,
status: SocketStatus,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SocketStatus {
Idle,
Polling {
flags: AfdPollMask,
},
Cancelled,
}
#[derive(Debug)]
struct WaitableState {
handle: RawHandle,
port: Weak<IoCompletionPort<Packet>>,
interest: Event,
mode: PollMode,
status: WaitableStatus,
}
#[derive(Debug)]
enum WaitableStatus {
Idle,
Waiting(WaitHandle),
Cancelled,
}
impl WaitableStatus {
fn is_idle(&self) -> bool {
matches!(self, WaitableStatus::Idle)
}
}
#[derive(Debug)]
enum FeedEventResult {
NoEvent,
Event(Event),
Notified,
}
#[derive(Debug)]
struct WaitHandle(RawHandle);
impl Drop for WaitHandle {
fn drop(&mut self) {
unsafe {
UnregisterWait(self.0 as _);
}
}
}
impl WaitHandle {
fn new<F>(
handle: RawHandle,
callback: F,
timeout: Option<Duration>,
long_wait: bool,
) -> io::Result<Self>
where
F: FnOnce() + Send + Sync + 'static,
{
struct AbortOnDrop;
impl Drop for AbortOnDrop {
fn drop(&mut self) {
std::process::abort();
}
}
unsafe extern "system" fn wait_callback<F: FnOnce() + Send + Sync + 'static>(
context: *mut c_void,
_timer_fired: BOOLEAN,
) {
let _guard = AbortOnDrop;
let callback = Box::from_raw(context as *mut F);
callback();
forget(_guard);
}
let mut wait_handle = MaybeUninit::<RawHandle>::uninit();
let mut flags = WT_EXECUTEONLYONCE;
if long_wait {
flags |= WT_EXECUTELONGFUNCTION;
}
let res = unsafe {
RegisterWaitForSingleObject(
wait_handle.as_mut_ptr().cast::<_>(),
handle as _,
Some(wait_callback::<F>),
Box::into_raw(Box::new(callback)) as _,
timeout.map_or(INFINITE, dur2timeout),
flags,
)
};
if res == 0 {
return Err(io::Error::last_os_error());
}
let wait_handle = unsafe { wait_handle.assume_init() };
Ok(Self(wait_handle))
}
}
#[inline]
fn event_to_afd_mask(event: Event, error: bool) -> afd::AfdPollMask {
event_properties_to_afd_mask(event.readable, event.writable, error) | event.extra.flags
}
#[inline]
fn event_properties_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask {
use afd::AfdPollMask as AfdPoll;
let mut mask = AfdPoll::empty();
if error || readable || writable {
mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL;
}
if readable {
mask |=
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED;
}
if writable {
mask |= AfdPoll::SEND;
}
mask
}
#[inline]
fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) {
use afd::AfdPollMask as AfdPoll;
let mut readable = false;
let mut writable = false;
if mask.intersects(
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED,
) {
readable = true;
}
if mask.intersects(AfdPoll::SEND) {
writable = true;
}
if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) {
readable = true;
writable = true;
}
(readable, writable)
}
fn dur2timeout(dur: Duration) -> u32 {
dur.as_secs()
.checked_mul(1000)
.and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
.and_then(|ms| {
if dur.subsec_nanos() % 1_000_000 > 0 {
ms.checked_add(1)
} else {
Some(ms)
}
})
.and_then(|x| u32::try_from(x).ok())
.unwrap_or(INFINITE)
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}