[go: up one dir, main page]

float_eq_derive 0.4.1

Derive macro support for float_eq.
Documentation
//! Derive macros for the traits provided by the [float_eq] crate.
//!
//! [float_eq]: https://crates.io/crates/float_eq

extern crate proc_macro;

use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, TokenStreamExt};
use syn::{parse_macro_input, DeriveInput};

mod read;

/// Helper for deriving the various float_eq traits.
///
/// By default, this will derive [`FloatUlps`], [`FloatEq`], [`FloatDiff`] and
/// [`FloatEqDebug`] for a struct type. Attribute parameters are passed through
/// to the `#[float_eq(...)]` attribute, see the docs for each trait for more
/// details. You will need to provide an `ulps` parameter with a type name for
/// the derived `Ulps` type.
///
/// If the optional `all_epsilon` parameter is provided then [`FloatEqAll`] and
/// [`FloatEqAllDebug`] are also derived.
///
/// [Example usage] is available in the top level `float_eq` documentation.
///
/// [`FloatUlps`]: trait.FloatUlps.html
/// [`FloatEq`]: trait.FloatEq.html
/// [`FloatEqAll`]: trait.FloatEqAll.html
/// [`FloatDiff`]: trait.FloatDiff.html
/// [`FloatEqDebug`]: trait.FloatEqDebug.html
/// [`FloatEqAllDebug`]: trait.FloatEqAllDebug.html
/// [Example usage]: index.html#derivable
#[proc_macro_attribute]
pub fn derive_float_eq(
    args: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let args = parse_macro_input!(args as syn::AttributeArgs);
    let item = parse_macro_input!(item as syn::ItemStruct);

    let derive_all = args.iter().map(read::name_type_pair).any(|nv| {
        if let Ok(nv) = nv {
            nv.name == "all_epsilon"
        } else {
            false
        }
    });

    let mut trait_names = vec!["FloatUlps", "FloatEq", "FloatDiff", "FloatEqDebug"];
    if derive_all {
        trait_names.push("FloatEqAll");
        trait_names.push("FloatEqAllDebug");
    }

    let mut traits = TokenStream::new();
    trait_names.into_iter().for_each(|ty| {
        let ident = Ident::new(ty, Span::call_site());
        traits.append_all(quote! { float_eq::#ident, });
    });

    let tokens = quote! {
        #[derive(#traits)]
        #[float_eq(#(#args,)*)]
        #item
    };
    tokens.into()
}

#[doc(hidden)]
#[proc_macro_derive(FloatUlps, attributes(float_eq))]
pub fn derive_float_ulps(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_ulps(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_ulps(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let vis = &input.vis;
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatUlps", &input)?;
    let params = read::float_eq_attr(&input)?;
    let ulps_name = params.ulps_type()?;

    let ulps_type = match fields.ty {
        read::FieldListType::Named => {
            let ulps_fields = fields.expand(|field| {
                let name = &field.name;
                let ty = &field.ty;
                quote! { #name: float_eq::Ulps<#ty> }
            });
            quote! {
                #vis struct #ulps_name {
                    #(#ulps_fields,)*
                }
            }
        }
        read::FieldListType::Tuple => {
            let ulps_fields = fields.expand(|field| {
                let ty = &field.ty;
                quote! { float_eq::Ulps<#ty> }
            });
            quote! {
                #vis struct #ulps_name( #(#ulps_fields,)* );
            }
        }
        read::FieldListType::Unit => quote! {
            #vis struct #ulps_name;
        },
    };

    let doc = format!(
        "Floating point ULPs representation derived from {}",
        struct_name.to_string()
    );
    Ok(quote! {
        #[doc = #doc]
        #[derive(Debug, PartialEq)]
        #ulps_type

        impl float_eq::FloatUlps for #struct_name {
            type Ulps = #ulps_name;
        }
    })
}

#[doc(hidden)]
#[proc_macro_derive(FloatDiff, attributes(float_eq))]
pub fn derive_float_diff(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_diff(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_diff(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatDiff", &input)?;
    let params = read::float_eq_attr(&input)?;
    let ulps_name = params.ulps_type()?;

    let expand_fields = |method| {
        let method = Ident::new(method, Span::call_site());
        fields.expand(|field| {
            let name = &field.name;
            quote! { #name: self.#name.#method(&other.#name) }
        })
    };

    let abs_diff_fields = expand_fields("abs_diff");
    let ulps_diff_fields = expand_fields("ulps_diff");

    Ok(quote! {
        impl float_eq::FloatDiff for #struct_name {
            type Output = Self;

            #[inline]
            fn abs_diff(&self, other: &Self) -> Self {
                Self {
                    #(#abs_diff_fields,)*
                }
            }

            #[inline]
            fn ulps_diff(&self, other: &Self) -> Option<#ulps_name> {
                Some(
                    #ulps_name {
                        #(#ulps_diff_fields?,)*
                    }
                )
            }
        }
    })
}

#[doc(hidden)]
#[proc_macro_derive(FloatEq, attributes(float_eq))]
pub fn derive_float_eq_attribute(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_eq(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_eq(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatEq", &input)?;
    let params = read::float_eq_attr(&input)?;
    let ulps_name = params.ulps_type()?;

    let expand_exprs = |method| {
        let mut expanded = fields.expand(|field| {
            let name = &field.name;
            let method = Ident::new(method, Span::call_site());
            quote! { self.#name.#method(&other.#name, &max_diff.#name) }
        });
        if expanded.is_empty() {
            expanded.push(quote! { true });
        }
        expanded
    };

    let eq_abs = expand_exprs("eq_abs");
    let eq_rel = expand_exprs("eq_rel");
    let eq_ulps = expand_exprs("eq_ulps");

    Ok(quote! {
        impl float_eq::FloatEq for #struct_name {
            type Epsilon = Self;

            #[inline]
            fn eq_abs(&self, other: &Self, max_diff: &Self) -> bool {
                #(#eq_abs)&&*
            }

            #[inline]
            fn eq_rel(&self, other: &Self, max_diff: &Self) -> bool {
                #(#eq_rel)&&*
            }

            #[inline]
            fn eq_ulps(&self, other: &Self, max_diff: &#ulps_name) -> bool {
                #(#eq_ulps)&&*
            }
        }
    })
}

#[doc(hidden)]
#[proc_macro_derive(FloatEqDebug, attributes(float_eq))]
pub fn derive_float_eq_debug(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_eq_debug(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_eq_debug(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatEqDebug", &input)?;
    let params = read::float_eq_attr(&input)?;
    let ulps_name = params.ulps_type()?;

    let expand_fields = |method| {
        fields.expand(|field| {
            let name = &field.name;
            let method = Ident::new(method, Span::call_site());
            quote! { #name: self.#name.#method(&other.#name, &max_diff.#name) }
        })
    };

    let abs_eps_fields = expand_fields("debug_abs_epsilon");
    let rel_eps_fields = expand_fields("debug_rel_epsilon");
    let ulps_eps_fields = expand_fields("debug_ulps_epsilon");

    Ok(quote! {
        impl float_eq::FloatEqDebug for #struct_name {
            type DebugEpsilon = Self;

            #[inline]
            fn debug_abs_epsilon(&self, other: &Self, max_diff: &Self) -> Self {
                Self {
                    #(#abs_eps_fields,)*
                }
            }

            #[inline]
            fn debug_rel_epsilon(&self, other: &Self, max_diff: &Self) -> Self {
                Self {
                    #(#rel_eps_fields,)*
                }
            }

            #[inline]
            fn debug_ulps_epsilon(&self, other: &Self, max_diff: &#ulps_name) -> #ulps_name {
                #ulps_name {
                    #(#ulps_eps_fields,)*
                }
            }
        }
    })
}

#[doc(hidden)]
#[proc_macro_derive(FloatEqAll, attributes(float_eq))]
pub fn derive_float_eq_all(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_eq_all(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_eq_all(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatEqAll", &input)?;
    let params = read::float_eq_attr(&input)?;
    let all_epsilon = params.all_epsilon_type()?;

    let expand_exprs = |method| {
        let mut expanded = fields.expand(|field| {
            let name = &field.name;
            let method = Ident::new(method, Span::call_site());
            quote! { self.#name.#method(&other.#name, max_diff) }
        });
        if expanded.is_empty() {
            expanded.push(quote! { true });
        }
        expanded
    };

    let eq_abs = expand_exprs("eq_abs_all");
    let eq_rel = expand_exprs("eq_rel_all");
    let eq_ulps = expand_exprs("eq_ulps_all");

    Ok(quote! {
        impl float_eq::FloatEqAll for #struct_name {
            type AllEpsilon = #all_epsilon;

            #[inline]
            fn eq_abs_all(&self, other: &Self, max_diff: &#all_epsilon) -> bool {
                #(#eq_abs)&&*
            }

            #[inline]
            fn eq_rel_all(&self, other: &Self, max_diff: &#all_epsilon) -> bool {
                #(#eq_rel)&&*
            }

            #[inline]
            fn eq_ulps_all(&self, other: &Self, max_diff: &::float_eq::Ulps<Self::AllEpsilon>) -> bool {
                #(#eq_ulps)&&*
            }
        }
    })
}

#[doc(hidden)]
#[proc_macro_derive(FloatEqAllDebug, attributes(float_eq))]
pub fn derive_float_eq_all_debug(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    expand_float_eq_all_debug(input)
        .unwrap_or_else(|e| e.to_compile_error())
        .into()
}

fn expand_float_eq_all_debug(input: DeriveInput) -> Result<TokenStream, syn::Error> {
    let struct_name = &input.ident;
    let fields = read::all_fields_info("FloatEqAllDebug", &input)?;
    let params = read::float_eq_attr(&input)?;
    let all_epsilon = params.all_epsilon_type()?;

    let expand_fields = |method| {
        fields.expand(|field| {
            let name = &field.name;
            let method = Ident::new(method, Span::call_site());
            quote! { #name: self.#name.#method(&other.#name, max_diff) }
        })
    };

    let abs_eps_fields = expand_fields("debug_abs_all_epsilon");
    let rel_eps_fields = expand_fields("debug_rel_all_epsilon");
    let ulps_eps_fields = expand_fields("debug_ulps_all_epsilon");

    Ok(quote! {
        impl float_eq::FloatEqAllDebug for #struct_name {
            type AllDebugEpsilon = Self;

            #[inline]
            fn debug_abs_all_epsilon(&self, other: &Self, max_diff: &#all_epsilon) -> Self {
                Self {
                    #(#abs_eps_fields,)*
                }
            }

            #[inline]
            fn debug_rel_all_epsilon(&self, other: &Self, max_diff: &#all_epsilon) -> Self {
                Self {
                    #(#rel_eps_fields,)*
                }
            }

            #[inline]
            fn debug_ulps_all_epsilon(
                &self,
                other: &Self,
                max_diff: &::float_eq::Ulps<Self::AllEpsilon>
            ) -> float_eq::Ulps<Self::AllDebugEpsilon> {
                float_eq::Ulps::<Self::AllDebugEpsilon> {
                    #(#ulps_eps_fields,)*
                }
            }
        }
    })
}