[go: up one dir, main page]

imbl/
ser.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5use archery::SharedPointerKind;
6use serde::de::{Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
7use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer};
8use std::fmt;
9use std::hash::{BuildHasher, Hash};
10use std::marker::PhantomData;
11
12use crate::hashmap::GenericHashMap;
13use crate::hashset::GenericHashSet;
14use crate::ordmap::GenericOrdMap;
15use crate::ordset::GenericOrdSet;
16use crate::vector::GenericVector;
17
18struct SeqVisitor<'de, S, A> {
19    phantom_s: PhantomData<S>,
20    phantom_a: PhantomData<A>,
21    phantom_lifetime: PhantomData<&'de ()>,
22}
23
24impl<'de, S, A> SeqVisitor<'de, S, A> {
25    pub(crate) fn new() -> SeqVisitor<'de, S, A> {
26        SeqVisitor {
27            phantom_s: PhantomData,
28            phantom_a: PhantomData,
29            phantom_lifetime: PhantomData,
30        }
31    }
32}
33
34impl<'de, S, A> Visitor<'de> for SeqVisitor<'de, S, A>
35where
36    S: From<Vec<A>>,
37    A: Deserialize<'de>,
38{
39    type Value = S;
40
41    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42        formatter.write_str("a sequence")
43    }
44
45    fn visit_seq<Access>(self, mut access: Access) -> Result<Self::Value, Access::Error>
46    where
47        Access: SeqAccess<'de>,
48    {
49        let mut v: Vec<A> = match access.size_hint() {
50            None => Vec::new(),
51            Some(l) => Vec::with_capacity(l),
52        };
53        while let Some(i) = access.next_element()? {
54            v.push(i)
55        }
56        Ok(From::from(v))
57    }
58}
59
60struct MapVisitor<'de, S, K, V> {
61    phantom_s: PhantomData<S>,
62    phantom_k: PhantomData<K>,
63    phantom_v: PhantomData<V>,
64    phantom_lifetime: PhantomData<&'de ()>,
65}
66
67impl<'de, S, K, V> MapVisitor<'de, S, K, V> {
68    pub(crate) fn new() -> MapVisitor<'de, S, K, V> {
69        MapVisitor {
70            phantom_s: PhantomData,
71            phantom_k: PhantomData,
72            phantom_v: PhantomData,
73            phantom_lifetime: PhantomData,
74        }
75    }
76}
77
78impl<'de, S, K, V> Visitor<'de> for MapVisitor<'de, S, K, V>
79where
80    S: From<Vec<(K, V)>>,
81    K: Deserialize<'de>,
82    V: Deserialize<'de>,
83{
84    type Value = S;
85
86    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87        formatter.write_str("a sequence")
88    }
89
90    fn visit_map<Access>(self, mut access: Access) -> Result<Self::Value, Access::Error>
91    where
92        Access: MapAccess<'de>,
93    {
94        let mut v: Vec<(K, V)> = match access.size_hint() {
95            None => Vec::new(),
96            Some(l) => Vec::with_capacity(l),
97        };
98        while let Some(i) = access.next_entry()? {
99            v.push(i)
100        }
101        Ok(From::from(v))
102    }
103}
104
105// Set
106
107impl<'de, A: Deserialize<'de> + Ord + Clone, P: SharedPointerKind> Deserialize<'de>
108    for GenericOrdSet<A, P>
109{
110    fn deserialize<D>(des: D) -> Result<Self, D::Error>
111    where
112        D: Deserializer<'de>,
113    {
114        des.deserialize_seq(SeqVisitor::new())
115    }
116}
117
118impl<A: Ord + Serialize, P: SharedPointerKind> Serialize for GenericOrdSet<A, P> {
119    fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
120    where
121        S: Serializer,
122    {
123        let mut s = ser.serialize_seq(Some(self.len()))?;
124        for i in self.iter() {
125            s.serialize_element(i)?;
126        }
127        s.end()
128    }
129}
130
131// Map
132
133impl<'de, K: Deserialize<'de> + Ord + Clone, V: Deserialize<'de> + Clone, P: SharedPointerKind>
134    Deserialize<'de> for GenericOrdMap<K, V, P>
135{
136    fn deserialize<D>(des: D) -> Result<Self, D::Error>
137    where
138        D: Deserializer<'de>,
139    {
140        des.deserialize_map(MapVisitor::<'de, GenericOrdMap<K, V, P>, K, V>::new())
141    }
142}
143
144impl<K: Serialize + Ord, V: Serialize, P: SharedPointerKind> Serialize for GenericOrdMap<K, V, P> {
145    fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
146    where
147        S: Serializer,
148    {
149        let mut s = ser.serialize_map(Some(self.len()))?;
150        for (k, v) in self.iter() {
151            s.serialize_entry(k, v)?;
152        }
153        s.end()
154    }
155}
156
157// HashMap
158
159impl<'de, K, V, S, P: SharedPointerKind> Deserialize<'de> for GenericHashMap<K, V, S, P>
160where
161    K: Deserialize<'de> + Hash + Eq + Clone,
162    V: Deserialize<'de> + Clone,
163    S: BuildHasher + Default + Clone,
164    P: SharedPointerKind,
165{
166    fn deserialize<D>(des: D) -> Result<Self, D::Error>
167    where
168        D: Deserializer<'de>,
169    {
170        des.deserialize_map(MapVisitor::<'de, GenericHashMap<K, V, S, P>, K, V>::new())
171    }
172}
173
174impl<K, V, S, P> Serialize for GenericHashMap<K, V, S, P>
175where
176    K: Serialize + Hash + Eq,
177    V: Serialize,
178    S: BuildHasher + Default,
179    P: SharedPointerKind,
180{
181    fn serialize<Ser>(&self, ser: Ser) -> Result<Ser::Ok, Ser::Error>
182    where
183        Ser: Serializer,
184    {
185        let mut s = ser.serialize_map(Some(self.len()))?;
186        for (k, v) in self.iter() {
187            s.serialize_entry(k, v)?;
188        }
189        s.end()
190    }
191}
192
193// HashSet
194
195impl<
196        'de,
197        A: Deserialize<'de> + Hash + Eq + Clone,
198        S: BuildHasher + Default + Clone,
199        P: SharedPointerKind,
200    > Deserialize<'de> for GenericHashSet<A, S, P>
201{
202    fn deserialize<D>(des: D) -> Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        des.deserialize_seq(SeqVisitor::new())
207    }
208}
209
210impl<A: Serialize + Hash + Eq, S: BuildHasher + Default, P: SharedPointerKind> Serialize
211    for GenericHashSet<A, S, P>
212{
213    fn serialize<Ser>(&self, ser: Ser) -> Result<Ser::Ok, Ser::Error>
214    where
215        Ser: Serializer,
216    {
217        let mut s = ser.serialize_seq(Some(self.len()))?;
218        for i in self.iter() {
219            s.serialize_element(i)?;
220        }
221        s.end()
222    }
223}
224
225// Vector
226
227impl<'de, A: Clone + Deserialize<'de>, P: SharedPointerKind> Deserialize<'de>
228    for GenericVector<A, P>
229{
230    fn deserialize<D>(des: D) -> Result<Self, D::Error>
231    where
232        D: Deserializer<'de>,
233    {
234        des.deserialize_seq(SeqVisitor::<'de, GenericVector<A, P>, A>::new())
235    }
236}
237
238impl<A: Serialize, P: SharedPointerKind> Serialize for GenericVector<A, P> {
239    fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
240    where
241        S: Serializer,
242    {
243        let mut s = ser.serialize_seq(Some(self.len()))?;
244        for i in self.iter() {
245            s.serialize_element(i)?;
246        }
247        s.end()
248    }
249}
250
251// Tests
252
253#[cfg(test)]
254mod test {
255    use crate::{
256        proptest::{hash_map, hash_set, ord_map, ord_set, vector},
257        HashMap, HashSet, OrdMap, OrdSet, Vector,
258    };
259    use proptest::num::i32;
260    use proptest::proptest;
261    use serde_json::{from_str, to_string};
262
263    proptest! {
264        #[cfg_attr(miri, ignore)]
265        #[test]
266        fn ser_ordset(ref v in ord_set(i32::ANY, 0..100)) {
267            assert_eq!(v, &from_str::<OrdSet<i32>>(&to_string(&v).unwrap()).unwrap());
268        }
269
270        #[cfg_attr(miri, ignore)]
271        #[test]
272        fn ser_ordmap(ref v in ord_map(i32::ANY, i32::ANY, 0..100)) {
273            assert_eq!(v, &from_str::<OrdMap<i32, i32>>(&to_string(&v).unwrap()).unwrap());
274        }
275
276        #[cfg_attr(miri, ignore)]
277        #[test]
278        fn ser_hashmap(ref v in hash_map(i32::ANY, i32::ANY, 0..100)) {
279            assert_eq!(v, &from_str::<HashMap<i32, i32>>(&to_string(&v).unwrap()).unwrap());
280        }
281
282        #[cfg_attr(miri, ignore)]
283        #[test]
284        fn ser_hashset(ref v in hash_set(i32::ANY, 0..100)) {
285            assert_eq!(v, &from_str::<HashSet<i32>>(&to_string(&v).unwrap()).unwrap());
286        }
287
288        #[cfg_attr(miri, ignore)]
289        #[test]
290        fn ser_vector(ref v in vector(i32::ANY, 0..100)) {
291            assert_eq!(v, &from_str::<Vector<i32>>(&to_string(&v).unwrap()).unwrap());
292        }
293    }
294}