use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident};
use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait};
pub(crate) fn generate_tag_enum(repr: &EnumRepr, data: &DataEnum) -> TokenStream {
let variants = data.variants.iter().map(|v| {
let ident = &v.ident;
if let Some((eq, discriminant)) = &v.discriminant {
quote! { #ident #eq #discriminant }
} else {
quote! { #ident }
}
});
let repr = match repr {
EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
EnumRepr::Compound(c, _) => quote! { #c },
};
quote! {
#repr
#[allow(dead_code)]
enum ___ZerocopyTag {
#(#variants,)*
}
}
}
fn tag_ident(variant_ident: &Ident) -> Ident {
Ident::new(&format!("___ZEROCOPY_TAG_{}", variant_ident), variant_ident.span())
}
fn generate_tag_consts(data: &DataEnum) -> TokenStream {
let tags = data.variants.iter().map(|v| {
let variant_ident = &v.ident;
let tag_ident = tag_ident(variant_ident);
quote! {
#[allow(non_upper_case_globals)]
const #tag_ident: ___ZerocopyTagPrimitive =
___ZerocopyTag::#variant_ident as ___ZerocopyTagPrimitive;
}
});
quote! {
#(#tags)*
}
}
fn variant_struct_ident(variant_ident: &Ident) -> Ident {
Ident::new(&format!("___ZerocopyVariantStruct_{}", variant_ident), variant_ident.span())
}
fn generate_variant_structs(
enum_name: &Ident,
generics: &Generics,
data: &DataEnum,
) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let phantom_ty = quote! {
core_reexport::marker::PhantomData<#enum_name #ty_generics>
};
let variant_structs = data.variants.iter().filter_map(|variant| {
if matches!(variant.fields, Fields::Unit) {
return None;
}
let variant_struct_ident = variant_struct_ident(&variant.ident);
let field_types = variant.fields.iter().map(|f| &f.ty);
let variant_struct = parse_quote! {
#[repr(C)]
#[allow(non_snake_case)]
struct #variant_struct_ident #impl_generics (
core_reexport::mem::MaybeUninit<___ZerocopyInnerTag>,
#(#field_types,)*
#phantom_ty,
) #where_clause;
};
let try_from_bytes_impl = derive_try_from_bytes_inner(&variant_struct, Trait::TryFromBytes)
.expect("derive_try_from_bytes_inner should not fail on synthesized type");
Some(quote! {
#variant_struct
#try_from_bytes_impl
})
});
quote! {
#(#variant_structs)*
}
}
fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream {
let (_, ty_generics, _) = generics.split_for_impl();
let fields = data.variants.iter().filter_map(|variant| {
if matches!(variant.fields, Fields::Unit) {
return None;
}
let field_name = Ident::new(&format!("__field_{}", &variant.ident), variant.ident.span());
let variant_struct_ident = variant_struct_ident(&variant.ident);
Some(quote! {
#field_name: core_reexport::mem::ManuallyDrop<
#variant_struct_ident #ty_generics
>,
})
});
quote! {
#[repr(C)]
#[allow(non_snake_case)]
union ___ZerocopyVariants #generics {
#(#fields)*
__nonempty: (),
}
}
}
pub(crate) fn derive_is_bit_valid(
enum_ident: &Ident,
repr: &EnumRepr,
generics: &Generics,
data: &DataEnum,
) -> Result<TokenStream, Error> {
let trait_path = Trait::TryFromBytes.crate_path();
let tag_enum = generate_tag_enum(repr, data);
let tag_consts = generate_tag_consts(data);
let (outer_tag_type, inner_tag_type) = if repr.is_c() {
(quote! { ___ZerocopyTag }, quote! { () })
} else if repr.is_primitive() {
(quote! { () }, quote! { ___ZerocopyTag })
} else {
return Err(Error::new(
Span::call_site(),
"must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
));
};
let variant_structs = generate_variant_structs(enum_ident, generics, data);
let variants_union = generate_variants_union(generics, data);
let (_, ty_generics, _) = generics.split_for_impl();
let match_arms = data.variants.iter().map(|variant| {
let tag_ident = tag_ident(&variant.ident);
let variant_struct_ident = variant_struct_ident(&variant.ident);
if matches!(variant.fields, Fields::Unit) {
quote! {
#tag_ident => true
}
} else {
quote! {
#tag_ident => {
let variant = unsafe {
variants.cast_unsized(
|p: *mut ___ZerocopyVariants #ty_generics| {
p as *mut #variant_struct_ident #ty_generics
}
)
};
let variant = unsafe { variant.assume_initialized() };
<
#variant_struct_ident #ty_generics as #trait_path
>::is_bit_valid(variant)
}
}
}
});
Ok(quote! {
fn is_bit_valid<___ZerocopyAliasing>(
mut candidate: ::zerocopy::Maybe<'_, Self, ___ZerocopyAliasing>,
) -> ::zerocopy::util::macro_util::core_reexport::primitive::bool
where
___ZerocopyAliasing: ::zerocopy::pointer::invariant::Aliasing
+ ::zerocopy::pointer::invariant::AtLeast<::zerocopy::pointer::invariant::Shared>,
{
use ::zerocopy::util::macro_util::core_reexport;
#tag_enum
type ___ZerocopyTagPrimitive = ::zerocopy::util::macro_util::SizeToTag<
{ core_reexport::mem::size_of::<___ZerocopyTag>() },
>;
#tag_consts
type ___ZerocopyOuterTag = #outer_tag_type;
type ___ZerocopyInnerTag = #inner_tag_type;
#variant_structs
#variants_union
#[repr(C)]
struct ___ZerocopyRawEnum #generics {
tag: ___ZerocopyOuterTag,
variants: ___ZerocopyVariants #ty_generics,
}
let tag = {
let tag_ptr = unsafe {
candidate.reborrow().cast_unsized(|p: *mut Self| {
p as *mut ___ZerocopyTagPrimitive
})
};
let tag_ptr = unsafe { tag_ptr.assume_initialized() };
tag_ptr.bikeshed_recall_valid().read_unaligned::<::zerocopy::BecauseImmutable>()
};
let raw_enum = unsafe {
candidate.cast_unsized(|p: *mut Self| {
p as *mut ___ZerocopyRawEnum #ty_generics
})
};
let raw_enum = unsafe { raw_enum.assume_initialized() };
let variants = unsafe {
raw_enum.project(|p: *mut ___ZerocopyRawEnum #ty_generics| {
core_reexport::ptr::addr_of_mut!((*p).variants)
})
};
#[allow(non_upper_case_globals)]
match tag {
#(#match_arms,)*
_ => false,
}
}
})
}