1use js_sys::Array;
2use std::convert::TryFrom;
3use wasm_bindgen::{JsCast, JsValue};
4use worker_sys::types::{SqlStorage as SqlStorageSys, SqlStorageCursor as SqlStorageCursorSys};
5
6use serde::de::DeserializeOwned;
7use serde_wasm_bindgen as swb;
8
9use crate::Error;
10use crate::Result;
11
12#[derive(Debug, Clone, PartialEq)]
17pub enum SqlStorageValue {
18 Null,
20 Boolean(bool),
22 Integer(i64),
25 Float(f64),
27 String(String),
29 Blob(Vec<u8>),
31}
32
33impl From<bool> for SqlStorageValue {
35 fn from(value: bool) -> Self {
36 SqlStorageValue::Boolean(value)
37 }
38}
39
40impl From<i32> for SqlStorageValue {
41 fn from(value: i32) -> Self {
42 SqlStorageValue::Integer(value as i64)
43 }
44}
45
46impl From<i64> for SqlStorageValue {
47 fn from(value: i64) -> Self {
48 SqlStorageValue::Integer(value)
49 }
50}
51
52impl SqlStorageValue {
53 pub fn try_from_i64(value: i64) -> Result<Self> {
58 if value >= js_sys::Number::MIN_SAFE_INTEGER as i64
59 && value <= js_sys::Number::MAX_SAFE_INTEGER as i64
60 {
61 Ok(SqlStorageValue::Integer(value))
62 } else {
63 Err(crate::Error::from(
64 "Value outside JavaScript safe integer range",
65 ))
66 }
67 }
68}
69
70impl From<f64> for SqlStorageValue {
71 fn from(value: f64) -> Self {
72 SqlStorageValue::Float(value)
73 }
74}
75
76impl From<String> for SqlStorageValue {
77 fn from(value: String) -> Self {
78 SqlStorageValue::String(value)
79 }
80}
81
82impl From<&str> for SqlStorageValue {
83 fn from(value: &str) -> Self {
84 SqlStorageValue::String(value.to_string())
85 }
86}
87
88impl From<Vec<u8>> for SqlStorageValue {
89 fn from(value: Vec<u8>) -> Self {
90 SqlStorageValue::Blob(value)
91 }
92}
93
94impl<T> From<Option<T>> for SqlStorageValue
95where
96 T: Into<SqlStorageValue>,
97{
98 fn from(value: Option<T>) -> Self {
99 match value {
100 Some(v) => v.into(),
101 None => SqlStorageValue::Null,
102 }
103 }
104}
105
106impl From<SqlStorageValue> for JsValue {
108 fn from(val: SqlStorageValue) -> Self {
109 match val {
110 SqlStorageValue::Null => JsValue::NULL,
111 SqlStorageValue::Boolean(b) => JsValue::from(b),
112 SqlStorageValue::Integer(i) => {
113 let js_value = JsValue::from(i as f64);
114
115 if !js_sys::Number::is_safe_integer(&js_value) {
116 crate::console_debug!(
117 "Warning: Converting {} to JsValue as Integer, \
118 but it is outside the JavaScript safe-integer range",
119 i
120 );
121 }
122 js_value
123 }
124 SqlStorageValue::Float(f) => JsValue::from(f),
125 SqlStorageValue::String(s) => JsValue::from(s),
126 SqlStorageValue::Blob(bytes) => {
127 let array = js_sys::Uint8Array::new_with_length(bytes.len() as u32);
129 array.copy_from(&bytes);
130 array.into()
131 }
132 }
133 }
134}
135
136impl TryFrom<JsValue> for SqlStorageValue {
137 type Error = crate::Error;
138
139 fn try_from(js_val: JsValue) -> Result<Self> {
140 if js_val.is_null() || js_val.is_undefined() {
141 Ok(SqlStorageValue::Null)
142 } else if let Some(bool_val) = js_val.as_bool() {
143 Ok(SqlStorageValue::Boolean(bool_val))
144 } else if let Some(str_val) = js_val.as_string() {
145 Ok(SqlStorageValue::String(str_val))
146 } else if let Some(num_val) = js_val.as_f64() {
147 if js_sys::Number::is_safe_integer(&js_val) {
148 Ok(SqlStorageValue::Integer(num_val as i64))
149 } else {
150 Ok(SqlStorageValue::Float(num_val))
151 }
152 } else {
153 js_val
154 .dyn_into::<js_sys::Uint8Array>()
155 .map(|uint8_array| {
156 let mut bytes = vec![0u8; uint8_array.length() as usize];
157 uint8_array.copy_to(&mut bytes);
158 SqlStorageValue::Blob(bytes)
159 })
160 .or_else(|js_val| {
161 js_val
162 .dyn_into::<js_sys::ArrayBuffer>()
163 .map(|array_buffer| {
164 let uint8_array = js_sys::Uint8Array::new(&array_buffer);
165 let mut bytes = vec![0u8; uint8_array.length() as usize];
166 uint8_array.copy_to(&mut bytes);
167 SqlStorageValue::Blob(bytes)
168 })
169 })
170 .map_err(|_| Error::from("Unsupported JavaScript value type"))
171 }
172 }
173}
174
175#[derive(Clone, Debug)]
180pub struct SqlStorage {
181 inner: SqlStorageSys,
182}
183
184unsafe impl Send for SqlStorage {}
185unsafe impl Sync for SqlStorage {}
186
187impl SqlStorage {
188 pub(crate) fn new(inner: SqlStorageSys) -> Self {
189 Self { inner }
190 }
191
192 pub fn database_size(&self) -> usize {
194 self.inner.database_size() as usize
195 }
196
197 pub fn exec(
202 &self,
203 query: &str,
204 bindings: impl Into<Option<Vec<SqlStorageValue>>>,
205 ) -> Result<SqlCursor> {
206 let array = Array::new();
207 if let Some(bindings) = bindings.into() {
208 for v in bindings {
209 array.push(&v.into());
210 }
211 }
212 let cursor = self.inner.exec(query, array).map_err(Error::from)?;
213 Ok(SqlCursor { inner: cursor })
214 }
215
216 pub fn exec_raw(
221 &self,
222 query: &str,
223 bindings: impl Into<Option<Vec<JsValue>>>,
224 ) -> Result<SqlCursor> {
225 let array = Array::new();
226 if let Some(bindings) = bindings.into() {
227 for v in bindings {
228 array.push(&v);
229 }
230 }
231 let cursor = self.inner.exec(query, array).map_err(Error::from)?;
232 Ok(SqlCursor { inner: cursor })
233 }
234}
235
236impl AsRef<JsValue> for SqlStorage {
237 fn as_ref(&self) -> &JsValue {
238 &self.inner
239 }
240}
241
242#[derive(Clone, Debug)]
244pub struct SqlCursor {
245 inner: SqlStorageCursorSys,
246}
247
248unsafe impl Send for SqlCursor {}
249unsafe impl Sync for SqlCursor {}
250
251#[derive(Debug)]
256pub struct SqlCursorIterator<T> {
257 cursor: SqlCursor,
258 _phantom: std::marker::PhantomData<T>,
259}
260
261impl<T> Iterator for SqlCursorIterator<T>
262where
263 T: DeserializeOwned,
264{
265 type Item = Result<T>;
266
267 fn next(&mut self) -> Option<Self::Item> {
268 let result = self.cursor.inner.next();
269
270 let done = js_sys::Reflect::get(&result, &JsValue::from("done"))
271 .ok()
272 .and_then(|v| v.as_bool())
273 .unwrap_or(true);
274
275 if done {
276 None
277 } else {
278 let value = js_sys::Reflect::get(&result, &JsValue::from("value"))
279 .map_err(Error::from)
280 .and_then(|js_val| swb::from_value(js_val).map_err(Error::from));
281 Some(value)
282 }
283 }
284}
285
286#[derive(Debug)]
292pub struct SqlCursorRawIterator {
293 inner: js_sys::Iterator,
294}
295
296impl Iterator for SqlCursorRawIterator {
297 type Item = Result<Vec<SqlStorageValue>>;
298
299 fn next(&mut self) -> Option<Self::Item> {
300 match self.inner.next() {
301 Ok(iterator_next) => {
302 if iterator_next.done() {
303 None
304 } else {
305 let js_val = iterator_next.value();
306 let array_result = js_array_to_sql_storage_values(js_val);
307 Some(array_result)
308 }
309 }
310 Err(e) => Some(Err(Error::from(e))),
311 }
312 }
313}
314
315fn js_array_to_sql_storage_values(js_val: JsValue) -> Result<Vec<SqlStorageValue>> {
316 let array = js_sys::Array::from(&js_val);
317 let mut values = Vec::with_capacity(array.length() as usize);
318
319 for i in 0..array.length() {
320 let item = array.get(i);
321 let sql_value = SqlStorageValue::try_from(item)?;
322 values.push(sql_value);
323 }
324
325 Ok(values)
326}
327
328impl SqlCursor {
329 pub fn to_array<T>(&self) -> Result<Vec<T>>
331 where
332 T: DeserializeOwned,
333 {
334 let arr = self.inner.to_array();
335 let mut out = Vec::with_capacity(arr.length() as usize);
336 for val in arr.iter() {
337 out.push(swb::from_value(val)?);
338 }
339 Ok(out)
340 }
341
342 pub fn one<T>(&self) -> Result<T>
344 where
345 T: DeserializeOwned,
346 {
347 let val = self.inner.one();
348 Ok(swb::from_value(val)?)
349 }
350
351 pub fn column_names(&self) -> Vec<String> {
353 self.inner
354 .column_names()
355 .iter()
356 .map(|v| v.as_string().unwrap_or_default())
357 .collect()
358 }
359
360 pub fn rows_read(&self) -> usize {
362 self.inner.rows_read() as usize
363 }
364
365 pub fn rows_written(&self) -> usize {
367 self.inner.rows_written() as usize
368 }
369
370 pub fn next<T>(&self) -> SqlCursorIterator<T>
375 where
376 T: DeserializeOwned,
377 {
378 SqlCursorIterator {
379 cursor: self.clone(),
380 _phantom: std::marker::PhantomData,
381 }
382 }
383
384 pub fn raw(&self) -> SqlCursorRawIterator {
389 SqlCursorRawIterator {
390 inner: self.inner.raw(),
391 }
392 }
393}
394
395impl Iterator for SqlCursor {
396 type Item = Result<JsValue>;
397
398 fn next(&mut self) -> Option<Self::Item> {
399 let result = self.inner.next();
400
401 let done = js_sys::Reflect::get(&result, &JsValue::from("done"))
403 .ok()
404 .and_then(|v| v.as_bool())
405 .unwrap_or(true);
406
407 if done {
408 None
409 } else {
410 let value = js_sys::Reflect::get(&result, &JsValue::from("value")).map_err(Error::from);
412 Some(value)
413 }
414 }
415}