use std::fmt;
use std::io;
use std::mem;
use std::net::{Shutdown, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use async_ready::{AsyncReadReady, AsyncWriteReady};
use futures::io::{AsyncRead, AsyncWrite};
use futures::Future;
use mio;
use crate::raw::PollEvented;
pub struct TcpStream {
io: PollEvented<mio::net::TcpStream>,
}
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct ConnectFuture {
inner: ConnectFutureState,
}
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
enum ConnectFutureState {
Waiting(TcpStream),
Error(io::Error),
Empty,
}
impl Unpin for TcpStream {}
impl TcpStream {
pub fn connect(addr: &SocketAddr) -> ConnectFuture {
use self::ConnectFutureState::*;
let inner = match mio::net::TcpStream::connect(addr) {
Ok(tcp) => Waiting(TcpStream::new(tcp)),
Err(e) => Error(e),
};
ConnectFuture { inner }
}
pub(crate) fn new(connected: mio::net::TcpStream) -> TcpStream {
let io = PollEvented::new(connected);
TcpStream { io }
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().peer_addr()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io.get_ref().shutdown(how)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.io.get_ref().nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.io.get_ref().set_nodelay(nodelay)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
self.io.get_ref().recv_buffer_size()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
self.io.get_ref().set_recv_buffer_size(size)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.io.get_ref().send_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.io.get_ref().set_send_buffer_size(size)
}
pub fn keepalive(&self) -> io::Result<Option<Duration>> {
self.io.get_ref().keepalive()
}
pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
self.io.get_ref().set_keepalive(keepalive)
}
pub fn ttl(&self) -> io::Result<u32> {
self.io.get_ref().ttl()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.io.get_ref().set_ttl(ttl)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.io.get_ref().linger()
}
pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
self.io.get_ref().set_linger(dur)
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_close(cx)
}
}
impl AsyncReadReady for TcpStream {
type Ok = mio::Ready;
type Err = io::Error;
fn poll_read_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Ok, Self::Err>> {
Pin::new(&mut self.io).poll_read_ready(cx)
}
}
impl AsyncWriteReady for TcpStream {
type Ok = mio::Ready;
type Err = io::Error;
fn poll_write_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Ok, Self::Err>> {
self.io.poll_write_ready(cx)
}
}
impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.io.get_ref().fmt(f)
}
}
impl Future for ConnectFuture {
type Output = io::Result<TcpStream>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<TcpStream>> {
match mem::replace(&mut self.inner, ConnectFutureState::Empty) {
ConnectFutureState::Waiting(stream) => {
if let Poll::Pending = stream.io.poll_write_ready(cx)? {
self.inner = ConnectFutureState::Waiting(stream);
return Poll::Pending;
}
if let Some(e) = stream.io.get_ref().take_error()? {
return Poll::Ready(Err(e));
}
Poll::Ready(Ok(stream))
}
ConnectFutureState::Error(e) => Poll::Ready(Err(e)),
ConnectFutureState::Empty => panic!("can't poll TCP stream twice"),
}
}
}
impl std::convert::TryFrom<std::net::TcpStream> for TcpStream {
type Error = io::Error;
fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
let tcp = mio::net::TcpStream::from_stream(stream)?;
Ok(TcpStream::new(tcp))
}
}
impl std::convert::TryFrom<&std::net::SocketAddr> for TcpStream {
type Error = io::Error;
fn try_from(addr: &std::net::SocketAddr) -> Result<Self, Self::Error> {
let tcp = mio::net::TcpStream::connect(&addr)?;
Ok(TcpStream::new(tcp))
}
}
#[cfg(unix)]
mod sys {
use super::TcpStream;
use std::os::unix::prelude::*;
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.io.get_ref().as_raw_fd()
}
}
}