[go: up one dir, main page]

worker/
socket.rs

1use std::{
2    convert::TryFrom,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11    Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12    Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21    ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26    pub remote_address: Option<String>,
27    pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31    type Error = Error;
32    fn try_from(value: JsValue) -> Result<Self> {
33        let remote_address_value =
34            js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35        let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36        Ok(Self {
37            remote_address: remote_address_value.as_string(),
38            local_address: local_address_value.as_string(),
39        })
40    }
41}
42
43#[derive(Debug, Default)]
44enum Reading {
45    #[default]
46    None,
47    Pending(JsFuture, ReadableStreamDefaultReader),
48    Ready(Vec<u8>),
49}
50
51#[derive(Debug, Default)]
52enum Writing {
53    Pending(JsFuture, WritableStreamDefaultWriter, usize),
54    #[default]
55    None,
56}
57
58#[derive(Debug, Default)]
59enum Closing {
60    Pending(JsFuture),
61    #[default]
62    None,
63}
64
65/// Represents an outbound TCP connection from your Worker.
66#[derive(Debug)]
67pub struct Socket {
68    inner: worker_sys::Socket,
69    writable: WritableStream,
70    readable: ReadableStream,
71    write: Option<Writing>,
72    read: Option<Reading>,
73    close: Option<Closing>,
74}
75
76// This can only be done because workers are single threaded.
77unsafe impl Send for Socket {}
78unsafe impl Sync for Socket {}
79
80impl Socket {
81    fn new(inner: worker_sys::Socket) -> Self {
82        let writable = inner.writable().unwrap();
83        let readable = inner.readable().unwrap();
84        Socket {
85            inner,
86            writable,
87            readable,
88            read: None,
89            write: None,
90            close: None,
91        }
92    }
93
94    /// Closes the TCP socket. Both the readable and writable streams are forcibly closed.
95    pub async fn close(&mut self) -> Result<()> {
96        JsFuture::from(self.inner.close()?).await?;
97        Ok(())
98    }
99
100    /// This Future is resolved when the socket is closed
101    /// and is rejected if the socket encounters an error.
102    pub async fn closed(&self) -> Result<()> {
103        JsFuture::from(self.inner.closed()?).await?;
104        Ok(())
105    }
106
107    pub async fn opened(&self) -> Result<SocketInfo> {
108        let value = JsFuture::from(self.inner.opened()?).await?;
109        value.try_into()
110    }
111
112    /// Upgrades an insecure socket to a secure one that uses TLS,
113    /// returning a new Socket. Note that in order to call this method,
114    /// you must set [`secure_transport`](SocketOptions::secure_transport)
115    /// to [`StartTls`](SecureTransport::StartTls) when initially
116    /// calling [`connect`](connect) to create the socket.
117    pub fn start_tls(self) -> Socket {
118        let inner = self.inner.start_tls().unwrap();
119        Socket::new(inner)
120    }
121
122    pub fn builder() -> ConnectionBuilder {
123        ConnectionBuilder::default()
124    }
125
126    fn handle_write_future(
127        cx: &mut Context<'_>,
128        mut fut: JsFuture,
129        writer: WritableStreamDefaultWriter,
130        len: usize,
131    ) -> (Writing, Poll<IoResult<usize>>) {
132        match fut.poll_unpin(cx) {
133            Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
134            Poll::Ready(res) => {
135                writer.release_lock();
136                match res {
137                    Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
138                    Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
139                }
140            }
141        }
142    }
143}
144
145fn js_value_to_std_io_error(value: JsValue) -> IoError {
146    let s = if value.is_string() {
147        value.as_string().unwrap()
148    } else if let Some(value) = value.dyn_ref::<JsError>() {
149        value.to_string().into()
150    } else {
151        format!("Error interpreting JsError: {value:?}")
152    };
153    IoError::other(s)
154}
155impl AsyncRead for Socket {
156    fn poll_read(
157        mut self: Pin<&mut Self>,
158        cx: &mut Context<'_>,
159        buf: &mut ReadBuf<'_>,
160    ) -> Poll<IoResult<()>> {
161        fn handle_future(
162            cx: &mut Context<'_>,
163            buf: &mut ReadBuf<'_>,
164            mut fut: JsFuture,
165            reader: ReadableStreamDefaultReader,
166        ) -> (Reading, Poll<IoResult<()>>) {
167            match fut.poll_unpin(cx) {
168                Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
169                Poll::Ready(res) => match res {
170                    Ok(value) => {
171                        reader.release_lock();
172                        let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
173                            Ok(value) => value.into(),
174                            Err(error) => {
175                                let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}");
176                                return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
177                            }
178                        };
179                        if done.is_truthy() {
180                            (Reading::None, Poll::Ready(Ok(())))
181                        } else {
182                            let arr: Uint8Array = match Reflect::get(
183                                &value,
184                                &JsValue::from("value"),
185                            ) {
186                                Ok(value) => value.into(),
187                                Err(error) => {
188                                    let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}");
189                                    return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
190                                }
191                            };
192                            let data = arr.to_vec();
193                            handle_data(buf, data)
194                        }
195                    }
196                    Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
197                },
198            }
199        }
200
201        let (new_reading, poll) = match self.read.take().unwrap_or_default() {
202            Reading::None => {
203                let reader: ReadableStreamDefaultReader =
204                    match self.readable.get_reader().dyn_into() {
205                        Ok(reader) => reader,
206                        Err(error) => {
207                            let msg = format!(
208                                "Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}"
209                            );
210                            return Poll::Ready(Err(IoError::other(msg)));
211                        }
212                    };
213
214                handle_future(cx, buf, JsFuture::from(reader.read()), reader)
215            }
216            Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
217            Reading::Ready(data) => handle_data(buf, data),
218        };
219        self.read = Some(new_reading);
220        poll
221    }
222}
223
224impl AsyncWrite for Socket {
225    fn poll_write(
226        mut self: Pin<&mut Self>,
227        cx: &mut Context<'_>,
228        buf: &[u8],
229    ) -> Poll<IoResult<usize>> {
230        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
231            Writing::None => {
232                let obj = JsValue::from(Uint8Array::from(buf));
233                let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
234                    Ok(writer) => writer,
235                    Err(error) => {
236                        let msg = format!("Could not retrieve Writer: {error:?}");
237                        return Poll::Ready(Err(IoError::other(msg)));
238                    }
239                };
240                Self::handle_write_future(
241                    cx,
242                    JsFuture::from(writer.write_with_chunk(&obj)),
243                    writer,
244                    buf.len(),
245                )
246            }
247            Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
248        };
249        self.write = Some(new_writing);
250        poll
251    }
252
253    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
254        // Poll existing write future if it exists.
255        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
256            Writing::Pending(fut, writer, len) => {
257                let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
258                // Map poll output to ()
259                (writing, poll.map(|res| res.map(|_| ())))
260            }
261            writing => (writing, Poll::Ready(Ok(()))),
262        };
263        self.write = Some(new_writing);
264        poll
265    }
266
267    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
268        fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
269            match fut.poll_unpin(cx) {
270                Poll::Pending => (Closing::Pending(fut), Poll::Pending),
271                Poll::Ready(res) => match res {
272                    Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
273                    Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
274                },
275            }
276        }
277        let (new_closing, poll) = match self.close.take().unwrap_or_default() {
278            Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
279            Closing::Pending(fut) => handle_future(cx, fut),
280        };
281        self.close = Some(new_closing);
282        poll
283    }
284}
285
286/// Secure transport options for outbound TCP connections.
287#[derive(Debug, Clone)]
288pub enum SecureTransport {
289    /// Do not use TLS.
290    Off,
291    /// Use TLS.
292    On,
293    /// Do not use TLS initially, but allow the socket to be upgraded to
294    /// use TLS by calling [`Socket.start_tls`](Socket::start_tls).
295    StartTls,
296}
297
298/// Used to configure outbound TCP connections.
299#[derive(Debug, Clone)]
300pub struct SocketOptions {
301    /// Specifies whether or not to use TLS when creating the TCP socket.
302    pub secure_transport: SecureTransport,
303    /// Defines whether the writable side of the TCP socket will automatically
304    /// close on end-of-file (EOF). When set to false, the writable side of the
305    /// TCP socket will automatically close on EOF. When set to true, the
306    /// writable side of the TCP socket will remain open on EOF.
307    pub allow_half_open: bool,
308}
309
310impl Default for SocketOptions {
311    fn default() -> Self {
312        SocketOptions {
313            secure_transport: SecureTransport::Off,
314            allow_half_open: false,
315        }
316    }
317}
318
319/// The host and port that you wish to connect to.
320#[derive(Debug, Clone)]
321pub struct SocketAddress {
322    /// The hostname to connect to. Example: `cloudflare.com`.
323    pub hostname: String,
324    /// The port number to connect to. Example: `5432`.
325    pub port: u16,
326}
327
328#[derive(Default, Debug, Clone)]
329pub struct ConnectionBuilder {
330    options: SocketOptions,
331}
332
333impl ConnectionBuilder {
334    /// Create a new `ConnectionBuilder` with default settings.
335    pub fn new() -> Self {
336        ConnectionBuilder {
337            options: SocketOptions::default(),
338        }
339    }
340
341    /// Set whether the writable side of the TCP socket will automatically
342    /// close on end-of-file (EOF).
343    pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
344        self.options.allow_half_open = allow_half_open;
345        self
346    }
347
348    // Specify whether or not to use TLS when creating the TCP socket.
349    pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
350        self.options.secure_transport = secure_transport;
351        self
352    }
353
354    /// Open the connection to `hostname` on port `port`, returning a [`Socket`](Socket).
355    pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
356        let address: JsValue = js_object!(
357            "hostname" => JsObject::from(JsString::from(hostname.into())),
358            "port" => JsNumber::from(port)
359        )
360        .into();
361
362        let options: JsValue = js_object!(
363            "allowHalfOpen" => JsBoolean::from(self.options.allow_half_open),
364            "secureTransport" => JsString::from(match self.options.secure_transport {
365                SecureTransport::On => "on",
366                SecureTransport::Off => "off",
367                SecureTransport::StartTls => "starttls",
368            })
369        )
370        .into();
371
372        let inner = worker_sys::connect(address, options)?;
373        Ok(Socket::new(inner))
374    }
375}
376
377// Writes as much as possible to buf, and stores the rest in internal buffer
378fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
379    let idx = buf.remaining().min(data.len());
380    let store = data.split_off(idx);
381    buf.put_slice(&data);
382    if store.is_empty() {
383        (Reading::None, Poll::Ready(Ok(())))
384    } else {
385        (Reading::Ready(store), Poll::Ready(Ok(())))
386    }
387}
388
389#[cfg(feature = "tokio-postgres")]
390/// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for
391/// [`Socket`](crate::Socket) to enable `tokio_postgres` connections
392/// to databases using TLS.
393pub mod postgres_tls {
394    use super::Socket;
395    use futures_util::future::{ready, Ready};
396    use std::error::Error;
397    use std::fmt::{self, Display, Formatter};
398    use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
399
400    /// Supply this to `connect_raw` in place of `NoTls` to specify TLS
401    /// when using Workers.
402    ///
403    /// ```rust
404    /// let config = tokio_postgres::config::Config::new();
405    /// let socket = Socket::builder()
406    ///     .secure_transport(SecureTransport::StartTls)
407    ///     .connect("database_url", 5432)?;
408    /// let _ = config.connect_raw(socket, PassthroughTls).await?;
409    /// ```
410    #[derive(Debug, Clone, Default)]
411    pub struct PassthroughTls;
412
413    #[derive(Debug)]
414    /// Error type for PassthroughTls.
415    /// Should never be returned.
416    pub struct PassthroughTlsError;
417
418    impl Error for PassthroughTlsError {}
419
420    impl Display for PassthroughTlsError {
421        fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
422            fmt.write_str("PassthroughTlsError")
423        }
424    }
425
426    impl TlsConnect<Socket> for PassthroughTls {
427        type Stream = Socket;
428        type Error = PassthroughTlsError;
429        type Future = Ready<Result<Socket, PassthroughTlsError>>;
430
431        fn connect(self, s: Self::Stream) -> Self::Future {
432            let tls = s.start_tls();
433            ready(Ok(tls))
434        }
435    }
436
437    impl TlsStream for Socket {
438        fn channel_binding(&self) -> ChannelBinding {
439            ChannelBinding::none()
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    #[test]
448    fn test_handle_data() {
449        let mut arr = vec![0u8; 32];
450        let mut buf = ReadBuf::new(&mut arr);
451        let data = vec![1u8; 32];
452        let (reading, _) = handle_data(&mut buf, data);
453
454        assert!(matches!(reading, Reading::None));
455        assert_eq!(buf.remaining(), 0);
456        assert_eq!(buf.filled().len(), 32);
457    }
458
459    #[test]
460    fn test_handle_large_data() {
461        let mut arr = vec![0u8; 32];
462        let mut buf = ReadBuf::new(&mut arr);
463        let data = vec![1u8; 64];
464        let (reading, _) = handle_data(&mut buf, data);
465
466        assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
467        assert_eq!(buf.remaining(), 0);
468        assert_eq!(buf.filled().len(), 32);
469    }
470
471    #[test]
472    fn test_handle_small_data() {
473        let mut arr = vec![0u8; 32];
474        let mut buf = ReadBuf::new(&mut arr);
475        let data = vec![1u8; 16];
476        let (reading, _) = handle_data(&mut buf, data);
477
478        assert!(matches!(reading, Reading::None));
479        assert_eq!(buf.remaining(), 16);
480        assert_eq!(buf.filled().len(), 16);
481    }
482}