use std::{
cell::UnsafeCell,
mem::{ManuallyDrop, MaybeUninit},
};
pub(crate) union DeferStore<I, O> {
slots: ManuallyDrop<Vec<DeferSlot<I, O>>>,
inputs: ManuallyDrop<Vec<DeferSlotItem<I>>>,
}
impl<I, O> Drop for DeferStore<I, O> {
#[inline]
fn drop(&mut self) {
unsafe {
if Self::ONLY_INPUTS {
ManuallyDrop::drop(&mut self.inputs)
} else {
ManuallyDrop::drop(&mut self.slots)
}
}
}
}
impl<I, O> Default for DeferStore<I, O> {
#[inline]
fn default() -> Self {
unsafe {
if Self::ONLY_INPUTS {
Self { inputs: ManuallyDrop::new(Vec::new()) }
} else {
Self { slots: ManuallyDrop::new(Vec::new()) }
}
}
}
}
impl<I, O> DeferStore<I, O> {
const ONLY_INPUTS: bool = !std::mem::needs_drop::<O>();
#[inline]
pub fn prepare(&mut self, sample_size: usize) {
macro_rules! imp {
($vec:expr) => {{
$vec.clear();
$vec.reserve_exact(sample_size);
unsafe { $vec.set_len(sample_size) }
}};
}
unsafe {
if Self::ONLY_INPUTS {
imp!(self.inputs)
} else {
imp!(self.slots)
}
}
}
#[inline(always)]
pub fn slots(&self) -> Result<&[DeferSlot<I, O>], &[DeferSlotItem<I>]> {
unsafe {
if Self::ONLY_INPUTS {
Err(&self.inputs)
} else {
Ok(&self.slots)
}
}
}
}
#[repr(C)]
pub(crate) struct DeferSlot<I, O> {
pub input: DeferSlotItem<I>,
pub output: DeferSlotItem<O>,
}
type DeferSlotItem<T> = UnsafeCell<MaybeUninit<T>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn access_uninit_slot() {
let mut slot: MaybeUninit<DeferSlot<String, String>> = MaybeUninit::uninit();
let slot_ref = unsafe { slot.assume_init_mut() };
slot_ref.input = UnsafeCell::new(MaybeUninit::new(String::new()));
slot_ref.output = UnsafeCell::new(MaybeUninit::new(String::new()));
unsafe {
let slot = slot.assume_init();
assert_eq!(slot.input.into_inner().assume_init(), "");
assert_eq!(slot.output.into_inner().assume_init(), "");
}
}
#[test]
fn access_aliased_input() {
struct Output<'i> {
input: &'i mut String,
}
impl Drop for Output<'_> {
fn drop(&mut self) {
assert_eq!(self.input, "hello");
self.input.push_str(" world");
}
}
let slot: MaybeUninit<DeferSlot<String, Output>> = MaybeUninit::uninit();
let slot_ref = unsafe { slot.assume_init_ref() };
for _ in 0..5 {
unsafe {
let input_ptr = slot_ref.input.get().cast::<String>();
let output_ptr = slot_ref.output.get().cast::<Output>();
input_ptr.write("hello".to_owned());
output_ptr.write(Output { input: &mut *input_ptr });
assert_eq!((*output_ptr).input, "hello");
output_ptr.drop_in_place();
assert_eq!(&*input_ptr, "hello world");
input_ptr.drop_in_place();
}
}
}
}