#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::future::Future;
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::mem::ManuallyDrop;
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
#[cfg(unix)]
use std::{
os::unix::io::{AsRawFd, FromRawFd, RawFd},
os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream},
path::Path,
};
use futures_lite::io::{AsyncRead, AsyncWrite};
use futures_lite::stream::{self, Stream};
use futures_lite::{future, pin};
use socket2::{Domain, Protocol, Socket, Type};
use crate::reactor::{Reactor, Source};
pub mod parking;
mod reactor;
#[derive(Debug)]
pub struct Timer {
id_and_waker: Option<(usize, Waker)>,
when: Instant,
}
impl Timer {
pub fn new(dur: Duration) -> Timer {
Timer {
id_and_waker: None,
when: Instant::now() + dur,
}
}
pub fn reset(&mut self, dur: Duration) {
if let Some((id, _)) = self.id_and_waker.as_ref() {
Reactor::get().remove_timer(self.when, *id);
}
self.when = Instant::now() + dur;
if let Some((id, waker)) = self.id_and_waker.as_mut() {
*id = Reactor::get().insert_timer(self.when, waker);
}
}
}
impl Drop for Timer {
fn drop(&mut self) {
if let Some((id, _)) = self.id_and_waker.take() {
Reactor::get().remove_timer(self.when, id);
}
}
}
impl Future for Timer {
type Output = Instant;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if Instant::now() >= self.when {
if let Some((id, _)) = self.id_and_waker.take() {
Reactor::get().remove_timer(self.when, id);
}
Poll::Ready(self.when)
} else {
match &self.id_and_waker {
None => {
let id = Reactor::get().insert_timer(self.when, cx.waker());
self.id_and_waker = Some((id, cx.waker().clone()));
}
Some((id, w)) if !w.will_wake(cx.waker()) => {
Reactor::get().remove_timer(self.when, *id);
let id = Reactor::get().insert_timer(self.when, cx.waker());
self.id_and_waker = Some((id, cx.waker().clone()));
}
Some(_) => {}
}
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct Async<T> {
source: Arc<Source>,
io: Option<Box<T>>,
}
#[cfg(unix)]
impl<T: AsRawFd> Async<T> {
pub fn new(io: T) -> io::Result<Async<T>> {
Ok(Async {
source: Reactor::get().insert_io(io.as_raw_fd())?,
io: Some(Box::new(io)),
})
}
}
#[cfg(unix)]
impl<T: AsRawFd> AsRawFd for Async<T> {
fn as_raw_fd(&self) -> RawFd {
self.source.raw
}
}
#[cfg(windows)]
impl<T: AsRawSocket> Async<T> {
pub fn new(io: T) -> io::Result<Async<T>> {
Ok(Async {
source: Reactor::get().insert_io(io.as_raw_socket())?,
io: Some(Box::new(io)),
})
}
}
#[cfg(windows)]
impl<T: AsRawSocket> AsRawSocket for Async<T> {
fn as_raw_socket(&self) -> RawSocket {
self.source.raw
}
}
impl<T> Async<T> {
pub fn get_ref(&self) -> &T {
self.io.as_ref().unwrap()
}
pub fn get_mut(&mut self) -> &mut T {
self.io.as_mut().unwrap()
}
pub fn into_inner(mut self) -> io::Result<T> {
let io = *self.io.take().unwrap();
Reactor::get().remove_io(&self.source)?;
Ok(io)
}
pub async fn readable(&self) -> io::Result<()> {
self.source.readable().await
}
pub async fn writable(&self) -> io::Result<()> {
self.source.writable().await
}
pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
let mut op = op;
loop {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.readable()).await?;
}
}
pub async fn read_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.readable()).await?;
}
}
pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
let mut op = op;
loop {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.writable()).await?;
}
}
pub async fn write_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
}
optimistic(self.writable()).await?;
}
}
}
impl<T> Drop for Async<T> {
fn drop(&mut self) {
if self.io.is_some() {
let _ = Reactor::get().remove_io(&self.source);
self.io.take();
}
}
}
impl<T: Read> AsyncRead for Async<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with_mut(|io| io.read(buf)))
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with_mut(|io| io.read_vectored(bufs)))
}
}
impl<T> AsyncRead for &Async<T>
where
for<'a> &'a T: Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with(|io| (&*io).read(buf)))
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.read_with(|io| (&*io).read_vectored(bufs)))
}
}
impl<T: Write> AsyncWrite for Async<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with_mut(|io| io.write(buf)))
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with_mut(|io| io.write_vectored(bufs)))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
poll_future(cx, self.write_with_mut(|io| io.flush()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(shutdown_write(self.source.raw))
}
}
impl<T> AsyncWrite for &Async<T>
where
for<'a> &'a T: Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with(|io| (&*io).write(buf)))
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
poll_future(cx, self.write_with(|io| (&*io).write_vectored(bufs)))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
poll_future(cx, self.write_with(|io| (&*io).flush()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(shutdown_write(self.source.raw))
}
}
impl Async<TcpListener> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpListener>> {
let addr = addr.into();
Ok(Async::new(TcpListener::bind(addr)?)?)
}
pub async fn accept(&self) -> io::Result<(Async<TcpStream>, SocketAddr)> {
let (stream, addr) = self.read_with(|io| io.accept()).await?;
Ok((Async::new(stream)?, addr))
}
pub fn incoming(&self) -> impl Stream<Item = io::Result<Async<TcpStream>>> + Send + Unpin + '_ {
Box::pin(stream::unfold(self, |listener| async move {
let res = listener.accept().await.map(|(stream, _)| stream);
Some((res, listener))
}))
}
}
impl Async<TcpStream> {
pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
let addr = addr.into();
let domain = if addr.is_ipv6() {
Domain::ipv6()
} else {
Domain::ipv4()
};
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
socket.set_nonblocking(true)?;
socket.connect(&addr.into()).or_else(|err| {
#[cfg(unix)]
let in_progress = err.raw_os_error() == Some(libc::EINPROGRESS);
#[cfg(windows)]
let in_progress = err.kind() == io::ErrorKind::WouldBlock;
if in_progress {
Ok(())
} else {
Err(err)
}
})?;
let stream = Async::new(socket.into_tcp_stream())?;
stream.writable().await?;
match stream.get_ref().take_error()? {
None => Ok(stream),
Some(err) => Err(err),
}
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.peek(buf)).await
}
}
impl Async<UdpSocket> {
pub fn bind<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<UdpSocket>> {
let addr = addr.into();
Ok(Async::new(UdpSocket::bind(addr)?)?)
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.read_with(|io| io.recv_from(buf)).await
}
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.read_with(|io| io.peek_from(buf)).await
}
pub async fn send_to<A: Into<SocketAddr>>(&self, buf: &[u8], addr: A) -> io::Result<usize> {
let addr = addr.into();
self.write_with(|io| io.send_to(buf, addr)).await
}
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.recv(buf)).await
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.peek(buf)).await
}
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
self.write_with(|io| io.send(buf)).await
}
}
#[cfg(unix)]
impl Async<UnixListener> {
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixListener>> {
let path = path.as_ref().to_owned();
Ok(Async::new(UnixListener::bind(path)?)?)
}
pub async fn accept(&self) -> io::Result<(Async<UnixStream>, UnixSocketAddr)> {
let (stream, addr) = self.read_with(|io| io.accept()).await?;
Ok((Async::new(stream)?, addr))
}
pub fn incoming(
&self,
) -> impl Stream<Item = io::Result<Async<UnixStream>>> + Send + Unpin + '_ {
Box::pin(stream::unfold(self, |listener| async move {
let res = listener.accept().await.map(|(stream, _)| stream);
Some((res, listener))
}))
}
}
#[cfg(unix)]
impl Async<UnixStream> {
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
let socket = Socket::new(Domain::unix(), Type::stream(), None)?;
socket.set_nonblocking(true)?;
socket
.connect(&socket2::SockAddr::unix(path)?)
.or_else(|err| {
if err.raw_os_error() == Some(libc::EINPROGRESS) {
Ok(())
} else {
Err(err)
}
})?;
let stream = Async::new(socket.into_unix_stream())?;
stream.writable().await?;
Ok(stream)
}
pub fn pair() -> io::Result<(Async<UnixStream>, Async<UnixStream>)> {
let (stream1, stream2) = UnixStream::pair()?;
Ok((Async::new(stream1)?, Async::new(stream2)?))
}
}
#[cfg(unix)]
impl Async<UnixDatagram> {
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixDatagram>> {
let path = path.as_ref().to_owned();
Ok(Async::new(UnixDatagram::bind(path)?)?)
}
pub fn unbound() -> io::Result<Async<UnixDatagram>> {
Ok(Async::new(UnixDatagram::unbound()?)?)
}
pub fn pair() -> io::Result<(Async<UnixDatagram>, Async<UnixDatagram>)> {
let (socket1, socket2) = UnixDatagram::pair()?;
Ok((Async::new(socket1)?, Async::new(socket2)?))
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> {
self.read_with(|io| io.recv_from(buf)).await
}
pub async fn send_to<P: AsRef<Path>>(&self, buf: &[u8], path: P) -> io::Result<usize> {
self.write_with(|io| io.send_to(buf, &path)).await
}
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_with(|io| io.recv(buf)).await
}
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
self.write_with(|io| io.send(buf)).await
}
}
fn poll_future<T>(cx: &mut Context<'_>, fut: impl Future<Output = T>) -> Poll<T> {
pin!(fut);
fut.poll(cx)
}
async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()> {
let mut polled = false;
pin!(fut);
future::poll_fn(|cx| {
if !polled {
polled = true;
fut.as_mut().poll(cx)
} else {
Poll::Ready(Ok(()))
}
})
.await
}
pub fn shutdown_write(#[cfg(unix)] raw: RawFd, #[cfg(windows)] raw: RawSocket) -> io::Result<()> {
let stream = unsafe {
ManuallyDrop::new(
#[cfg(unix)]
TcpStream::from_raw_fd(raw),
#[cfg(windows)]
TcpStream::from_raw_socket(raw),
)
};
match stream.shutdown(Shutdown::Write) {
Err(err) if err.kind() == io::ErrorKind::NotConnected => Err(err),
_ => Ok(()),
}
}