1use std::{
2 convert::TryFrom,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11 Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12 Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21 ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26 pub remote_address: Option<String>,
27 pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31 type Error = Error;
32 fn try_from(value: JsValue) -> Result<Self> {
33 let remote_address_value =
34 js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35 let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36 Ok(Self {
37 remote_address: remote_address_value.as_string(),
38 local_address: local_address_value.as_string(),
39 })
40 }
41}
42
43#[derive(Debug, Default)]
44enum Reading {
45 #[default]
46 None,
47 Pending(JsFuture, ReadableStreamDefaultReader),
48 Ready(Vec<u8>),
49}
50
51#[derive(Debug, Default)]
52enum Writing {
53 Pending(JsFuture, WritableStreamDefaultWriter, usize),
54 #[default]
55 None,
56}
57
58#[derive(Debug, Default)]
59enum Closing {
60 Pending(JsFuture),
61 #[default]
62 None,
63}
64
65#[derive(Debug)]
67pub struct Socket {
68 inner: worker_sys::Socket,
69 writable: WritableStream,
70 readable: ReadableStream,
71 write: Option<Writing>,
72 read: Option<Reading>,
73 close: Option<Closing>,
74}
75
76unsafe impl Send for Socket {}
78unsafe impl Sync for Socket {}
79
80impl Socket {
81 fn new(inner: worker_sys::Socket) -> Self {
82 let writable = inner.writable().unwrap();
83 let readable = inner.readable().unwrap();
84 Socket {
85 inner,
86 writable,
87 readable,
88 read: None,
89 write: None,
90 close: None,
91 }
92 }
93
94 pub async fn close(&mut self) -> Result<()> {
96 JsFuture::from(self.inner.close()?).await?;
97 Ok(())
98 }
99
100 pub async fn closed(&self) -> Result<()> {
103 JsFuture::from(self.inner.closed()?).await?;
104 Ok(())
105 }
106
107 pub async fn opened(&self) -> Result<SocketInfo> {
108 let value = JsFuture::from(self.inner.opened()?).await?;
109 value.try_into()
110 }
111
112 pub fn start_tls(self) -> Socket {
118 let inner = self.inner.start_tls().unwrap();
119 Socket::new(inner)
120 }
121
122 pub fn builder() -> ConnectionBuilder {
123 ConnectionBuilder::default()
124 }
125
126 fn handle_write_future(
127 cx: &mut Context<'_>,
128 mut fut: JsFuture,
129 writer: WritableStreamDefaultWriter,
130 len: usize,
131 ) -> (Writing, Poll<IoResult<usize>>) {
132 match fut.poll_unpin(cx) {
133 Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
134 Poll::Ready(res) => {
135 writer.release_lock();
136 match res {
137 Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
138 Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
139 }
140 }
141 }
142 }
143}
144
145fn js_value_to_std_io_error(value: JsValue) -> IoError {
146 let s = if value.is_string() {
147 value.as_string().unwrap()
148 } else if let Some(value) = value.dyn_ref::<JsError>() {
149 value.to_string().into()
150 } else {
151 format!("Error interpreting JsError: {value:?}")
152 };
153 IoError::other(s)
154}
155impl AsyncRead for Socket {
156 fn poll_read(
157 mut self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 buf: &mut ReadBuf<'_>,
160 ) -> Poll<IoResult<()>> {
161 fn handle_future(
162 cx: &mut Context<'_>,
163 buf: &mut ReadBuf<'_>,
164 mut fut: JsFuture,
165 reader: ReadableStreamDefaultReader,
166 ) -> (Reading, Poll<IoResult<()>>) {
167 match fut.poll_unpin(cx) {
168 Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
169 Poll::Ready(res) => match res {
170 Ok(value) => {
171 reader.release_lock();
172 let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
173 Ok(value) => value.into(),
174 Err(error) => {
175 let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}");
176 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
177 }
178 };
179 if done.is_truthy() {
180 (Reading::None, Poll::Ready(Ok(())))
181 } else {
182 let arr: Uint8Array = match Reflect::get(
183 &value,
184 &JsValue::from("value"),
185 ) {
186 Ok(value) => value.into(),
187 Err(error) => {
188 let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}");
189 return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
190 }
191 };
192 let data = arr.to_vec();
193 handle_data(buf, data)
194 }
195 }
196 Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
197 },
198 }
199 }
200
201 let (new_reading, poll) = match self.read.take().unwrap_or_default() {
202 Reading::None => {
203 let reader: ReadableStreamDefaultReader =
204 match self.readable.get_reader().dyn_into() {
205 Ok(reader) => reader,
206 Err(error) => {
207 let msg = format!(
208 "Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}"
209 );
210 return Poll::Ready(Err(IoError::other(msg)));
211 }
212 };
213
214 handle_future(cx, buf, JsFuture::from(reader.read()), reader)
215 }
216 Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
217 Reading::Ready(data) => handle_data(buf, data),
218 };
219 self.read = Some(new_reading);
220 poll
221 }
222}
223
224impl AsyncWrite for Socket {
225 fn poll_write(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 buf: &[u8],
229 ) -> Poll<IoResult<usize>> {
230 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
231 Writing::None => {
232 let obj = JsValue::from(Uint8Array::from(buf));
233 let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
234 Ok(writer) => writer,
235 Err(error) => {
236 let msg = format!("Could not retrieve Writer: {error:?}");
237 return Poll::Ready(Err(IoError::other(msg)));
238 }
239 };
240 Self::handle_write_future(
241 cx,
242 JsFuture::from(writer.write_with_chunk(&obj)),
243 writer,
244 buf.len(),
245 )
246 }
247 Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
248 };
249 self.write = Some(new_writing);
250 poll
251 }
252
253 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
254 let (new_writing, poll) = match self.write.take().unwrap_or_default() {
256 Writing::Pending(fut, writer, len) => {
257 let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
258 (writing, poll.map(|res| res.map(|_| ())))
260 }
261 writing => (writing, Poll::Ready(Ok(()))),
262 };
263 self.write = Some(new_writing);
264 poll
265 }
266
267 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
268 fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
269 match fut.poll_unpin(cx) {
270 Poll::Pending => (Closing::Pending(fut), Poll::Pending),
271 Poll::Ready(res) => match res {
272 Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
273 Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
274 },
275 }
276 }
277 let (new_closing, poll) = match self.close.take().unwrap_or_default() {
278 Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
279 Closing::Pending(fut) => handle_future(cx, fut),
280 };
281 self.close = Some(new_closing);
282 poll
283 }
284}
285
286#[derive(Debug, Clone)]
288pub enum SecureTransport {
289 Off,
291 On,
293 StartTls,
296}
297
298#[derive(Debug, Clone)]
300pub struct SocketOptions {
301 pub secure_transport: SecureTransport,
303 pub allow_half_open: bool,
308}
309
310impl Default for SocketOptions {
311 fn default() -> Self {
312 SocketOptions {
313 secure_transport: SecureTransport::Off,
314 allow_half_open: false,
315 }
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct SocketAddress {
322 pub hostname: String,
324 pub port: u16,
326}
327
328#[derive(Default, Debug, Clone)]
329pub struct ConnectionBuilder {
330 options: SocketOptions,
331}
332
333impl ConnectionBuilder {
334 pub fn new() -> Self {
336 ConnectionBuilder {
337 options: SocketOptions::default(),
338 }
339 }
340
341 pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
344 self.options.allow_half_open = allow_half_open;
345 self
346 }
347
348 pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
350 self.options.secure_transport = secure_transport;
351 self
352 }
353
354 pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
356 let address: JsValue = js_object!(
357 "hostname" => JsObject::from(JsString::from(hostname.into())),
358 "port" => JsNumber::from(port)
359 )
360 .into();
361
362 let options: JsValue = js_object!(
363 "allowHalfOpen" => JsBoolean::from(self.options.allow_half_open),
364 "secureTransport" => JsString::from(match self.options.secure_transport {
365 SecureTransport::On => "on",
366 SecureTransport::Off => "off",
367 SecureTransport::StartTls => "starttls",
368 })
369 )
370 .into();
371
372 let inner = worker_sys::connect(address, options)?;
373 Ok(Socket::new(inner))
374 }
375}
376
377fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
379 let idx = buf.remaining().min(data.len());
380 let store = data.split_off(idx);
381 buf.put_slice(&data);
382 if store.is_empty() {
383 (Reading::None, Poll::Ready(Ok(())))
384 } else {
385 (Reading::Ready(store), Poll::Ready(Ok(())))
386 }
387}
388
389#[cfg(feature = "tokio-postgres")]
390pub mod postgres_tls {
394 use super::Socket;
395 use futures_util::future::{ready, Ready};
396 use std::error::Error;
397 use std::fmt::{self, Display, Formatter};
398 use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
399
400 #[derive(Debug, Clone, Default)]
411 pub struct PassthroughTls;
412
413 #[derive(Debug)]
414 pub struct PassthroughTlsError;
417
418 impl Error for PassthroughTlsError {}
419
420 impl Display for PassthroughTlsError {
421 fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
422 fmt.write_str("PassthroughTlsError")
423 }
424 }
425
426 impl TlsConnect<Socket> for PassthroughTls {
427 type Stream = Socket;
428 type Error = PassthroughTlsError;
429 type Future = Ready<Result<Socket, PassthroughTlsError>>;
430
431 fn connect(self, s: Self::Stream) -> Self::Future {
432 let tls = s.start_tls();
433 ready(Ok(tls))
434 }
435 }
436
437 impl TlsStream for Socket {
438 fn channel_binding(&self) -> ChannelBinding {
439 ChannelBinding::none()
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 #[test]
448 fn test_handle_data() {
449 let mut arr = vec![0u8; 32];
450 let mut buf = ReadBuf::new(&mut arr);
451 let data = vec![1u8; 32];
452 let (reading, _) = handle_data(&mut buf, data);
453
454 assert!(matches!(reading, Reading::None));
455 assert_eq!(buf.remaining(), 0);
456 assert_eq!(buf.filled().len(), 32);
457 }
458
459 #[test]
460 fn test_handle_large_data() {
461 let mut arr = vec![0u8; 32];
462 let mut buf = ReadBuf::new(&mut arr);
463 let data = vec![1u8; 64];
464 let (reading, _) = handle_data(&mut buf, data);
465
466 assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
467 assert_eq!(buf.remaining(), 0);
468 assert_eq!(buf.filled().len(), 32);
469 }
470
471 #[test]
472 fn test_handle_small_data() {
473 let mut arr = vec![0u8; 32];
474 let mut buf = ReadBuf::new(&mut arr);
475 let data = vec![1u8; 16];
476 let (reading, _) = handle_data(&mut buf, data);
477
478 assert!(matches!(reading, Reading::None));
479 assert_eq!(buf.remaining(), 16);
480 assert_eq!(buf.filled().len(), 16);
481 }
482}