use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::loom::sync::Mutex;
use bytes::{Buf, BytesMut};
use std::{
pin::Pin,
sync::Arc,
task::{self, Poll, Waker},
};
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
pub struct DuplexStream {
read: Arc<Mutex<Pipe>>,
write: Arc<Mutex<Pipe>>,
}
#[derive(Debug)]
struct Pipe {
buffer: BytesMut,
is_closed: bool,
max_buf_size: usize,
read_waker: Option<Waker>,
write_waker: Option<Waker>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
(
DuplexStream {
read: one.clone(),
write: two.clone(),
},
DuplexStream {
read: two,
write: one,
},
)
}
impl AsyncRead for DuplexStream {
#[allow(unused_mut)]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
}
}
impl AsyncWrite for DuplexStream {
#[allow(unused_mut)]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
}
#[allow(unused_mut)]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.write.lock()).poll_flush(cx)
}
#[allow(unused_mut)]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
}
}
impl Drop for DuplexStream {
fn drop(&mut self) {
self.write.lock().close_write();
self.read.lock().close_read();
}
}
impl Pipe {
fn new(max_buf_size: usize) -> Self {
Pipe {
buffer: BytesMut::new(),
is_closed: false,
max_buf_size,
read_waker: None,
write_waker: None,
}
}
fn close_write(&mut self) {
self.is_closed = true;
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
}
fn close_read(&mut self) {
self.is_closed = true;
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
}
impl AsyncRead for Pipe {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.buffer.has_remaining() {
let max = self.buffer.remaining().min(buf.remaining());
buf.put_slice(&self.buffer[..max]);
self.buffer.advance(max);
if max > 0 {
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
}
Poll::Ready(Ok(()))
} else if self.is_closed {
Poll::Ready(Ok(()))
} else {
self.read_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl AsyncWrite for Pipe {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
let avail = self.max_buf_size - self.buffer.len();
if avail == 0 {
self.write_waker = Some(cx.waker().clone());
return Poll::Pending;
}
let len = buf.len().min(avail);
self.buffer.extend_from_slice(&buf[..len]);
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.close_write();
Poll::Ready(Ok(()))
}
}