[go: up one dir, main page]

worker/
streams.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures_util::{Stream, TryStreamExt};
7use js_sys::{BigInt, Uint8Array};
8use pin_project::pin_project;
9use wasm_bindgen::{JsCast, JsValue};
10use wasm_streams::readable::IntoStream;
11use web_sys::ReadableStream;
12use worker_sys::FixedLengthStream as FixedLengthStreamSys;
13
14use crate::{Error, Result};
15
16#[pin_project]
17#[derive(Debug)]
18pub struct ByteStream {
19    #[pin]
20    pub(crate) inner: IntoStream<'static>,
21}
22
23impl Stream for ByteStream {
24    type Item = Result<Vec<u8>>;
25
26    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
27        let this = self.project();
28        let item = match futures_util::ready!(this.inner.poll_next(cx)) {
29            Some(res) => res.map(Uint8Array::from).map_err(Error::from),
30            None => return Poll::Ready(None),
31        };
32
33        Poll::Ready(match item {
34            Ok(value) => Some(Ok(value.to_vec())),
35            Err(e) if e.to_string() == "Error: aborted" => None,
36            Err(e) => Some(Err(e)),
37        })
38    }
39}
40
41#[pin_project]
42pub struct FixedLengthStream {
43    length: u64,
44    #[pin]
45    bytes_read: u64,
46    #[pin]
47    inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + 'static>>,
48}
49
50impl core::fmt::Debug for FixedLengthStream {
51    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52        f.debug_struct("FixedLengthStream")
53            .field("length", &self.length)
54            .field("bytes_read", &self.bytes_read)
55            .finish()
56    }
57}
58
59impl FixedLengthStream {
60    pub fn wrap(stream: impl Stream<Item = Result<Vec<u8>>> + 'static, length: u64) -> Self {
61        Self {
62            length,
63            bytes_read: 0,
64            inner: Box::pin(stream),
65        }
66    }
67}
68
69impl Stream for FixedLengthStream {
70    type Item = Result<Vec<u8>>;
71
72    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73        let mut this = self.project();
74        let item = if let Some(res) = futures_util::ready!(this.inner.poll_next(cx)) {
75            let chunk = match res {
76                Ok(chunk) => chunk,
77                Err(err) => return Poll::Ready(Some(Err(err))),
78            };
79
80            *this.bytes_read += chunk.len() as u64;
81
82            if *this.bytes_read > *this.length {
83                let err = Error::from(format!(
84                    "fixed length stream had different length than expected (expected {}, got {})",
85                    *this.length, *this.bytes_read,
86                ));
87                Some(Err(err))
88            } else {
89                Some(Ok(chunk))
90            }
91        } else if *this.bytes_read != *this.length {
92            let err = Error::from(format!(
93                "fixed length stream had different length than expected (expected {}, got {})",
94                *this.length, *this.bytes_read,
95            ));
96            Some(Err(err))
97        } else {
98            None
99        };
100
101        Poll::Ready(item)
102    }
103}
104
105impl From<FixedLengthStream> for FixedLengthStreamSys {
106    fn from(stream: FixedLengthStream) -> Self {
107        let raw = if stream.length < u32::MAX as u64 {
108            FixedLengthStreamSys::new(stream.length as u32).unwrap()
109        } else {
110            FixedLengthStreamSys::new_big_int(BigInt::from(stream.length)).unwrap()
111        };
112
113        let js_stream = stream
114            .map_ok(|item| -> Vec<u8> { item })
115            .map_ok(|chunk| {
116                let array = Uint8Array::new_with_length(chunk.len() as _);
117                array.copy_from(&chunk);
118
119                array.into()
120            })
121            .map_err(JsValue::from);
122
123        let stream: ReadableStream = wasm_streams::ReadableStream::from_stream(js_stream)
124            .as_raw()
125            .clone()
126            .unchecked_into();
127        let _ = stream.pipe_to(&raw.writable());
128
129        raw
130    }
131}