use super::ucred::{self, UCred};
use crate::raw::PollEvented;
use async_ready::{AsyncReadReady, AsyncWriteReady, TakeError};
use futures::io::{AsyncRead, AsyncWrite};
use futures::{ready, Future, Poll};
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::task::Context;
pub struct UnixStream {
io: PollEvented<mio_uds::UnixStream>,
}
#[derive(Debug)]
pub struct ConnectFuture {
inner: State,
}
#[derive(Debug)]
enum State {
Waiting(UnixStream),
Error(io::Error),
Empty,
}
impl Unpin for UnixStream {}
impl UnixStream {
pub fn connect(path: impl AsRef<Path>) -> ConnectFuture {
let res = mio_uds::UnixStream::connect(path).map(UnixStream::new);
let inner = match res {
Ok(stream) => State::Waiting(stream),
Err(e) => State::Error(e),
};
ConnectFuture { inner }
}
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
let (a, b) = mio_uds::UnixStream::pair()?;
let a = UnixStream::new(a);
let b = UnixStream::new(b);
Ok((a, b))
}
pub(crate) fn new(stream: mio_uds::UnixStream) -> UnixStream {
let io = PollEvented::new(stream);
UnixStream { io }
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().peer_addr()
}
pub fn peer_cred(&self) -> io::Result<UCred> {
ucred::get_peer_cred(self)
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io.get_ref().shutdown(how)
}
}
impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_close(cx)
}
}
impl AsyncReadReady for UnixStream {
type Ok = mio::Ready;
type Err = io::Error;
fn poll_read_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Ok, Self::Err>> {
Pin::new(&mut self.io).poll_read_ready(cx)
}
}
impl AsyncWriteReady for UnixStream {
type Ok = mio::Ready;
type Err = io::Error;
fn poll_write_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Ok, Self::Err>> {
Pin::new(&mut self.io).poll_write_ready(cx)
}
}
impl TakeError for UnixStream {
type Ok = io::Error;
type Err = io::Error;
fn take_error(&self) -> Result<Option<Self::Ok>, Self::Err> {
self.io.get_ref().take_error()
}
}
impl fmt::Debug for UnixStream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.io.get_ref().fmt(f)
}
}
impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd {
self.io.get_ref().as_raw_fd()
}
}
impl Future for ConnectFuture {
type Output = io::Result<UnixStream>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<UnixStream>> {
use std::mem;
match self.inner {
State::Waiting(ref mut stream) => {
ready!(stream.io.poll_write_ready(cx)?);
if let Some(e) = stream.io.get_ref().take_error()? {
return Poll::Ready(Err(e));
}
}
State::Error(_) => {
let e = match mem::replace(&mut self.inner, State::Empty) {
State::Error(e) => e,
_ => unreachable!(),
};
return Poll::Ready(Err(e));
}
State::Empty => panic!("can't poll stream twice"),
}
match mem::replace(&mut self.inner, State::Empty) {
State::Waiting(stream) => Poll::Ready(Ok(stream)),
_ => unreachable!(),
}
}
}