use std::{
any::{Any, TypeId},
borrow::Cow,
mem, slice,
sync::OnceLock,
};
use crate::{util::ty::TypeCast, Bencher};
pub struct BenchArgs {
args: OnceLock<ErasedArgsSlice>,
}
#[derive(Clone, Copy)]
pub struct BenchArgsRunner {
args: &'static ErasedArgsSlice,
bench: fn(Bencher, &ErasedArgsSlice, arg_index: usize),
}
struct ErasedArgsSlice {
args: *const (),
names: *const &'static str,
len: usize,
arg_type: TypeId,
}
unsafe impl Send for ErasedArgsSlice {}
unsafe impl Sync for ErasedArgsSlice {}
impl BenchArgs {
pub const fn new() -> Self {
Self { args: OnceLock::new() }
}
pub fn runner<I, B>(
&'static self,
make_args: impl FnOnce() -> I,
arg_to_string: impl Fn(&I::Item) -> String,
_bench_impl: B,
) -> BenchArgsRunner
where
I: IntoIterator,
I::Item: Any + Send + Sync,
B: FnOnce(Bencher, &I::Item) + Copy,
{
let args = self.args.get_or_init(|| {
let args_iter = make_args().into_iter();
let args_strings: Option<&'static [&str]> =
args_iter.cast_ref::<slice::Iter<&str>>().map(|iter| iter.as_slice());
let args: &'static [I::Item] = Box::leak(args_iter.collect());
let names: &'static [&str] = 'names: {
if let Some(args) = args_strings {
break 'names args;
}
if let Some(args) = args.cast_ref::<&[&str]>() {
break 'names args;
}
Box::leak(
args.iter()
.map(|arg| -> &str {
if let Some(arg) = arg.cast_ref::<String>() {
return arg;
}
if let Some(arg) = arg.cast_ref::<Box<str>>() {
return arg;
}
if let Some(arg) = arg.cast_ref::<Cow<str>>() {
return arg;
}
Box::leak(arg_to_string(arg).into_boxed_str())
})
.collect(),
)
};
ErasedArgsSlice {
args: crate::black_box(args.as_ptr().cast()),
names: names.as_ptr(),
len: args.len(),
arg_type: TypeId::of::<I::Item>(),
}
});
BenchArgsRunner { args, bench: bench::<I::Item, B> }
}
}
impl Default for BenchArgs {
fn default() -> Self {
Self::new()
}
}
impl BenchArgsRunner {
#[inline]
pub(crate) fn bench(&self, bencher: Bencher, index: usize) {
(self.bench)(bencher, self.args, index)
}
#[inline]
pub(crate) fn arg_names(&self) -> &'static [&'static str] {
self.args.names()
}
}
impl ErasedArgsSlice {
#[inline]
fn typed_args<T: Any>(&self) -> Option<&[T]> {
if self.arg_type == TypeId::of::<T>() {
Some(unsafe { slice::from_raw_parts(self.args.cast(), self.len) })
} else {
None
}
}
#[inline]
fn names(&self) -> &'static [&str] {
unsafe { slice::from_raw_parts(self.names, self.len) }
}
}
fn bench<T, B>(bencher: Bencher, erased_args: &ErasedArgsSlice, arg_index: usize)
where
T: Any,
B: FnOnce(Bencher, &T) + Copy,
{
let Some(typed_args) = erased_args.typed_args::<T>() else {
type_mismatch::<T>();
#[cold]
#[inline(never)]
fn type_mismatch<T>() -> ! {
unreachable!("incorrect type '{}'", std::any::type_name::<T>())
}
};
let bench_impl: B = unsafe {
assert_eq!(size_of::<B>(), 0, "benchmark closure expected to be zero-sized");
mem::zeroed()
};
bench_impl(bencher, &typed_args[arg_index]);
}
#[cfg(test)]
mod tests {
use super::*;
mod optimizations {
use std::borrow::Borrow;
use super::*;
fn test_eq_ptr<A: Borrow<str>, B: Borrow<str>>(a: &[A], b: &[B]) {
assert_eq!(a.len(), b.len());
for (a, b) in a.iter().zip(b) {
let a = a.borrow();
let b = b.borrow();
assert_eq!(a, b);
assert_eq!(a.as_ptr(), b.as_ptr());
}
}
#[test]
fn str_slice() {
static ARGS: BenchArgs = BenchArgs::new();
static ORIG_ARGS: &[&str] = &["a", "b"];
let runner = ARGS.runner(|| ORIG_ARGS, ToString::to_string, |_, _| {});
let typed_args: Vec<&str> =
runner.args.typed_args::<&&str>().unwrap().iter().copied().copied().collect();
let names = runner.arg_names();
assert_eq!(names, ORIG_ARGS);
assert_eq!(names, typed_args);
assert_eq!(names.as_ptr(), ORIG_ARGS.as_ptr());
assert_ne!(names.as_ptr(), typed_args.as_ptr());
}
#[test]
fn str_array() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(|| ["a", "b"], ToString::to_string, |_, _| {});
let typed_args = runner.args.typed_args::<&str>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
assert_eq!(names, typed_args);
assert_eq!(names.as_ptr(), typed_args.as_ptr());
}
#[test]
fn string_array() {
static ARGS: BenchArgs = BenchArgs::new();
let runner =
ARGS.runner(|| ["a".to_owned(), "b".to_owned()], ToString::to_string, |_, _| {});
let typed_args = runner.args.typed_args::<String>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
#[test]
fn box_str_array() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(
|| ["a".to_owned().into_boxed_str(), "b".to_owned().into_boxed_str()],
ToString::to_string,
|_, _| {},
);
let typed_args = runner.args.typed_args::<Box<str>>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
#[test]
fn cow_str_array() {
static ARGS: BenchArgs = BenchArgs::new();
let runner = ARGS.runner(
|| [Cow::Owned("a".to_owned()), Cow::Borrowed("b")],
ToString::to_string,
|_, _| {},
);
let typed_args = runner.args.typed_args::<Cow<str>>().unwrap();
let names = runner.arg_names();
assert_eq!(names, ["a", "b"]);
test_eq_ptr(names, typed_args);
}
}
}