1use core::{borrow::Borrow, cmp, fmt, hash, marker::PhantomData, ops::Deref};
4
5use munge::munge;
6use rancor::{Fallible, Source};
7
8use crate::{
9 primitive::FixedUsize,
10 seal::Seal,
11 ser::{Sharing, SharingExt, Writer, WriterExt as _},
12 traits::ArchivePointee,
13 ArchiveUnsized, Place, Portable, RelPtr, SerializeUnsized,
14};
15
16pub trait Flavor: 'static {
18 const ALLOW_CYCLES: bool;
22}
23
24pub struct RcFlavor;
26
27impl Flavor for RcFlavor {
28 const ALLOW_CYCLES: bool = false;
29}
30
31pub struct ArcFlavor;
33
34impl Flavor for ArcFlavor {
35 const ALLOW_CYCLES: bool = false;
36}
37
38#[derive(Portable)]
45#[rkyv(crate)]
46#[repr(transparent)]
47#[cfg_attr(
48 feature = "bytecheck",
49 derive(bytecheck::CheckBytes),
50 bytecheck(verify)
51)]
52pub struct ArchivedRc<T: ArchivePointee + ?Sized, F> {
53 ptr: RelPtr<T>,
54 _phantom: PhantomData<F>,
55}
56
57impl<T: ArchivePointee + ?Sized, F> ArchivedRc<T, F> {
58 pub fn get(&self) -> &T {
60 unsafe { &*self.ptr.as_ptr() }
61 }
62
63 pub unsafe fn get_seal_unchecked(this: Seal<'_, Self>) -> Seal<'_, T> {
70 munge!(let Self { ptr, _phantom: _ } = this);
71 Seal::new(unsafe { &mut *RelPtr::as_mut_ptr(ptr) })
72 }
73
74 pub fn resolve_from_ref<U: ArchiveUnsized<Archived = T> + ?Sized>(
76 value: &U,
77 resolver: RcResolver,
78 out: Place<Self>,
79 ) {
80 munge!(let ArchivedRc { ptr, .. } = out);
81 RelPtr::emplace_unsized(
82 resolver.pos as usize,
83 value.archived_metadata(),
84 ptr,
85 );
86 }
87
88 pub fn serialize_from_ref<U, S>(
90 value: &U,
91 serializer: &mut S,
92 ) -> Result<RcResolver, S::Error>
93 where
94 U: SerializeUnsized<S> + ?Sized,
95 S: Fallible + Writer + Sharing + ?Sized,
96 S::Error: Source,
97 {
98 let pos = serializer.serialize_shared(value)?;
99
100 if serializer.pos() == pos {
104 serializer.pad(1)?;
105 }
106
107 Ok(RcResolver {
108 pos: pos as FixedUsize,
109 })
110 }
111}
112
113impl<T, F> AsRef<T> for ArchivedRc<T, F>
114where
115 T: ArchivePointee + ?Sized,
116{
117 fn as_ref(&self) -> &T {
118 self.get()
119 }
120}
121
122impl<T, F> Borrow<T> for ArchivedRc<T, F>
123where
124 T: ArchivePointee + ?Sized,
125{
126 fn borrow(&self) -> &T {
127 self.get()
128 }
129}
130
131impl<T, F> fmt::Debug for ArchivedRc<T, F>
132where
133 T: ArchivePointee + fmt::Debug + ?Sized,
134{
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 self.get().fmt(f)
137 }
138}
139
140impl<T, F> Deref for ArchivedRc<T, F>
141where
142 T: ArchivePointee + ?Sized,
143{
144 type Target = T;
145 fn deref(&self) -> &Self::Target {
146 self.get()
147 }
148}
149
150impl<T, F> fmt::Display for ArchivedRc<T, F>
151where
152 T: ArchivePointee + fmt::Display + ?Sized,
153{
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 self.get().fmt(f)
156 }
157}
158
159impl<T, F> Eq for ArchivedRc<T, F> where T: ArchivePointee + Eq + ?Sized {}
160
161impl<T, F> hash::Hash for ArchivedRc<T, F>
162where
163 T: ArchivePointee + hash::Hash + ?Sized,
164{
165 fn hash<H: hash::Hasher>(&self, state: &mut H) {
166 self.get().hash(state)
167 }
168}
169
170impl<T, F> Ord for ArchivedRc<T, F>
171where
172 T: ArchivePointee + Ord + ?Sized,
173{
174 fn cmp(&self, other: &Self) -> cmp::Ordering {
175 self.get().cmp(other.get())
176 }
177}
178
179impl<T, TF, U, UF> PartialEq<ArchivedRc<U, UF>> for ArchivedRc<T, TF>
180where
181 T: ArchivePointee + PartialEq<U> + ?Sized,
182 U: ArchivePointee + ?Sized,
183{
184 fn eq(&self, other: &ArchivedRc<U, UF>) -> bool {
185 self.get().eq(other.get())
186 }
187}
188
189impl<T, TF, U, UF> PartialOrd<ArchivedRc<U, UF>> for ArchivedRc<T, TF>
190where
191 T: ArchivePointee + PartialOrd<U> + ?Sized,
192 U: ArchivePointee + ?Sized,
193{
194 fn partial_cmp(&self, other: &ArchivedRc<U, UF>) -> Option<cmp::Ordering> {
195 self.get().partial_cmp(other.get())
196 }
197}
198
199impl<T, F> fmt::Pointer for ArchivedRc<T, F> {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 fmt::Pointer::fmt(&self.ptr.base(), f)
202 }
203}
204
205pub struct RcResolver {
207 pos: FixedUsize,
208}
209
210impl RcResolver {
211 pub fn from_pos(pos: usize) -> Self {
216 Self {
217 pos: pos as FixedUsize,
218 }
219 }
220}
221
222#[derive(Portable)]
226#[rkyv(crate)]
227#[repr(transparent)]
228#[cfg_attr(
229 feature = "bytecheck",
230 derive(bytecheck::CheckBytes),
231 bytecheck(verify)
232)]
233pub struct ArchivedRcWeak<T: ArchivePointee + ?Sized, F> {
234 ptr: RelPtr<T>,
235 _phantom: PhantomData<F>,
236}
237
238impl<T: ArchivePointee + ?Sized, F> ArchivedRcWeak<T, F> {
239 pub fn upgrade(&self) -> Option<&ArchivedRc<T, F>> {
243 if self.ptr.is_invalid() {
244 None
245 } else {
246 Some(unsafe { &*(self as *const Self).cast() })
247 }
248 }
249
250 pub fn upgrade_seal(
252 this: Seal<'_, Self>,
253 ) -> Option<Seal<'_, ArchivedRc<T, F>>> {
254 let this = unsafe { this.unseal_unchecked() };
255 if this.ptr.is_invalid() {
256 None
257 } else {
258 Some(Seal::new(unsafe { &mut *(this as *mut Self).cast() }))
259 }
260 }
261
262 pub fn resolve_from_ref<U: ArchiveUnsized<Archived = T> + ?Sized>(
264 value: Option<&U>,
265 resolver: RcWeakResolver,
266 out: Place<Self>,
267 ) {
268 match value {
269 None => {
270 munge!(let ArchivedRcWeak { ptr, _phantom: _ } = out);
271 RelPtr::emplace_invalid(ptr);
272 }
273 Some(value) => {
274 let out = unsafe { out.cast_unchecked::<ArchivedRc<T, F>>() };
275 ArchivedRc::resolve_from_ref(value, resolver.inner, out);
276 }
277 }
278 }
279
280 pub fn serialize_from_ref<U, S>(
282 value: Option<&U>,
283 serializer: &mut S,
284 ) -> Result<RcWeakResolver, S::Error>
285 where
286 U: SerializeUnsized<S, Archived = T> + ?Sized,
287 S: Fallible + Writer + Sharing + ?Sized,
288 S::Error: Source,
289 {
290 Ok(match value {
291 None => RcWeakResolver {
292 inner: RcResolver { pos: 0 },
293 },
294 Some(r) => RcWeakResolver {
295 inner: ArchivedRc::<T, F>::serialize_from_ref(r, serializer)?,
296 },
297 })
298 }
299}
300
301impl<T: ArchivePointee + fmt::Debug + ?Sized, F> fmt::Debug
302 for ArchivedRcWeak<T, F>
303{
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 write!(f, "(Weak)")
306 }
307}
308
309pub struct RcWeakResolver {
311 inner: RcResolver,
312}
313
314#[cfg(feature = "bytecheck")]
315mod verify {
316 use core::{any::TypeId, error::Error, fmt};
317
318 use bytecheck::{
319 rancor::{Fallible, Source},
320 CheckBytes, Verify,
321 };
322 use rancor::fail;
323
324 use crate::{
325 rc::{ArchivedRc, ArchivedRcWeak, Flavor},
326 traits::{ArchivePointee, LayoutRaw},
327 validation::{
328 shared::ValidationState, ArchiveContext, ArchiveContextExt,
329 SharedContext,
330 },
331 };
332
333 #[derive(Debug)]
334 struct CyclicSharedPointerError;
335
336 impl fmt::Display for CyclicSharedPointerError {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 write!(f, "encountered cyclic shared pointers while validating")
339 }
340 }
341
342 impl Error for CyclicSharedPointerError {}
343
344 unsafe impl<T, F, C> Verify<C> for ArchivedRc<T, F>
345 where
346 T: ArchivePointee + CheckBytes<C> + LayoutRaw + ?Sized + 'static,
347 T::ArchivedMetadata: CheckBytes<C>,
348 F: Flavor,
349 C: Fallible + ArchiveContext + SharedContext + ?Sized,
350 C::Error: Source,
351 {
352 fn verify(&self, context: &mut C) -> Result<(), C::Error> {
353 let ptr = self.ptr.as_ptr_wrapping();
354 let type_id = TypeId::of::<ArchivedRc<T, F>>();
355
356 let addr = ptr as *const u8 as usize;
357 match context.start_shared(addr, type_id)? {
358 ValidationState::Started => {
359 context.in_subtree(ptr, |context| unsafe {
360 T::check_bytes(ptr, context)
361 })?;
362 context.finish_shared(addr, type_id)?;
363 }
364 ValidationState::Pending => {
365 if !F::ALLOW_CYCLES {
366 fail!(CyclicSharedPointerError)
367 }
368 }
369 ValidationState::Finished => (),
370 }
371
372 Ok(())
373 }
374 }
375
376 unsafe impl<T, F, C> Verify<C> for ArchivedRcWeak<T, F>
377 where
378 T: ArchivePointee + CheckBytes<C> + LayoutRaw + ?Sized + 'static,
379 T::ArchivedMetadata: CheckBytes<C>,
380 F: Flavor,
381 C: Fallible + ArchiveContext + SharedContext + ?Sized,
382 C::Error: Source,
383 {
384 fn verify(&self, context: &mut C) -> Result<(), C::Error> {
385 if self.ptr.is_invalid() {
386 Ok(())
387 } else {
388 let rc = unsafe {
392 &*(self as *const Self).cast::<ArchivedRc<T, F>>()
393 };
394 rc.verify(context)
395 }
396 }
397 }
398}