use arrow::{
array::{Array, ArrayRef, ListArray, PrimitiveArray},
buffer::OffsetBuffer,
datatypes::{Field, UInt8Type},
};
#[test]
fn test_shrink_to_fit_after_concat() {
let array_len = 6_000;
let num_concats = 100;
let primitive_array: PrimitiveArray<UInt8Type> = (0..array_len)
.map(|v| (v % 255) as u8)
.collect::<Vec<_>>()
.into();
let primitive_array: ArrayRef = Arc::new(primitive_array);
let list_array: ArrayRef = Arc::new(ListArray::new(
Field::new_list_field(primitive_array.data_type().clone(), false).into(),
OffsetBuffer::from_lengths([primitive_array.len()]),
primitive_array.clone(),
None,
));
let (concatenated, _bytes_allocated_globally, bytes_allocated_by_this_thread) =
memory_use(|| {
let mut concatenated = concatenate(num_concats, list_array.clone());
concatenated.shrink_to_fit(); dbg!(concatenated.data_type());
concatenated
});
let expected_len = num_concats * array_len;
assert_eq!(bytes_used(concatenated.clone()), expected_len);
eprintln!("The concatenated array is {expected_len} B long. Amount of memory used by this thread: {bytes_allocated_by_this_thread} B");
assert!(
expected_len <= bytes_allocated_by_this_thread,
"We must allocate at least as much space as the concatenated array"
);
assert!(
bytes_allocated_by_this_thread <= expected_len + expected_len / 100,
"We shouldn't have more than 1% memory overhead. In fact, we are using {bytes_allocated_by_this_thread} B of memory for {expected_len} B of data"
);
}
fn concatenate(num_times: usize, array: ArrayRef) -> ArrayRef {
let mut concatenated = array.clone();
for _ in 0..num_times - 1 {
concatenated = arrow::compute::kernels::concat::concat(&[&*concatenated, &*array]).unwrap();
}
concatenated
}
fn bytes_used(array: ArrayRef) -> usize {
let mut array = array;
loop {
match array.data_type() {
arrow::datatypes::DataType::UInt8 => break,
arrow::datatypes::DataType::List(_) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();
array = list.values().clone();
}
_ => unreachable!(),
}
}
array.len()
}
use std::{
alloc::Layout,
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
Arc,
},
};
static LIVE_BYTES_GLOBAL: AtomicUsize = AtomicUsize::new(0);
thread_local! {
static LIVE_BYTES_IN_THREAD: AtomicUsize = const { AtomicUsize::new(0) } ;
}
pub struct TrackingAllocator {
allocator: std::alloc::System,
}
#[global_allocator]
pub static GLOBAL_ALLOCATOR: TrackingAllocator = TrackingAllocator {
allocator: std::alloc::System,
};
#[allow(unsafe_code)]
unsafe impl std::alloc::GlobalAlloc for TrackingAllocator {
#[allow(clippy::let_and_return)]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let ptr = unsafe { self.allocator.alloc(layout) };
if !ptr.is_null() {
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_add(layout.size(), Relaxed));
LIVE_BYTES_GLOBAL.fetch_add(layout.size(), Relaxed);
}
ptr
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_sub(layout.size(), Relaxed));
LIVE_BYTES_GLOBAL.fetch_sub(layout.size(), Relaxed);
unsafe { self.allocator.dealloc(ptr, layout) };
}
}
fn live_bytes_local() -> usize {
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.load(Relaxed))
}
fn live_bytes_global() -> usize {
LIVE_BYTES_GLOBAL.load(Relaxed)
}
fn memory_use<R>(run: impl Fn() -> R) -> (R, usize, usize) {
let used_bytes_start_local = live_bytes_local();
let used_bytes_start_global = live_bytes_global();
let ret = run();
let bytes_used_local = live_bytes_local() - used_bytes_start_local;
let bytes_used_global = live_bytes_global() - used_bytes_start_global;
(ret, bytes_used_global, bytes_used_local)
}