1use crate::{Error, Method, Request, Result};
2use futures_channel::mpsc::UnboundedReceiver;
3use futures_util::Stream;
4use js_sys::Uint8Array;
5use serde::Serialize;
6use url::Url;
7use worker_sys::ext::WebSocketExt;
8
9#[cfg(not(feature = "http"))]
10use crate::Fetch;
11use std::pin::Pin;
12use std::rc::Rc;
13use std::task::{Context, Poll};
14use wasm_bindgen::convert::FromWasmAbi;
15use wasm_bindgen::prelude::Closure;
16use wasm_bindgen::JsCast;
17#[cfg(feature = "http")]
18use wasm_bindgen_futures::JsFuture;
19
20pub use crate::ws_events::*;
21pub use worker_sys::WebSocketRequestResponsePair;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct WebSocketPair {
26 pub client: WebSocket,
27 pub server: WebSocket,
28}
29
30unsafe impl Send for WebSocketPair {}
31unsafe impl Sync for WebSocketPair {}
32
33impl WebSocketPair {
34 pub fn new() -> Result<Self> {
36 let mut pair = worker_sys::WebSocketPair::new()?;
37 let client = pair.client()?.into();
38 let server = pair.server()?.into();
39 Ok(Self { client, server })
40 }
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct WebSocket {
46 socket: web_sys::WebSocket,
47}
48
49unsafe impl Send for WebSocket {}
50unsafe impl Sync for WebSocket {}
51
52impl WebSocket {
53 pub async fn connect(url: Url) -> Result<WebSocket> {
79 WebSocket::connect_with_protocols(url, None).await
80 }
81
82 pub async fn connect_with_protocols(
90 mut url: Url,
91 protocols: Option<Vec<&str>>,
92 ) -> Result<WebSocket> {
93 let scheme: String = match url.scheme() {
94 "ws" => "http".into(),
95 "wss" => "https".into(),
96 scheme => scheme.into(),
97 };
98
99 url.set_scheme(&scheme).unwrap();
102
103 let mut req = Request::new(url.as_str(), Method::Get)?;
104 req.headers_mut()?.set("upgrade", "websocket")?;
105
106 match protocols {
107 None => {}
108 Some(v) => {
109 req.headers_mut()?
110 .set("Sec-WebSocket-Protocol", v.join(",").as_str())?;
111 }
112 }
113
114 #[cfg(not(feature = "http"))]
115 let res = Fetch::Request(req).send().await?;
116 #[cfg(feature = "http")]
117 let res: crate::Response = fetch_with_request_raw(req).await?.into();
118
119 match res.websocket() {
120 Some(ws) => Ok(ws),
121 None => Err(Error::RustError("server did not accept".into())),
122 }
123 }
124
125 pub fn accept(&self) -> Result<()> {
127 self.socket.accept().map_err(Error::from)
128 }
129
130 pub fn send<T: Serialize>(&self, data: &T) -> Result<()> {
132 let value = serde_json::to_string(data)?;
133 self.send_with_str(value.as_str())
134 }
135
136 pub fn send_with_str<S: AsRef<str>>(&self, data: S) -> Result<()> {
138 self.socket
139 .send_with_str(data.as_ref())
140 .map_err(Error::from)
141 }
142
143 pub fn send_with_bytes<D: AsRef<[u8]>>(&self, bytes: D) -> Result<()> {
145 let uint8_array = Uint8Array::from(bytes.as_ref());
150 self.socket.send_with_array_buffer(&uint8_array.buffer())?;
151 Ok(())
152 }
153
154 pub fn close<S: AsRef<str>>(&self, code: Option<u16>, reason: Option<S>) -> Result<()> {
165 if let Some((code, reason)) = code.zip(reason) {
166 self.socket
167 .close_with_code_and_reason(code, reason.as_ref())
168 } else if let Some(code) = code {
169 self.socket.close_with_code(code)
170 } else {
171 self.socket.close()
172 }
173 .map_err(Error::from)
174 }
175
176 fn add_event_handler<T: FromWasmAbi + 'static, F: FnMut(T) + 'static>(
184 &self,
185 r#type: &str,
186 fun: F,
187 ) -> Result<Closure<dyn FnMut(T)>> {
188 let js_callback = Closure::wrap(Box::new(fun) as Box<dyn FnMut(T)>);
189 self.socket
190 .add_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
191 .map_err(Error::from)?;
192
193 Ok(js_callback)
194 }
195
196 fn remove_event_handler<T: FromWasmAbi + 'static>(
200 &self,
201 r#type: &str,
202 js_callback: Closure<dyn FnMut(T)>,
203 ) -> Result<()> {
204 self.socket
205 .remove_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
206 .map_err(Error::from)
207 }
208
209 pub fn events(&self) -> Result<EventStream<'_>> {
212 let (tx, rx) = futures_channel::mpsc::unbounded::<Result<WebsocketEvent>>();
213 let tx = Rc::new(tx);
214
215 let close_closure = self.add_event_handler("close", {
216 let tx = tx.clone();
217 move |event: web_sys::CloseEvent| {
218 tx.unbounded_send(Ok(WebsocketEvent::Close(event.into())))
219 .unwrap();
220 }
221 })?;
222 let message_closure = self.add_event_handler("message", {
223 let tx = tx.clone();
224 move |event: web_sys::MessageEvent| {
225 tx.unbounded_send(Ok(WebsocketEvent::Message(event.into())))
226 .unwrap();
227 }
228 })?;
229 let error_closure =
230 self.add_event_handler("error", move |event: web_sys::ErrorEvent| {
231 let error = event.error();
232 tx.unbounded_send(Err(error.into())).unwrap();
233 })?;
234
235 Ok(EventStream {
236 ws: self,
237 rx,
238 closed: false,
239 closures: Some((message_closure, error_closure, close_closure)),
240 })
241 }
242
243 pub fn serialize_attachment<T: Serialize>(&self, value: T) -> Result<()> {
244 self.socket
245 .serialize_attachment(serde_wasm_bindgen::to_value(&value)?)
246 .map_err(Error::from)
247 }
248
249 pub fn deserialize_attachment<T: serde::de::DeserializeOwned>(&self) -> Result<Option<T>> {
250 let value = self.socket.deserialize_attachment().map_err(Error::from)?;
251
252 if value.is_null() || value.is_undefined() {
253 return Ok(None);
254 }
255
256 serde_wasm_bindgen::from_value::<T>(value)
257 .map(Some)
258 .map_err(Error::from)
259 }
260}
261
262type EvCallback<T> = Closure<dyn FnMut(T)>;
263
264#[pin_project::pin_project(PinnedDrop)]
290#[derive(Debug)]
291pub struct EventStream<'ws> {
292 ws: &'ws WebSocket,
293 #[pin]
294 rx: UnboundedReceiver<Result<WebsocketEvent>>,
295 closed: bool,
296 closures: Option<(
299 EvCallback<web_sys::MessageEvent>,
300 EvCallback<web_sys::ErrorEvent>,
301 EvCallback<web_sys::CloseEvent>,
302 )>,
303}
304
305impl Stream for EventStream<'_> {
306 type Item = Result<WebsocketEvent>;
307
308 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 let this = self.project();
310
311 if *this.closed {
312 return Poll::Ready(None);
313 }
314
315 let item = futures_util::ready!(this.rx.poll_next(cx));
317
318 if let Some(item) = &item {
320 if matches!(&item, Ok(WebsocketEvent::Close(_))) {
321 *this.closed = true;
322 }
323 }
324
325 Poll::Ready(item)
326 }
327}
328
329#[pin_project::pinned_drop]
332impl PinnedDrop for EventStream<'_> {
333 fn drop(self: Pin<&'_ mut Self>) {
334 let this = self.project();
335
336 let (message_closure, error_closure, close_closure) =
339 std::mem::take(this.closures).expect("double drop on worker::EventStream");
340
341 this.ws
342 .remove_event_handler("message", message_closure)
343 .expect("could not remove message handler");
344 this.ws
345 .remove_event_handler("error", error_closure)
346 .expect("could not remove error handler");
347 this.ws
348 .remove_event_handler("close", close_closure)
349 .expect("could not remove close handler");
350 }
351}
352
353impl From<web_sys::WebSocket> for WebSocket {
354 fn from(socket: web_sys::WebSocket) -> Self {
355 Self { socket }
356 }
357}
358
359impl AsRef<web_sys::WebSocket> for WebSocket {
360 fn as_ref(&self) -> &web_sys::WebSocket {
361 &self.socket
362 }
363}
364
365pub mod ws_events {
366 use serde::de::DeserializeOwned;
367 use wasm_bindgen::JsValue;
368
369 use crate::Error;
370
371 #[derive(Debug, Clone)]
373 pub enum WebsocketEvent {
374 Message(MessageEvent),
375 Close(CloseEvent),
376 }
377
378 #[derive(Debug, Clone, PartialEq, Eq)]
380 pub struct MessageEvent {
381 event: web_sys::MessageEvent,
382 }
383
384 impl From<web_sys::MessageEvent> for MessageEvent {
385 fn from(event: web_sys::MessageEvent) -> Self {
386 Self { event }
387 }
388 }
389
390 impl AsRef<web_sys::MessageEvent> for MessageEvent {
391 fn as_ref(&self) -> &web_sys::MessageEvent {
392 &self.event
393 }
394 }
395
396 impl MessageEvent {
397 fn data(&self) -> JsValue {
399 self.event.data()
400 }
401
402 pub fn text(&self) -> Option<String> {
403 let value = self.data();
404 value.as_string()
405 }
406
407 pub fn bytes(&self) -> Option<Vec<u8>> {
408 let value = self.data();
409 if value.is_object() {
410 Some(js_sys::Uint8Array::new(&value).to_vec())
411 } else {
412 None
413 }
414 }
415
416 pub fn json<T: DeserializeOwned>(&self) -> crate::Result<T> {
417 let text = match self.text() {
418 Some(text) => text,
419 None => return Err(Error::from("data of message event is not text")),
420 };
421
422 serde_json::from_str(&text).map_err(Error::from)
423 }
424 }
425
426 #[derive(Debug, Clone, PartialEq, Eq)]
428 pub struct CloseEvent {
429 event: web_sys::CloseEvent,
430 }
431
432 impl CloseEvent {
433 pub fn reason(&self) -> String {
434 self.event.reason()
435 }
436
437 pub fn code(&self) -> u16 {
438 self.event.code()
439 }
440
441 pub fn was_clean(&self) -> bool {
442 self.event.was_clean()
443 }
444 }
445
446 impl From<web_sys::CloseEvent> for CloseEvent {
447 fn from(event: web_sys::CloseEvent) -> Self {
448 Self { event }
449 }
450 }
451
452 impl AsRef<web_sys::CloseEvent> for CloseEvent {
453 fn as_ref(&self) -> &web_sys::CloseEvent {
454 &self.event
455 }
456 }
457}
458
459#[cfg(feature = "http")]
461async fn fetch_with_request_raw(request: crate::Request) -> Result<web_sys::Response> {
462 let req = request.inner();
463 let fut = {
464 let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into();
465 crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req)))
466 };
467 let resp = fut.await?;
468 Ok(resp.dyn_into()?)
469}