[go: up one dir, main page]

worker/
websocket.rs

1use crate::{Error, Method, Request, Result};
2use futures_channel::mpsc::UnboundedReceiver;
3use futures_util::Stream;
4use js_sys::Uint8Array;
5use serde::Serialize;
6use url::Url;
7use worker_sys::ext::WebSocketExt;
8
9#[cfg(not(feature = "http"))]
10use crate::Fetch;
11use std::pin::Pin;
12use std::rc::Rc;
13use std::task::{Context, Poll};
14use wasm_bindgen::convert::FromWasmAbi;
15use wasm_bindgen::prelude::Closure;
16use wasm_bindgen::JsCast;
17#[cfg(feature = "http")]
18use wasm_bindgen_futures::JsFuture;
19
20pub use crate::ws_events::*;
21pub use worker_sys::WebSocketRequestResponsePair;
22
23/// Struct holding the values for a JavaScript `WebSocketPair`
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct WebSocketPair {
26    pub client: WebSocket,
27    pub server: WebSocket,
28}
29
30unsafe impl Send for WebSocketPair {}
31unsafe impl Sync for WebSocketPair {}
32
33impl WebSocketPair {
34    /// Creates a new `WebSocketPair`.
35    pub fn new() -> Result<Self> {
36        let mut pair = worker_sys::WebSocketPair::new()?;
37        let client = pair.client()?.into();
38        let server = pair.server()?.into();
39        Ok(Self { client, server })
40    }
41}
42
43/// Wrapper struct for underlying worker-sys `WebSocket`
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct WebSocket {
46    socket: web_sys::WebSocket,
47}
48
49unsafe impl Send for WebSocket {}
50unsafe impl Sync for WebSocket {}
51
52impl WebSocket {
53    /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`].
54    ///
55    /// # Example:
56    /// ```rust,ignore
57    /// let ws = WebSocket::connect("wss://echo.zeb.workers.dev/".parse()?).await?;
58    ///
59    /// // It's important that we call this before we send our first message, otherwise we will
60    /// // not have any event listeners on the socket to receive the echoed message.
61    /// let mut event_stream = ws.events()?;
62    ///
63    /// ws.accept()?;
64    /// ws.send_with_str("Hello, world!")?;
65    ///
66    /// while let Some(event) = event_stream.next().await {
67    ///     let event = event?;
68    ///
69    ///     if let WebsocketEvent::Message(msg) = event {
70    ///         if let Some(text) = msg.text() {
71    ///             return Response::ok(text);
72    ///         }
73    ///     }
74    /// }
75    ///
76    /// Response::error("never got a message echoed back :(", 500)
77    /// ```
78    pub async fn connect(url: Url) -> Result<WebSocket> {
79        WebSocket::connect_with_protocols(url, None).await
80    }
81
82    /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`] and protocol.
83    ///
84    /// # Example:
85    /// ```rust,ignore
86    /// let ws = WebSocket::connect_with_protocols("wss://echo.zeb.workers.dev/".parse()?, Some(vec!["GiggleBytes"])).await?;
87    ///
88    /// ```
89    pub async fn connect_with_protocols(
90        mut url: Url,
91        protocols: Option<Vec<&str>>,
92    ) -> Result<WebSocket> {
93        let scheme: String = match url.scheme() {
94            "ws" => "http".into(),
95            "wss" => "https".into(),
96            scheme => scheme.into(),
97        };
98
99        // With fetch we can only make requests to http(s) urls, but Workers will allow us to upgrade
100        // those connections into websockets if we use the `Upgrade` header.
101        url.set_scheme(&scheme).unwrap();
102
103        let mut req = Request::new(url.as_str(), Method::Get)?;
104        req.headers_mut()?.set("upgrade", "websocket")?;
105
106        match protocols {
107            None => {}
108            Some(v) => {
109                req.headers_mut()?
110                    .set("Sec-WebSocket-Protocol", v.join(",").as_str())?;
111            }
112        }
113
114        #[cfg(not(feature = "http"))]
115        let res = Fetch::Request(req).send().await?;
116        #[cfg(feature = "http")]
117        let res: crate::Response = fetch_with_request_raw(req).await?.into();
118
119        match res.websocket() {
120            Some(ws) => Ok(ws),
121            None => Err(Error::RustError("server did not accept".into())),
122        }
123    }
124
125    /// Accepts the connection, allowing for messages to be sent to and from the `WebSocket`.
126    pub fn accept(&self) -> Result<()> {
127        self.socket.accept().map_err(Error::from)
128    }
129
130    /// Serialize data into a string using serde and send it through the `WebSocket`
131    pub fn send<T: Serialize>(&self, data: &T) -> Result<()> {
132        let value = serde_json::to_string(data)?;
133        self.send_with_str(value.as_str())
134    }
135
136    /// Sends a raw string through the `WebSocket`
137    pub fn send_with_str<S: AsRef<str>>(&self, data: S) -> Result<()> {
138        self.socket
139            .send_with_str(data.as_ref())
140            .map_err(Error::from)
141    }
142
143    /// Sends raw binary data through the `WebSocket`.
144    pub fn send_with_bytes<D: AsRef<[u8]>>(&self, bytes: D) -> Result<()> {
145        // This clone to Uint8Array must happen, because workerd
146        // will not clone the supplied buffer and will send it asynchronously.
147        // Rust believes that the lifetime ends when `send` returns, and frees
148        // the memory, causing corruption.
149        let uint8_array = Uint8Array::from(bytes.as_ref());
150        self.socket.send_with_array_buffer(&uint8_array.buffer())?;
151        Ok(())
152    }
153
154    /// Closes this channel.
155    /// This method translates to three different underlying method calls based of the
156    /// parameters passed.
157    ///
158    /// If the following parameters are Some:
159    /// * `code` and `reason` -> `close_with_code_and_reason`
160    /// * `code`              -> `close_with_code`
161    /// * `reason` or `none`  -> `close`
162    ///
163    /// Effectively, if only `reason` is `Some`, the `reason` argument will be ignored.
164    pub fn close<S: AsRef<str>>(&self, code: Option<u16>, reason: Option<S>) -> Result<()> {
165        if let Some((code, reason)) = code.zip(reason) {
166            self.socket
167                .close_with_code_and_reason(code, reason.as_ref())
168        } else if let Some(code) = code {
169            self.socket.close_with_code(code)
170        } else {
171            self.socket.close()
172        }
173        .map_err(Error::from)
174    }
175
176    /// Internal utility method to avoid verbose code.
177    /// This method registers a closure in the underlying JS environment, which calls back into the
178    /// Rust/wasm environment.
179    ///
180    /// Since this is a 'long living closure', we need to keep the lifetime of this closure until
181    /// we remove any references to the closure. So the caller of this must not drop the closure
182    /// until they call [`Self::remove_event_handler`].
183    fn add_event_handler<T: FromWasmAbi + 'static, F: FnMut(T) + 'static>(
184        &self,
185        r#type: &str,
186        fun: F,
187    ) -> Result<Closure<dyn FnMut(T)>> {
188        let js_callback = Closure::wrap(Box::new(fun) as Box<dyn FnMut(T)>);
189        self.socket
190            .add_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
191            .map_err(Error::from)?;
192
193        Ok(js_callback)
194    }
195
196    /// Internal utility method to avoid verbose code.
197    /// This method registers a closure in the underlying JS environment, which calls back
198    /// into the Rust/wasm environment.
199    fn remove_event_handler<T: FromWasmAbi + 'static>(
200        &self,
201        r#type: &str,
202        js_callback: Closure<dyn FnMut(T)>,
203    ) -> Result<()> {
204        self.socket
205            .remove_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
206            .map_err(Error::from)
207    }
208
209    /// Gets an implementation [`Stream`](futures::Stream) that yields events from the inner
210    /// WebSocket.
211    pub fn events(&self) -> Result<EventStream<'_>> {
212        let (tx, rx) = futures_channel::mpsc::unbounded::<Result<WebsocketEvent>>();
213        let tx = Rc::new(tx);
214
215        let close_closure = self.add_event_handler("close", {
216            let tx = tx.clone();
217            move |event: web_sys::CloseEvent| {
218                tx.unbounded_send(Ok(WebsocketEvent::Close(event.into())))
219                    .unwrap();
220            }
221        })?;
222        let message_closure = self.add_event_handler("message", {
223            let tx = tx.clone();
224            move |event: web_sys::MessageEvent| {
225                tx.unbounded_send(Ok(WebsocketEvent::Message(event.into())))
226                    .unwrap();
227            }
228        })?;
229        let error_closure =
230            self.add_event_handler("error", move |event: web_sys::ErrorEvent| {
231                let error = event.error();
232                tx.unbounded_send(Err(error.into())).unwrap();
233            })?;
234
235        Ok(EventStream {
236            ws: self,
237            rx,
238            closed: false,
239            closures: Some((message_closure, error_closure, close_closure)),
240        })
241    }
242
243    pub fn serialize_attachment<T: Serialize>(&self, value: T) -> Result<()> {
244        self.socket
245            .serialize_attachment(serde_wasm_bindgen::to_value(&value)?)
246            .map_err(Error::from)
247    }
248
249    pub fn deserialize_attachment<T: serde::de::DeserializeOwned>(&self) -> Result<Option<T>> {
250        let value = self.socket.deserialize_attachment().map_err(Error::from)?;
251
252        if value.is_null() || value.is_undefined() {
253            return Ok(None);
254        }
255
256        serde_wasm_bindgen::from_value::<T>(value)
257            .map(Some)
258            .map_err(Error::from)
259    }
260}
261
262type EvCallback<T> = Closure<dyn FnMut(T)>;
263
264/// A [`Stream`](futures::Stream) that yields [`WebsocketEvent`](crate::ws_events::WebsocketEvent)s
265/// emitted by the inner [`WebSocket`](crate::WebSocket). The stream is guaranteed to always yield a
266/// `WebsocketEvent::Close` as the final non-none item.
267///
268/// # Example
269/// ```rust,ignore
270/// use futures::StreamExt;
271///
272/// let pair = WebSocketPair::new()?;
273/// let server = pair.server;
274///
275/// server.accept()?;
276///
277/// // Spawn a future for handling the stream of events from the websocket.
278/// wasm_bindgen_futures::spawn_local(async move {
279///     let mut event_stream = server.events().expect("could not open stream");
280///
281///     while let Some(event) = event_stream.next().await {
282///         match event.expect("received error in websocket") {
283///             WebsocketEvent::Message(msg) => console_log!("{:#?}", msg),
284///             WebsocketEvent::Close(event) => console_log!("Closed!"),
285///         }
286///     }
287/// });
288/// ```
289#[pin_project::pin_project(PinnedDrop)]
290#[derive(Debug)]
291pub struct EventStream<'ws> {
292    ws: &'ws WebSocket,
293    #[pin]
294    rx: UnboundedReceiver<Result<WebsocketEvent>>,
295    closed: bool,
296    /// Once we have decided we need to finish the stream, we need to remove any listeners we
297    /// registered with the websocket.
298    closures: Option<(
299        EvCallback<web_sys::MessageEvent>,
300        EvCallback<web_sys::ErrorEvent>,
301        EvCallback<web_sys::CloseEvent>,
302    )>,
303}
304
305impl Stream for EventStream<'_> {
306    type Item = Result<WebsocketEvent>;
307
308    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309        let this = self.project();
310
311        if *this.closed {
312            return Poll::Ready(None);
313        }
314
315        // Poll the inner receiver to check if theres any events from our event callbacks.
316        let item = futures_util::ready!(this.rx.poll_next(cx));
317
318        // Mark the stream as closed if we get a close event and yield None next iteration.
319        if let Some(item) = &item {
320            if matches!(&item, Ok(WebsocketEvent::Close(_))) {
321                *this.closed = true;
322            }
323        }
324
325        Poll::Ready(item)
326    }
327}
328
329// Because we don't want to receive messages once our stream is done we need to remove all our
330// listeners when we go to drop the stream.
331#[pin_project::pinned_drop]
332impl PinnedDrop for EventStream<'_> {
333    fn drop(self: Pin<&'_ mut Self>) {
334        let this = self.project();
335
336        // remove_event_handler takes an owned closure, so we'll do this little hack and wrap our
337        // closures in an Option. This should never panic because we should never call drop twice.
338        let (message_closure, error_closure, close_closure) =
339            std::mem::take(this.closures).expect("double drop on worker::EventStream");
340
341        this.ws
342            .remove_event_handler("message", message_closure)
343            .expect("could not remove message handler");
344        this.ws
345            .remove_event_handler("error", error_closure)
346            .expect("could not remove error handler");
347        this.ws
348            .remove_event_handler("close", close_closure)
349            .expect("could not remove close handler");
350    }
351}
352
353impl From<web_sys::WebSocket> for WebSocket {
354    fn from(socket: web_sys::WebSocket) -> Self {
355        Self { socket }
356    }
357}
358
359impl AsRef<web_sys::WebSocket> for WebSocket {
360    fn as_ref(&self) -> &web_sys::WebSocket {
361        &self.socket
362    }
363}
364
365pub mod ws_events {
366    use serde::de::DeserializeOwned;
367    use wasm_bindgen::JsValue;
368
369    use crate::Error;
370
371    /// Events that can be yielded by a [`EventStream`](crate::EventStream).
372    #[derive(Debug, Clone)]
373    pub enum WebsocketEvent {
374        Message(MessageEvent),
375        Close(CloseEvent),
376    }
377
378    /// Wrapper/Utility struct for the `web_sys::MessageEvent`
379    #[derive(Debug, Clone, PartialEq, Eq)]
380    pub struct MessageEvent {
381        event: web_sys::MessageEvent,
382    }
383
384    impl From<web_sys::MessageEvent> for MessageEvent {
385        fn from(event: web_sys::MessageEvent) -> Self {
386            Self { event }
387        }
388    }
389
390    impl AsRef<web_sys::MessageEvent> for MessageEvent {
391        fn as_ref(&self) -> &web_sys::MessageEvent {
392            &self.event
393        }
394    }
395
396    impl MessageEvent {
397        /// Gets the data/payload from the message.
398        fn data(&self) -> JsValue {
399            self.event.data()
400        }
401
402        pub fn text(&self) -> Option<String> {
403            let value = self.data();
404            value.as_string()
405        }
406
407        pub fn bytes(&self) -> Option<Vec<u8>> {
408            let value = self.data();
409            if value.is_object() {
410                Some(js_sys::Uint8Array::new(&value).to_vec())
411            } else {
412                None
413            }
414        }
415
416        pub fn json<T: DeserializeOwned>(&self) -> crate::Result<T> {
417            let text = match self.text() {
418                Some(text) => text,
419                None => return Err(Error::from("data of message event is not text")),
420            };
421
422            serde_json::from_str(&text).map_err(Error::from)
423        }
424    }
425
426    /// Wrapper/Utility struct for the `web_sys::CloseEvent`
427    #[derive(Debug, Clone, PartialEq, Eq)]
428    pub struct CloseEvent {
429        event: web_sys::CloseEvent,
430    }
431
432    impl CloseEvent {
433        pub fn reason(&self) -> String {
434            self.event.reason()
435        }
436
437        pub fn code(&self) -> u16 {
438            self.event.code()
439        }
440
441        pub fn was_clean(&self) -> bool {
442            self.event.was_clean()
443        }
444    }
445
446    impl From<web_sys::CloseEvent> for CloseEvent {
447        fn from(event: web_sys::CloseEvent) -> Self {
448            Self { event }
449        }
450    }
451
452    impl AsRef<web_sys::CloseEvent> for CloseEvent {
453        fn as_ref(&self) -> &web_sys::CloseEvent {
454            &self.event
455        }
456    }
457}
458
459/// TODO: Convert WebSocket to use `http` types and `reqwest`.
460#[cfg(feature = "http")]
461async fn fetch_with_request_raw(request: crate::Request) -> Result<web_sys::Response> {
462    let req = request.inner();
463    let fut = {
464        let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into();
465        crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req)))
466    };
467    let resp = fut.await?;
468    Ok(resp.dyn_into()?)
469}