1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures_util::{Stream, TryStreamExt};
7use js_sys::{BigInt, Uint8Array};
8use pin_project::pin_project;
9use wasm_bindgen::{JsCast, JsValue};
10use wasm_streams::readable::IntoStream;
11use web_sys::ReadableStream;
12use worker_sys::FixedLengthStream as FixedLengthStreamSys;
13
14use crate::{Error, Result};
15
16#[pin_project]
17#[derive(Debug)]
18pub struct ByteStream {
19 #[pin]
20 pub(crate) inner: IntoStream<'static>,
21}
22
23impl Stream for ByteStream {
24 type Item = Result<Vec<u8>>;
25
26 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
27 let this = self.project();
28 let item = match futures_util::ready!(this.inner.poll_next(cx)) {
29 Some(res) => res.map(Uint8Array::from).map_err(Error::from),
30 None => return Poll::Ready(None),
31 };
32
33 Poll::Ready(match item {
34 Ok(value) => Some(Ok(value.to_vec())),
35 Err(e) if e.to_string() == "Error: aborted" => None,
36 Err(e) => Some(Err(e)),
37 })
38 }
39}
40
41#[pin_project]
42pub struct FixedLengthStream {
43 length: u64,
44 #[pin]
45 bytes_read: u64,
46 #[pin]
47 inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + 'static>>,
48}
49
50impl core::fmt::Debug for FixedLengthStream {
51 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52 f.debug_struct("FixedLengthStream")
53 .field("length", &self.length)
54 .field("bytes_read", &self.bytes_read)
55 .finish()
56 }
57}
58
59impl FixedLengthStream {
60 pub fn wrap(stream: impl Stream<Item = Result<Vec<u8>>> + 'static, length: u64) -> Self {
61 Self {
62 length,
63 bytes_read: 0,
64 inner: Box::pin(stream),
65 }
66 }
67}
68
69impl Stream for FixedLengthStream {
70 type Item = Result<Vec<u8>>;
71
72 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73 let mut this = self.project();
74 let item = if let Some(res) = futures_util::ready!(this.inner.poll_next(cx)) {
75 let chunk = match res {
76 Ok(chunk) => chunk,
77 Err(err) => return Poll::Ready(Some(Err(err))),
78 };
79
80 *this.bytes_read += chunk.len() as u64;
81
82 if *this.bytes_read > *this.length {
83 let err = Error::from(format!(
84 "fixed length stream had different length than expected (expected {}, got {})",
85 *this.length, *this.bytes_read,
86 ));
87 Some(Err(err))
88 } else {
89 Some(Ok(chunk))
90 }
91 } else if *this.bytes_read != *this.length {
92 let err = Error::from(format!(
93 "fixed length stream had different length than expected (expected {}, got {})",
94 *this.length, *this.bytes_read,
95 ));
96 Some(Err(err))
97 } else {
98 None
99 };
100
101 Poll::Ready(item)
102 }
103}
104
105impl From<FixedLengthStream> for FixedLengthStreamSys {
106 fn from(stream: FixedLengthStream) -> Self {
107 let raw = if stream.length < u32::MAX as u64 {
108 FixedLengthStreamSys::new(stream.length as u32).unwrap()
109 } else {
110 FixedLengthStreamSys::new_big_int(BigInt::from(stream.length)).unwrap()
111 };
112
113 let js_stream = stream
114 .map_ok(|item| -> Vec<u8> { item })
115 .map_ok(|chunk| {
116 let array = Uint8Array::new_with_length(chunk.len() as _);
117 array.copy_from(&chunk);
118
119 array.into()
120 })
121 .map_err(JsValue::from);
122
123 let stream: ReadableStream = wasm_streams::ReadableStream::from_stream(js_stream)
124 .as_raw()
125 .clone()
126 .unchecked_into();
127 let _ = stream.pipe_to(&raw.writable());
128
129 raw
130 }
131}