use std::sync::Arc;
use crate::{
bitmap::Bitmap,
datatypes::{DataType, Field},
error::Result,
ffi,
};
use super::{ffi::ToFfi, new_empty_array, new_null_array, Array, FromFfi};
#[derive(Debug, Clone)]
pub struct StructArray {
data_type: DataType,
values: Vec<Arc<dyn Array>>,
validity: Option<Bitmap>,
}
impl StructArray {
pub fn new_empty(data_type: DataType) -> Self {
if let DataType::Struct(fields) = &data_type {
let values = fields
.iter()
.map(|field| new_empty_array(field.data_type().clone()).into())
.collect();
Self::from_data(data_type, values, None)
} else {
panic!("StructArray must be initialized with DataType::Struct");
}
}
pub fn new_null(data_type: DataType, length: usize) -> Self {
if let DataType::Struct(fields) = &data_type {
let values = fields
.iter()
.map(|field| new_null_array(field.data_type().clone(), length).into())
.collect();
Self::from_data(data_type, values, Some(Bitmap::new_zeroed(length)))
} else {
panic!("StructArray must be initialized with DataType::Struct");
}
}
pub fn from_data(
data_type: DataType,
values: Vec<Arc<dyn Array>>,
validity: Option<Bitmap>,
) -> Self {
let fields = Self::get_fields(&data_type);
assert!(!fields.is_empty());
assert_eq!(fields.len(), values.len());
assert!(values.iter().all(|x| x.len() == values[0].len()));
if let Some(ref validity) = validity {
assert_eq!(values[0].len(), validity.len());
}
Self {
data_type,
values,
validity,
}
}
pub fn into_data(self) -> (Vec<Field>, Vec<Arc<dyn Array>>, Option<Bitmap>) {
let Self {
data_type,
values,
validity,
} = self;
let fields = if let DataType::Struct(fields) = data_type {
fields
} else {
unreachable!()
};
(fields, values, validity)
}
pub fn slice(&self, offset: usize, length: usize) -> Self {
assert!(
offset + length <= self.len(),
"offset + length may not exceed length of array"
);
unsafe { self.slice_unchecked(offset, length) }
}
pub unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Self {
let validity = self
.validity
.clone()
.map(|x| x.slice_unchecked(offset, length));
Self {
data_type: self.data_type.clone(),
values: self
.values
.iter()
.map(|x| x.slice_unchecked(offset, length).into())
.collect(),
validity,
}
}
pub fn with_validity(&self, validity: Option<Bitmap>) -> Self {
if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) {
panic!("validity should be as least as large as the array")
}
let mut arr = self.clone();
arr.validity = validity;
arr
}
}
impl StructArray {
#[inline]
pub fn validity(&self) -> Option<&Bitmap> {
self.validity.as_ref()
}
pub fn values(&self) -> &[Arc<dyn Array>] {
&self.values
}
pub fn fields(&self) -> &[Field] {
Self::get_fields(&self.data_type)
}
}
impl StructArray {
pub fn get_fields(data_type: &DataType) -> &[Field] {
match data_type {
DataType::Struct(fields) => fields,
DataType::Extension(_, inner, _) => Self::get_fields(inner),
_ => panic!("Wrong datatype passed to Struct."),
}
}
}
impl Array for StructArray {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}
#[inline]
fn len(&self) -> usize {
self.values[0].len()
}
#[inline]
fn data_type(&self) -> &DataType {
&self.data_type
}
#[inline]
fn validity(&self) -> Option<&Bitmap> {
self.validity.as_ref()
}
fn slice(&self, offset: usize, length: usize) -> Box<dyn Array> {
Box::new(self.slice(offset, length))
}
unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Box<dyn Array> {
Box::new(self.slice_unchecked(offset, length))
}
fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
Box::new(self.with_validity(validity))
}
}
impl std::fmt::Display for StructArray {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "StructArray{{")?;
for (field, column) in self.fields().iter().zip(self.values()) {
writeln!(f, "{}: {},", field.name(), column)?;
}
write!(f, "}}")
}
}
unsafe impl ToFfi for StructArray {
fn buffers(&self) -> Vec<Option<std::ptr::NonNull<u8>>> {
vec![self.validity.as_ref().map(|x| x.as_ptr())]
}
fn offset(&self) -> usize {
0
}
fn children(&self) -> Vec<Arc<dyn Array>> {
self.values.clone()
}
}
impl<A: ffi::ArrowArrayRef> FromFfi<A> for StructArray {
unsafe fn try_from_ffi(array: A) -> Result<Self> {
let field = array.field();
let fields = Self::get_fields(field.data_type()).to_vec();
let length = array.array().len();
let offset = array.array().offset();
let mut validity = unsafe { array.validity() }?;
let values = (0..fields.len())
.map(|index| {
let child = array.child(index)?;
Ok(ffi::try_from(child)?.into())
})
.collect::<Result<Vec<Arc<dyn Array>>>>()?;
if offset > 0 {
validity = validity.map(|x| x.slice(offset, length))
}
Ok(Self::from_data(DataType::Struct(fields), values, validity))
}
}