#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(missing_docs)]
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Literal, TokenTree};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use std::ops::Deref;
use syn::Result as SynResult;
#[proc_macro_attribute]
pub fn serial(attr: TokenStream, input: TokenStream) -> TokenStream {
local_serial_core(attr.into(), input.into()).into()
}
#[proc_macro_attribute]
pub fn parallel(attr: TokenStream, input: TokenStream) -> TokenStream {
local_parallel_core(attr.into(), input.into()).into()
}
#[proc_macro_attribute]
#[cfg_attr(docsrs, doc(cfg(feature = "file_locks")))]
pub fn file_serial(attr: TokenStream, input: TokenStream) -> TokenStream {
fs_serial_core(attr.into(), input.into()).into()
}
#[proc_macro_attribute]
#[cfg_attr(docsrs, doc(cfg(feature = "file_locks")))]
pub fn file_parallel(attr: TokenStream, input: TokenStream) -> TokenStream {
fs_parallel_core(attr.into(), input.into()).into()
}
#[derive(Default, Debug, Clone)]
struct QuoteOption<T>(Option<T>);
impl<T: ToTokens> ToTokens for QuoteOption<T> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
tokens.append_all(match self.0 {
Some(ref t) => quote! { ::std::option::Option::Some(#t) },
None => quote! { ::std::option::Option::None },
});
}
}
#[derive(Default, Debug)]
struct Config {
names: Vec<String>,
path: QuoteOption<String>,
}
fn string_from_literal(literal: Literal) -> String {
let string_literal = literal.to_string();
if !string_literal.starts_with('\"') || !string_literal.ends_with('\"') {
panic!("Expected a string literal, got '{}'", string_literal);
}
string_literal[1..string_literal.len() - 1].to_string()
}
fn get_config(attr: proc_macro2::TokenStream) -> Config {
let mut attrs = attr.into_iter().collect::<Vec<TokenTree>>();
let mut raw_args: Vec<String> = Vec::new();
let mut in_path: bool = false;
let mut path: Option<String> = None;
while !attrs.is_empty() {
match attrs.remove(0) {
TokenTree::Ident(id) if id.to_string().eq_ignore_ascii_case("path") => {
in_path = true;
}
TokenTree::Ident(id) => {
let name = id.to_string();
raw_args.push(name);
}
x => {
panic!(
"Expected literal as key args (or a 'path => '\"foo\"'), not {}",
x
);
}
}
if in_path {
if attrs.len() < 3 {
panic!("Expected a '=> <path>' after 'path'");
}
match attrs.remove(0) {
TokenTree::Punct(p) if p.as_char() == '=' => {}
x => {
panic!("Expected = after path, not {}", x);
}
}
match attrs.remove(0) {
TokenTree::Punct(p) if p.as_char() == '>' => {}
x => {
panic!("Expected > after path, not {}", x);
}
}
match attrs.remove(0) {
TokenTree::Literal(literal) => {
path = Some(string_from_literal(literal));
}
x => {
panic!("Expected literals as path arg, not {}", x);
}
}
in_path = false;
}
if !attrs.is_empty() {
match attrs.remove(0) {
TokenTree::Punct(p) if p.as_char() == ',' => {}
x => {
panic!("Expected , between args, not {}", x);
}
}
}
}
if raw_args.is_empty() {
raw_args.push(String::new());
}
raw_args.sort(); Config {
names: raw_args,
path: QuoteOption(path),
}
}
fn local_serial_core(
attr: proc_macro2::TokenStream,
input: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let config = get_config(attr);
serial_setup(input, config, "local")
}
fn local_parallel_core(
attr: proc_macro2::TokenStream,
input: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let config = get_config(attr);
parallel_setup(input, config, "local")
}
fn fs_serial_core(
attr: proc_macro2::TokenStream,
input: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let config = get_config(attr);
serial_setup(input, config, "fs")
}
fn fs_parallel_core(
attr: proc_macro2::TokenStream,
input: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let config = get_config(attr);
parallel_setup(input, config, "fs")
}
#[allow(clippy::cmp_owned)]
fn core_setup(
input: proc_macro2::TokenStream,
config: &Config,
prefix: &str,
kind: &str,
) -> proc_macro2::TokenStream {
let fn_ast: SynResult<syn::ItemFn> = syn::parse2(input.clone());
if let Ok(ast) = fn_ast {
return fn_setup(ast, config, prefix, kind);
};
let mod_ast: SynResult<syn::ItemMod> = syn::parse2(input);
match mod_ast {
Ok(mut ast) => {
let new_content = ast.content.clone().map(|(brace, items)| {
let new_items = items
.into_iter()
.map(|item| match item {
syn::Item::Fn(item_fn)
if item_fn.attrs.iter().any(|attr| {
attr.meta
.path()
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<String>>()
.join("::")
.contains("test")
}) =>
{
let tokens = fn_setup(item_fn, config, prefix, kind);
let token_display = format!("tokens: {tokens}");
syn::parse2(tokens).expect(&token_display)
}
other => other,
})
.collect();
(brace, new_items)
});
if let Some(nc) = new_content {
ast.content.replace(nc);
}
ast.attrs.retain(|attr| {
attr.meta.path().segments.first().unwrap().ident.to_string() != "serial"
});
ast.into_token_stream()
}
Err(_) => {
panic!("Attribute applied to something other than mod or fn!");
}
}
}
fn fn_setup(
ast: syn::ItemFn,
config: &Config,
prefix: &str,
kind: &str,
) -> proc_macro2::TokenStream {
let asyncness = ast.sig.asyncness;
if asyncness.is_some() && cfg!(not(feature = "async")) {
panic!("async testing attempted with async feature disabled in serial_test!");
}
let vis = ast.vis;
let name = ast.sig.ident;
#[cfg(all(feature = "test_logging", not(test)))]
let print_name = {
let print_str = format!("Starting {name}");
quote! {
println!(#print_str);
}
};
#[cfg(any(not(feature = "test_logging"), test))]
let print_name = quote! {};
let return_type = match ast.sig.output {
syn::ReturnType::Default => None,
syn::ReturnType::Type(_rarrow, ref box_type) => Some(box_type.deref()),
};
let block = ast.block;
let attrs: Vec<syn::Attribute> = ast.attrs.into_iter().collect();
let names = config.names.clone();
let path = config.path.clone();
if let Some(ret) = return_type {
match asyncness {
Some(_) => {
let fnname = format_ident!("{}_async_{}_core_with_return", prefix, kind);
let temp_fn = format_ident!("_{}_internal", name);
quote! {
#(#attrs)
*
#vis async fn #name () -> #ret {
async fn #temp_fn () -> #ret
#block
#print_name
serial_test::#fnname(vec![#(#names ),*], #path, #temp_fn()).await
}
}
}
None => {
let fnname = format_ident!("{}_{}_core_with_return", prefix, kind);
quote! {
#(#attrs)
*
#vis fn #name () -> #ret {
#print_name
serial_test::#fnname(vec![#(#names ),*], #path, || #block )
}
}
}
}
} else {
match asyncness {
Some(_) => {
let fnname = format_ident!("{}_async_{}_core", prefix, kind);
let temp_fn = format_ident!("_{}_internal", name);
quote! {
#(#attrs)
*
#vis async fn #name () {
async fn #temp_fn ()
#block
#print_name
serial_test::#fnname(vec![#(#names ),*], #path, #temp_fn()).await;
}
}
}
None => {
let fnname = format_ident!("{}_{}_core", prefix, kind);
quote! {
#(#attrs)
*
#vis fn #name () {
#print_name
serial_test::#fnname(vec![#(#names ),*], #path, || #block );
}
}
}
}
}
}
fn serial_setup(
input: proc_macro2::TokenStream,
config: Config,
prefix: &str,
) -> proc_macro2::TokenStream {
core_setup(input, &config, prefix, "serial")
}
fn parallel_setup(
input: proc_macro2::TokenStream,
config: Config,
prefix: &str,
) -> proc_macro2::TokenStream {
core_setup(input, &config, prefix, "parallel")
}
#[cfg(test)]
mod tests {
use super::{fs_serial_core, local_serial_core};
use proc_macro2::TokenStream;
use quote::quote;
use std::iter::FromIterator;
fn init() {
let _ = env_logger::builder().is_test(false).try_init();
}
fn unparse(input: TokenStream) -> String {
let item = syn::parse2(input).unwrap();
let file = syn::File {
attrs: vec![],
items: vec![item],
shebang: None,
};
prettyplease::unparse(&file)
}
fn compare_streams(first: TokenStream, second: TokenStream) {
let f = unparse(first);
assert_eq!(f, unparse(second));
}
#[test]
fn test_serial() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[test]
fn foo() {}
};
let stream = local_serial_core(attrs.into(), input);
let compare = quote! {
#[test]
fn foo () {
serial_test::local_serial_core(vec![""], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_serial_with_pub() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[test]
pub fn foo() {}
};
let stream = local_serial_core(attrs.into(), input);
let compare = quote! {
#[test]
pub fn foo () {
serial_test::local_serial_core(vec![""], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_other_attributes() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[test]
#[ignore]
#[should_panic(expected = "Testing panic")]
#[something_else]
fn foo() {}
};
let stream = local_serial_core(attrs.into(), input);
let compare = quote! {
#[test]
#[ignore]
#[should_panic(expected = "Testing panic")]
#[something_else]
fn foo () {
serial_test::local_serial_core(vec![""], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
#[cfg(feature = "async")]
fn test_serial_async() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
async fn foo() {}
};
let stream = local_serial_core(attrs.into(), input);
let compare = quote! {
async fn foo () {
async fn _foo_internal () { }
serial_test::local_async_serial_core(vec![""], ::std::option::Option::None, _foo_internal() ).await;
}
};
assert_eq!(format!("{}", compare), format!("{}", stream));
}
#[test]
#[cfg(feature = "async")]
fn test_serial_async_return() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
async fn foo() -> Result<(), ()> { Ok(()) }
};
let stream = local_serial_core(attrs.into(), input);
let compare = quote! {
async fn foo () -> Result<(), ()> {
async fn _foo_internal () -> Result<(), ()> { Ok(()) }
serial_test::local_async_serial_core_with_return(vec![""], ::std::option::Option::None, _foo_internal() ).await
}
};
assert_eq!(format!("{}", compare), format!("{}", stream));
}
#[test]
fn test_file_serial() {
init();
let attrs: Vec<_> = quote! { foo }.into_iter().collect();
let input = quote! {
#[test]
fn foo() {}
};
let stream = fs_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[test]
fn foo () {
serial_test::fs_serial_core(vec!["foo"], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_file_serial_no_args() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[test]
fn foo() {}
};
let stream = fs_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[test]
fn foo () {
serial_test::fs_serial_core(vec![""], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_file_serial_with_path() {
init();
let attrs: Vec<_> = quote! { foo, path => "bar_path" }.into_iter().collect();
let input = quote! {
#[test]
fn foo() {}
};
let stream = fs_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[test]
fn foo () {
serial_test::fs_serial_core(vec!["foo"], ::std::option::Option::Some("bar_path"), || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_single_attr() {
init();
let attrs: Vec<_> = quote! { one}.into_iter().collect();
let input = quote! {
#[test]
fn single() {}
};
let stream = local_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[test]
fn single () {
serial_test::local_serial_core(vec!["one"], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_multiple_attr() {
init();
let attrs: Vec<_> = quote! { two, one }.into_iter().collect();
let input = quote! {
#[test]
fn multiple() {}
};
let stream = local_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[test]
fn multiple () {
serial_test::local_serial_core(vec!["one", "two"], ::std::option::Option::None, || {} );
}
};
compare_streams(compare, stream);
}
#[test]
fn test_mod() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[cfg(test)]
#[serial]
mod serial_attr_tests {
pub fn foo() {
println!("Nothing");
}
#[test]
fn bar() {}
}
};
let stream = local_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[cfg(test)]
mod serial_attr_tests {
pub fn foo() {
println!("Nothing");
}
#[test]
fn bar() {
serial_test::local_serial_core(vec![""], ::std::option::Option::None, || {} );
}
}
};
compare_streams(compare, stream);
}
#[test]
fn test_later_test_mod() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[cfg(test)]
#[serial]
mod serial_attr_tests {
pub fn foo() {
println!("Nothing");
}
#[demo_library::test]
fn bar() {}
}
};
let stream = local_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[cfg(test)]
mod serial_attr_tests {
pub fn foo() {
println!("Nothing");
}
#[demo_library::test]
fn bar() {
serial_test::local_serial_core(vec![""], ::std::option::Option::None, || {} );
}
}
};
compare_streams(compare, stream);
}
#[test]
#[cfg(feature = "async")]
fn test_mod_with_async() {
init();
let attrs = proc_macro2::TokenStream::new();
let input = quote! {
#[cfg(test)]
#[serial]
mod serial_attr_tests {
#[demo_library::test]
async fn foo() -> Result<(), ()> {
Ok(())
}
#[demo_library::test]
#[ignore = "bla"]
async fn bar() -> Result<(), ()> {
Ok(())
}
}
};
let stream = local_serial_core(
proc_macro2::TokenStream::from_iter(attrs.into_iter()),
input,
);
let compare = quote! {
#[cfg(test)]
mod serial_attr_tests {
#[demo_library::test]
async fn foo() -> Result<(), ()> {
async fn _foo_internal() -> Result<(), ()> { Ok(())}
serial_test::local_async_serial_core_with_return(vec![""], ::std::option::Option::None, _foo_internal() ).await
}
#[demo_library::test]
#[ignore = "bla"]
async fn bar() -> Result<(), ()> {
async fn _bar_internal() -> Result<(), ()> { Ok(())}
serial_test::local_async_serial_core_with_return(vec![""], ::std::option::Option::None, _bar_internal() ).await
}
}
};
compare_streams(compare, stream);
}
}