use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use proto::{Chunk, Chunks, ConnectionError, ReadableError, StreamId};
use thiserror::Error;
use tokio::io::ReadBuf;
use crate::{
connection::{ClosedStream, ConnectionRef},
VarInt,
};
#[derive(Debug)]
pub struct RecvStream {
conn: ConnectionRef,
stream: StreamId,
is_0rtt: bool,
all_data_read: bool,
reset: Option<VarInt>,
}
impl RecvStream {
pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
all_data_read: false,
reset: None,
}
}
pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, ReadError> {
Read {
stream: self,
buf: ReadBuf::new(buf),
}
.await
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ReadExactError> {
ReadExact {
stream: self,
buf: ReadBuf::new(buf),
}
.await
}
pub fn poll_read(
&mut self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, ReadError>> {
let mut buf = ReadBuf::new(buf);
ready!(self.poll_read_buf(cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
fn poll_read_buf(
&mut self,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), ReadError>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
self.poll_read_generic(cx, true, |chunks| {
let mut read = false;
loop {
if buf.remaining() == 0 {
return ReadStatus::Readable(());
}
match chunks.next(buf.remaining()) {
Ok(Some(chunk)) => {
buf.put_slice(&chunk.bytes);
read = true;
}
res => return (if read { Some(()) } else { None }, res.err()).into(),
}
}
})
.map(|res| res.map(|_| ()))
}
pub async fn read_chunk(
&mut self,
max_length: usize,
ordered: bool,
) -> Result<Option<Chunk>, ReadError> {
ReadChunk {
stream: self,
max_length,
ordered,
}
.await
}
fn poll_read_chunk(
&mut self,
cx: &mut Context,
max_length: usize,
ordered: bool,
) -> Poll<Result<Option<Chunk>, ReadError>> {
self.poll_read_generic(cx, ordered, |chunks| match chunks.next(max_length) {
Ok(Some(chunk)) => ReadStatus::Readable(chunk),
res => (None, res.err()).into(),
})
}
pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
ReadChunks { stream: self, bufs }.await
}
fn poll_read_chunks(
&mut self,
cx: &mut Context,
bufs: &mut [Bytes],
) -> Poll<Result<Option<usize>, ReadError>> {
if bufs.is_empty() {
return Poll::Ready(Ok(Some(0)));
}
self.poll_read_generic(cx, true, |chunks| {
let mut read = 0;
loop {
if read >= bufs.len() {
return ReadStatus::Readable(read);
}
match chunks.next(usize::MAX) {
Ok(Some(chunk)) => {
bufs[read] = chunk.bytes;
read += 1;
}
res => return (if read == 0 { None } else { Some(read) }, res.err()).into(),
}
}
})
}
pub async fn read_to_end(&mut self, size_limit: usize) -> Result<Vec<u8>, ReadToEndError> {
ReadToEnd {
stream: self,
size_limit,
read: Vec::new(),
start: u64::max_value(),
end: 0,
}
.await
}
pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("RecvStream::stop");
if self.is_0rtt && conn.check_0rtt().is_err() {
return Ok(());
}
conn.inner.recv_stream(self.stream).stop(error_code)?;
conn.wake();
self.all_data_read = true;
Ok(())
}
pub fn is_0rtt(&self) -> bool {
self.is_0rtt
}
pub fn id(&self) -> StreamId {
self.stream
}
fn poll_read_generic<T, U>(
&mut self,
cx: &mut Context,
ordered: bool,
mut read_fn: T,
) -> Poll<Result<Option<U>, ReadError>>
where
T: FnMut(&mut Chunks) -> ReadStatus<U>,
{
use proto::ReadError::*;
if self.all_data_read {
return Poll::Ready(Ok(None));
}
let mut conn = self.conn.state.lock("RecvStream::poll_read");
if self.is_0rtt {
conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?;
}
let status = match self.reset.take() {
Some(code) => ReadStatus::Failed(None, Reset(code)),
None => {
let mut recv = conn.inner.recv_stream(self.stream);
let mut chunks = recv.read(ordered)?;
let status = read_fn(&mut chunks);
if chunks.finalize().should_transmit() {
conn.wake();
}
status
}
};
match status {
ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
ReadStatus::Finished(read) => {
self.all_data_read = true;
Poll::Ready(Ok(read))
}
ReadStatus::Failed(read, Blocked) => match read {
Some(val) => Poll::Ready(Ok(Some(val))),
None => {
if let Some(ref x) = conn.error {
return Poll::Ready(Err(ReadError::ConnectionLost(x.clone())));
}
conn.blocked_readers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
},
ReadStatus::Failed(read, Reset(error_code)) => match read {
None => {
self.all_data_read = true;
Poll::Ready(Err(ReadError::Reset(error_code)))
}
done => {
self.reset = Some(error_code);
Poll::Ready(Ok(done))
}
},
}
}
}
enum ReadStatus<T> {
Readable(T),
Finished(Option<T>),
Failed(Option<T>, proto::ReadError),
}
impl<T> From<(Option<T>, Option<proto::ReadError>)> for ReadStatus<T> {
fn from(status: (Option<T>, Option<proto::ReadError>)) -> Self {
match status {
(read, None) => Self::Finished(read),
(read, Some(e)) => Self::Failed(read, e),
}
}
}
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
struct ReadToEnd<'a> {
stream: &'a mut RecvStream,
read: Vec<(Bytes, u64)>,
start: u64,
end: u64,
size_limit: usize,
}
impl Future for ReadToEnd<'_> {
type Output = Result<Vec<u8>, ReadToEndError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
match ready!(self.stream.poll_read_chunk(cx, usize::MAX, false))? {
Some(chunk) => {
self.start = self.start.min(chunk.offset);
let end = chunk.bytes.len() as u64 + chunk.offset;
if (end - self.start) > self.size_limit as u64 {
return Poll::Ready(Err(ReadToEndError::TooLong));
}
self.end = self.end.max(end);
self.read.push((chunk.bytes, chunk.offset));
}
None => {
if self.end == 0 {
return Poll::Ready(Ok(Vec::new()));
}
let start = self.start;
let mut buffer = vec![0; (self.end - start) as usize];
for (data, offset) in self.read.drain(..) {
let offset = (offset - start) as usize;
buffer[offset..offset + data.len()].copy_from_slice(&data);
}
return Poll::Ready(Ok(buffer));
}
}
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadToEndError {
#[error("read error: {0}")]
Read(#[from] ReadError),
#[error("stream too long")]
TooLong,
}
#[cfg(feature = "futures-io")]
impl futures_io::AsyncRead for RecvStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(Self::poll_read_buf(self.get_mut(), cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl tokio::io::AsyncRead for RecvStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
ready!(Self::poll_read_buf(self.get_mut(), cx, buf))?;
Poll::Ready(Ok(()))
}
}
impl Drop for RecvStream {
fn drop(&mut self) {
let mut conn = self.conn.state.lock("RecvStream::drop");
conn.blocked_readers.remove(&self.stream);
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
if !self.all_data_read {
let _ = conn.inner.recv_stream(self.stream).stop(0u32.into());
conn.wake();
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadError {
#[error("stream reset by peer: error {0}")]
Reset(VarInt),
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("closed stream")]
ClosedStream,
#[error("ordered read after unordered read")]
IllegalOrderedRead,
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ReadableError> for ReadError {
fn from(e: ReadableError) -> Self {
match e {
ReadableError::ClosedStream => Self::ClosedStream,
ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
}
}
}
impl From<ReadError> for io::Error {
fn from(x: ReadError) -> Self {
use self::ReadError::*;
let kind = match x {
Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
IllegalOrderedRead => io::ErrorKind::InvalidInput,
};
Self::new(kind, x)
}
}
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
struct Read<'a> {
stream: &'a mut RecvStream,
buf: ReadBuf<'a>,
}
impl<'a> Future for Read<'a> {
type Output = Result<Option<usize>, ReadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
ready!(this.stream.poll_read_buf(cx, &mut this.buf))?;
match this.buf.filled().len() {
0 if this.buf.capacity() != 0 => Poll::Ready(Ok(None)),
n => Poll::Ready(Ok(Some(n))),
}
}
}
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
struct ReadExact<'a> {
stream: &'a mut RecvStream,
buf: ReadBuf<'a>,
}
impl<'a> Future for ReadExact<'a> {
type Output = Result<(), ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let total = this.buf.remaining();
let mut remaining = total;
while remaining > 0 {
ready!(this.stream.poll_read_buf(cx, &mut this.buf))?;
let new = this.buf.remaining();
if new == remaining {
let read = total - remaining;
return Poll::Ready(Err(ReadExactError::FinishedEarly(read)));
}
remaining = new;
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadExactError {
#[error("stream finished early ({0} bytes read)")]
FinishedEarly(usize),
#[error(transparent)]
ReadError(#[from] ReadError),
}
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
struct ReadChunk<'a> {
stream: &'a mut RecvStream,
max_length: usize,
ordered: bool,
}
impl<'a> Future for ReadChunk<'a> {
type Output = Result<Option<Chunk>, ReadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (max_length, ordered) = (self.max_length, self.ordered);
self.stream.poll_read_chunk(cx, max_length, ordered)
}
}
#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
struct ReadChunks<'a> {
stream: &'a mut RecvStream,
bufs: &'a mut [Bytes],
}
impl<'a> Future for ReadChunks<'a> {
type Output = Result<Option<usize>, ReadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
this.stream.poll_read_chunks(cx, this.bufs)
}
}