use crate::{array::PrimitiveArray, bitmap::Bitmap, buffer::MutableBuffer, types::Index};
use super::SortOptions;
#[inline]
fn k_element_sort_inner<I: Index, T, G, F>(
indices: &mut [I],
get: G,
descending: bool,
limit: usize,
mut cmp: F,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let mut compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&rhs, &lhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
} else {
let mut compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
}
}
#[inline]
fn sort_unstable_by<I, T, G, F>(
indices: &mut [I],
get: G,
mut cmp: F,
descending: bool,
limit: usize,
) where
I: Index,
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, get, descending, limit, cmp);
}
if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&rhs, &lhs)
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
})
}
}
#[inline]
pub(super) fn indices_sorted_unstable_by<I, T, G, F>(
validity: Option<&Bitmap>,
get: G,
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<I>
where
I: Index,
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
let descending = options.descending;
let limit = limit.unwrap_or(length);
let limit = limit.min(length);
let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<I>::from_len_zeroed(length);
if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(I::range(0, length).unwrap())
.for_each(|(is_valid, index)| {
if is_valid {
indices[validity.null_count() + valids] = index;
valids += 1;
} else {
indices[nulls] = index;
nulls += 1;
}
});
if limit > validity.null_count() {
let limit = limit.saturating_sub(validity.null_count());
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
}
} else {
let last_valid_index = length.saturating_sub(validity.null_count());
let mut nulls = 0;
let mut valids = 0;
validity
.iter()
.zip(I::range(0, length).unwrap())
.for_each(|(x, index)| {
if x {
indices[valids] = index;
valids += 1;
} else {
indices[last_valid_index + nulls] = index;
nulls += 1;
}
});
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
sort_unstable_by(indices, get, cmp, options.descending, limit);
}
indices.truncate(limit);
indices.shrink_to_fit();
indices
} else {
let mut indices = MutableBuffer::from_trusted_len_iter(I::range(0, length).unwrap());
sort_unstable_by(&mut indices, get, cmp, descending, limit);
indices.truncate(limit);
indices.shrink_to_fit();
indices
};
PrimitiveArray::<I>::from_data(I::DATA_TYPE, indices.into(), None)
}