[go: up one dir, main page]

rkyv/
rc.rs

1//! Archived versions of shared pointers.
2
3use core::{borrow::Borrow, cmp, fmt, hash, marker::PhantomData, ops::Deref};
4
5use munge::munge;
6use rancor::{Fallible, Source};
7
8use crate::{
9    primitive::FixedUsize,
10    seal::Seal,
11    ser::{Sharing, SharingExt, Writer, WriterExt as _},
12    traits::ArchivePointee,
13    ArchiveUnsized, Place, Portable, RelPtr, SerializeUnsized,
14};
15
16/// A type marker for `ArchivedRc`.
17pub trait Flavor: 'static {
18    /// If `true`, cyclic `ArchivedRc`s with this flavor will not fail
19    /// validation. If `false`, cyclic `ArchivedRc`s with this flavor will fail
20    /// validation.
21    const ALLOW_CYCLES: bool;
22}
23
24/// The flavor type for [`Rc`](crate::alloc::rc::Rc).
25pub struct RcFlavor;
26
27impl Flavor for RcFlavor {
28    const ALLOW_CYCLES: bool = false;
29}
30
31/// The flavor type for [`Arc`](crate::alloc::sync::Arc).
32pub struct ArcFlavor;
33
34impl Flavor for ArcFlavor {
35    const ALLOW_CYCLES: bool = false;
36}
37
38/// An archived `Rc`.
39///
40/// This is a thin wrapper around a [`RelPtr`] to the archived type paired with
41/// a "flavor" type. Because there may be many varieties of shared pointers and
42/// they may not be used together, the flavor helps check that memory is not
43/// being shared incorrectly during validation.
44#[derive(Portable)]
45#[rkyv(crate)]
46#[repr(transparent)]
47#[cfg_attr(
48    feature = "bytecheck",
49    derive(bytecheck::CheckBytes),
50    bytecheck(verify)
51)]
52pub struct ArchivedRc<T: ArchivePointee + ?Sized, F> {
53    ptr: RelPtr<T>,
54    _phantom: PhantomData<F>,
55}
56
57impl<T: ArchivePointee + ?Sized, F> ArchivedRc<T, F> {
58    /// Gets the value of the `ArchivedRc`.
59    pub fn get(&self) -> &T {
60        unsafe { &*self.ptr.as_ptr() }
61    }
62
63    /// Gets the sealed value of this `ArchivedRc`.
64    ///
65    /// # Safety
66    ///
67    /// Any other pointers to the same value must not be dereferenced for the
68    /// duration of the returned borrow.
69    pub unsafe fn get_seal_unchecked(this: Seal<'_, Self>) -> Seal<'_, T> {
70        munge!(let Self { ptr, _phantom: _ } = this);
71        Seal::new(unsafe { &mut *RelPtr::as_mut_ptr(ptr) })
72    }
73
74    /// Resolves an archived `Rc` from a given reference.
75    pub fn resolve_from_ref<U: ArchiveUnsized<Archived = T> + ?Sized>(
76        value: &U,
77        resolver: RcResolver,
78        out: Place<Self>,
79    ) {
80        munge!(let ArchivedRc { ptr, .. } = out);
81        RelPtr::emplace_unsized(
82            resolver.pos as usize,
83            value.archived_metadata(),
84            ptr,
85        );
86    }
87
88    /// Serializes an archived `Rc` from a given reference.
89    pub fn serialize_from_ref<U, S>(
90        value: &U,
91        serializer: &mut S,
92    ) -> Result<RcResolver, S::Error>
93    where
94        U: SerializeUnsized<S> + ?Sized,
95        S: Fallible + Writer + Sharing + ?Sized,
96        S::Error: Source,
97    {
98        let pos = serializer.serialize_shared(value)?;
99
100        // The positions of serialized `Rc` values must be unique. If we didn't
101        // write any data by serializing `value`, pad the serializer by a byte
102        // to ensure that our position will be unique.
103        if serializer.pos() == pos {
104            serializer.pad(1)?;
105        }
106
107        Ok(RcResolver {
108            pos: pos as FixedUsize,
109        })
110    }
111}
112
113impl<T, F> AsRef<T> for ArchivedRc<T, F>
114where
115    T: ArchivePointee + ?Sized,
116{
117    fn as_ref(&self) -> &T {
118        self.get()
119    }
120}
121
122impl<T, F> Borrow<T> for ArchivedRc<T, F>
123where
124    T: ArchivePointee + ?Sized,
125{
126    fn borrow(&self) -> &T {
127        self.get()
128    }
129}
130
131impl<T, F> fmt::Debug for ArchivedRc<T, F>
132where
133    T: ArchivePointee + fmt::Debug + ?Sized,
134{
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        self.get().fmt(f)
137    }
138}
139
140impl<T, F> Deref for ArchivedRc<T, F>
141where
142    T: ArchivePointee + ?Sized,
143{
144    type Target = T;
145    fn deref(&self) -> &Self::Target {
146        self.get()
147    }
148}
149
150impl<T, F> fmt::Display for ArchivedRc<T, F>
151where
152    T: ArchivePointee + fmt::Display + ?Sized,
153{
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        self.get().fmt(f)
156    }
157}
158
159impl<T, F> Eq for ArchivedRc<T, F> where T: ArchivePointee + Eq + ?Sized {}
160
161impl<T, F> hash::Hash for ArchivedRc<T, F>
162where
163    T: ArchivePointee + hash::Hash + ?Sized,
164{
165    fn hash<H: hash::Hasher>(&self, state: &mut H) {
166        self.get().hash(state)
167    }
168}
169
170impl<T, F> Ord for ArchivedRc<T, F>
171where
172    T: ArchivePointee + Ord + ?Sized,
173{
174    fn cmp(&self, other: &Self) -> cmp::Ordering {
175        self.get().cmp(other.get())
176    }
177}
178
179impl<T, TF, U, UF> PartialEq<ArchivedRc<U, UF>> for ArchivedRc<T, TF>
180where
181    T: ArchivePointee + PartialEq<U> + ?Sized,
182    U: ArchivePointee + ?Sized,
183{
184    fn eq(&self, other: &ArchivedRc<U, UF>) -> bool {
185        self.get().eq(other.get())
186    }
187}
188
189impl<T, TF, U, UF> PartialOrd<ArchivedRc<U, UF>> for ArchivedRc<T, TF>
190where
191    T: ArchivePointee + PartialOrd<U> + ?Sized,
192    U: ArchivePointee + ?Sized,
193{
194    fn partial_cmp(&self, other: &ArchivedRc<U, UF>) -> Option<cmp::Ordering> {
195        self.get().partial_cmp(other.get())
196    }
197}
198
199impl<T, F> fmt::Pointer for ArchivedRc<T, F> {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        fmt::Pointer::fmt(&self.ptr.base(), f)
202    }
203}
204
205/// The resolver for `Rc`.
206pub struct RcResolver {
207    pos: FixedUsize,
208}
209
210impl RcResolver {
211    /// Creates a new [`RcResolver`] from the position of a serialized value.
212    ///
213    /// In most cases, you won't need to create a [`RcResolver`] yourself and
214    /// can instead obtain it through [`ArchivedRc::serialize_from_ref`].
215    pub fn from_pos(pos: usize) -> Self {
216        Self {
217            pos: pos as FixedUsize,
218        }
219    }
220}
221
222/// An archived `rc::Weak`.
223///
224/// This is essentially just an optional [`ArchivedRc`].
225#[derive(Portable)]
226#[rkyv(crate)]
227#[repr(transparent)]
228#[cfg_attr(
229    feature = "bytecheck",
230    derive(bytecheck::CheckBytes),
231    bytecheck(verify)
232)]
233pub struct ArchivedRcWeak<T: ArchivePointee + ?Sized, F> {
234    ptr: RelPtr<T>,
235    _phantom: PhantomData<F>,
236}
237
238impl<T: ArchivePointee + ?Sized, F> ArchivedRcWeak<T, F> {
239    /// Attempts to upgrade the weak pointer to an `ArchivedArc`.
240    ///
241    /// Returns `None` if a null weak pointer was serialized.
242    pub fn upgrade(&self) -> Option<&ArchivedRc<T, F>> {
243        if self.ptr.is_invalid() {
244            None
245        } else {
246            Some(unsafe { &*(self as *const Self).cast() })
247        }
248    }
249
250    /// Attempts to upgrade a sealed weak pointer.
251    pub fn upgrade_seal(
252        this: Seal<'_, Self>,
253    ) -> Option<Seal<'_, ArchivedRc<T, F>>> {
254        let this = unsafe { this.unseal_unchecked() };
255        if this.ptr.is_invalid() {
256            None
257        } else {
258            Some(Seal::new(unsafe { &mut *(this as *mut Self).cast() }))
259        }
260    }
261
262    /// Resolves an archived `Weak` from a given optional reference.
263    pub fn resolve_from_ref<U: ArchiveUnsized<Archived = T> + ?Sized>(
264        value: Option<&U>,
265        resolver: RcWeakResolver,
266        out: Place<Self>,
267    ) {
268        match value {
269            None => {
270                munge!(let ArchivedRcWeak { ptr, _phantom: _ } = out);
271                RelPtr::emplace_invalid(ptr);
272            }
273            Some(value) => {
274                let out = unsafe { out.cast_unchecked::<ArchivedRc<T, F>>() };
275                ArchivedRc::resolve_from_ref(value, resolver.inner, out);
276            }
277        }
278    }
279
280    /// Serializes an archived `Weak` from a given optional reference.
281    pub fn serialize_from_ref<U, S>(
282        value: Option<&U>,
283        serializer: &mut S,
284    ) -> Result<RcWeakResolver, S::Error>
285    where
286        U: SerializeUnsized<S, Archived = T> + ?Sized,
287        S: Fallible + Writer + Sharing + ?Sized,
288        S::Error: Source,
289    {
290        Ok(match value {
291            None => RcWeakResolver {
292                inner: RcResolver { pos: 0 },
293            },
294            Some(r) => RcWeakResolver {
295                inner: ArchivedRc::<T, F>::serialize_from_ref(r, serializer)?,
296            },
297        })
298    }
299}
300
301impl<T: ArchivePointee + fmt::Debug + ?Sized, F> fmt::Debug
302    for ArchivedRcWeak<T, F>
303{
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        write!(f, "(Weak)")
306    }
307}
308
309/// The resolver for `rc::Weak`.
310pub struct RcWeakResolver {
311    inner: RcResolver,
312}
313
314#[cfg(feature = "bytecheck")]
315mod verify {
316    use core::{any::TypeId, error::Error, fmt};
317
318    use bytecheck::{
319        rancor::{Fallible, Source},
320        CheckBytes, Verify,
321    };
322    use rancor::fail;
323
324    use crate::{
325        rc::{ArchivedRc, ArchivedRcWeak, Flavor},
326        traits::{ArchivePointee, LayoutRaw},
327        validation::{
328            shared::ValidationState, ArchiveContext, ArchiveContextExt,
329            SharedContext,
330        },
331    };
332
333    #[derive(Debug)]
334    struct CyclicSharedPointerError;
335
336    impl fmt::Display for CyclicSharedPointerError {
337        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338            write!(f, "encountered cyclic shared pointers while validating")
339        }
340    }
341
342    impl Error for CyclicSharedPointerError {}
343
344    unsafe impl<T, F, C> Verify<C> for ArchivedRc<T, F>
345    where
346        T: ArchivePointee + CheckBytes<C> + LayoutRaw + ?Sized + 'static,
347        T::ArchivedMetadata: CheckBytes<C>,
348        F: Flavor,
349        C: Fallible + ArchiveContext + SharedContext + ?Sized,
350        C::Error: Source,
351    {
352        fn verify(&self, context: &mut C) -> Result<(), C::Error> {
353            let ptr = self.ptr.as_ptr_wrapping();
354            let type_id = TypeId::of::<ArchivedRc<T, F>>();
355
356            let addr = ptr as *const u8 as usize;
357            match context.start_shared(addr, type_id)? {
358                ValidationState::Started => {
359                    context.in_subtree(ptr, |context| unsafe {
360                        T::check_bytes(ptr, context)
361                    })?;
362                    context.finish_shared(addr, type_id)?;
363                }
364                ValidationState::Pending => {
365                    if !F::ALLOW_CYCLES {
366                        fail!(CyclicSharedPointerError)
367                    }
368                }
369                ValidationState::Finished => (),
370            }
371
372            Ok(())
373        }
374    }
375
376    unsafe impl<T, F, C> Verify<C> for ArchivedRcWeak<T, F>
377    where
378        T: ArchivePointee + CheckBytes<C> + LayoutRaw + ?Sized + 'static,
379        T::ArchivedMetadata: CheckBytes<C>,
380        F: Flavor,
381        C: Fallible + ArchiveContext + SharedContext + ?Sized,
382        C::Error: Source,
383    {
384        fn verify(&self, context: &mut C) -> Result<(), C::Error> {
385            if self.ptr.is_invalid() {
386                Ok(())
387            } else {
388                // SAFETY: `ArchivedRc` and `ArchivedRcWeak` are
389                // `repr(transparent)` and so have the same layout as each
390                // other.
391                let rc = unsafe {
392                    &*(self as *const Self).cast::<ArchivedRc<T, F>>()
393                };
394                rc.verify(context)
395            }
396        }
397    }
398}