use crate::take::take;
use arrow_array::{
make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, Int32Array, Scalar,
UnionArray,
};
use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer};
use arrow_data::layout;
use arrow_schema::{ArrowError, DataType, UnionFields};
use std::cmp::Ordering;
use std::sync::Arc;
pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
let fields = match union_array.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!(),
};
let (target_type_id, _) = fields
.iter()
.find(|field| field.1.name() == target)
.ok_or_else(|| {
ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
})?;
match union_array.offsets() {
Some(_) => extract_dense(union_array, fields, target_type_id),
None => extract_sparse(union_array, fields, target_type_id),
}
}
fn extract_sparse(
union_array: &UnionArray,
fields: &UnionFields,
target_type_id: i8,
) -> Result<ArrayRef, ArrowError> {
let target = union_array.child(target_type_id);
if fields.len() == 1 || union_array.is_empty() || target.null_count() == target.len() || target.data_type().is_null()
{
Ok(Arc::clone(target))
} else {
match eq_scalar(union_array.type_ids(), target_type_id) {
BoolValue::Scalar(true) => Ok(Arc::clone(target)),
BoolValue::Scalar(false) => {
if layout(target.data_type()).can_contain_null_mask {
let data = unsafe {
target
.into_data()
.into_builder()
.nulls(Some(NullBuffer::new_null(target.len())))
.build_unchecked()
};
Ok(make_array(data))
} else {
Ok(new_null_array(target.data_type(), target.len()))
}
}
BoolValue::Buffer(selected) => {
if layout(target.data_type()).can_contain_null_mask {
let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
Some(nulls) => &selected & nulls.inner(),
None => selected,
};
let data = unsafe {
assert_eq!(nulls.len(), target.len());
target
.into_data()
.into_builder()
.nulls(Some(nulls.into()))
.build_unchecked()
};
Ok(make_array(data))
} else {
Ok(crate::zip::zip(
&BooleanArray::new(selected, None),
target,
&Scalar::new(new_null_array(target.data_type(), 1)),
)?)
}
}
}
}
}
fn extract_dense(
union_array: &UnionArray,
fields: &UnionFields,
target_type_id: i8,
) -> Result<ArrayRef, ArrowError> {
let target = union_array.child(target_type_id);
let offsets = union_array.offsets().unwrap();
if union_array.is_empty() {
if target.is_empty() {
Ok(Arc::clone(target))
} else {
Ok(new_empty_array(target.data_type()))
}
} else if target.is_empty() {
Ok(new_null_array(target.data_type(), union_array.len()))
} else if target.null_count() == target.len() || target.data_type().is_null() {
match target.len().cmp(&union_array.len()) {
Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
Ordering::Equal => Ok(Arc::clone(target)),
Ordering::Greater => Ok(target.slice(0, union_array.len())),
}
} else if fields.len() == 1 || fields
.iter()
.filter(|(field_type_id, _)| *field_type_id != target_type_id)
.all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
{
Ok(extract_dense_all_selected(union_array, target, offsets)?)
} else {
match eq_scalar(union_array.type_ids(), target_type_id) {
BoolValue::Scalar(true) => {
Ok(extract_dense_all_selected(union_array, target, offsets)?)
}
BoolValue::Scalar(false) => {
match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
(Ordering::Less, _) | (_, false) => { Ok(new_null_array(target.data_type(), union_array.len()))
}
(Ordering::Equal, true) => {
let data = unsafe {
target
.into_data()
.into_builder()
.nulls(Some(NullBuffer::new_null(union_array.len())))
.build_unchecked()
};
Ok(make_array(data))
}
(Ordering::Greater, true) => {
let data = unsafe {
target
.into_data()
.slice(0, union_array.len())
.into_builder()
.nulls(Some(NullBuffer::new_null(union_array.len())))
.build_unchecked()
};
Ok(make_array(data))
}
}
}
BoolValue::Buffer(selected) => {
Ok(take(
target,
&Int32Array::new(offsets.clone(), Some(selected.into())),
None,
)?)
}
}
}
}
fn extract_dense_all_selected(
union_array: &UnionArray,
target: &Arc<dyn Array>,
offsets: &ScalarBuffer<i32>,
) -> Result<ArrayRef, ArrowError> {
let sequential =
target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
if sequential && target.len() == union_array.len() {
Ok(Arc::clone(target))
} else if sequential && target.len() > union_array.len() {
Ok(target.slice(offsets[0] as usize, union_array.len()))
} else {
let indices = Int32Array::try_new(offsets.clone(), None)?;
Ok(take(target, &indices, None)?)
}
}
const EQ_SCALAR_CHUNK_SIZE: usize = 512;
#[derive(Debug, PartialEq)]
enum BoolValue {
Scalar(bool),
Buffer(BooleanBuffer),
}
fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
}
fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
type_ids
.chunks(chunk_size)
.take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
.map(|chunk| chunk.len())
.sum()
}
fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
let (set_bits, val) = if true_bits == type_ids.len() {
return BoolValue::Scalar(true);
} else if true_bits == 0 {
let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
if false_bits == type_ids.len() {
return BoolValue::Scalar(false);
} else {
(false_bits, false)
}
} else {
(true_bits, true)
};
let set_bits = set_bits - set_bits % 64;
let mut buffer =
MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
chunk
.iter()
.copied()
.enumerate()
.fold(0, |packed, (bit_idx, v)| {
packed | ((v == target) as u64) << bit_idx
})
}));
BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
}
const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
fn is_sequential(offsets: &[i32]) -> bool {
is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
}
fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
if offsets.is_empty() {
return true;
}
if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
return false;
}
let chunks = offsets.chunks_exact(N);
let remainder = chunks.remainder();
chunks.enumerate().all(|(i, chunk)| {
let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
chunk_array
.iter()
.copied()
.enumerate()
.fold(true, |acc, (i, offset)| {
acc & (offset == chunk_array[0] + i as i32)
})
&& offsets[0] + (i * N) as i32 == chunk_array[0] }) && remainder
.iter()
.copied()
.enumerate()
.fold(true, |acc, (i, offset)| {
acc & (offset == remainder[0] + i as i32)
}) }
#[cfg(test)]
mod tests {
use super::{eq_scalar_inner, is_sequential_generic, union_extract, BoolValue};
use arrow_array::{new_null_array, Array, Int32Array, NullArray, StringArray, UnionArray};
use arrow_buffer::{BooleanBuffer, ScalarBuffer};
use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
use std::sync::Arc;
#[test]
fn test_eq_scalar() {
const ARRAY_LEN: usize = 64 * 4;
const EQ_SCALAR_CHUNK_SIZE: usize = 3;
fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
}
fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
}
assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
let mut values = [1; ARRAY_LEN];
assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
for i in 1..ARRAY_LEN {
assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
}
for i in 0..ARRAY_LEN {
values[i] = 2;
assert_eq!(
eq_scalar(&values, 1),
BoolValue::Buffer(cross_check(&values, 1))
);
assert_eq!(
eq_scalar(&values, 2),
BoolValue::Buffer(cross_check(&values, 2))
);
values[i] = 1;
}
}
#[test]
fn test_is_sequential() {
const CHUNK_SIZE: usize = 3;
fn is_sequential(v: &[i32]) -> bool {
is_sequential_generic::<CHUNK_SIZE>(v)
}
assert!(is_sequential(&[])); assert!(is_sequential(&[1]));
assert!(is_sequential(&[1, 2]));
assert!(is_sequential(&[1, 2, 3]));
assert!(is_sequential(&[1, 2, 3, 4]));
assert!(is_sequential(&[1, 2, 3, 4, 5]));
assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
assert!(!is_sequential(&[8, 7]));
assert!(!is_sequential(&[8, 7, 6]));
assert!(!is_sequential(&[8, 7, 6, 5]));
assert!(!is_sequential(&[8, 7, 6, 5, 4]));
assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
assert!(!is_sequential(&[0, 2]));
assert!(!is_sequential(&[1, 0]));
assert!(!is_sequential(&[0, 2, 3]));
assert!(!is_sequential(&[1, 0, 3]));
assert!(!is_sequential(&[1, 2, 0]));
assert!(!is_sequential(&[0, 2, 3, 4]));
assert!(!is_sequential(&[1, 0, 3, 4]));
assert!(!is_sequential(&[1, 2, 0, 4]));
assert!(!is_sequential(&[1, 2, 3, 0]));
assert!(!is_sequential(&[0, 2, 3, 4, 5]));
assert!(!is_sequential(&[1, 0, 3, 4, 5]));
assert!(!is_sequential(&[1, 2, 0, 4, 5]));
assert!(!is_sequential(&[1, 2, 3, 0, 5]));
assert!(!is_sequential(&[1, 2, 3, 4, 0]));
assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
assert!(!is_sequential(&[1, 2, 3, 5]));
assert!(!is_sequential(&[1, 2, 3, 5, 6]));
assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
}
fn str1() -> UnionFields {
UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, true)])
}
fn str1_int3() -> UnionFields {
UnionFields::new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("int", DataType::Int32, true),
],
)
}
#[test]
fn sparse_1_1_single_field() {
let union = UnionArray::try_new(
str1(),
ScalarBuffer::from(vec![1, 1]), None, vec![
Arc::new(StringArray::from(vec!["a", "b"])), ],
)
.unwrap();
let expected = StringArray::from(vec!["a", "b"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_1_2_empty() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![]), None, vec![
Arc::new(StringArray::new_null(0)),
Arc::new(Int32Array::new_null(0)),
],
)
.unwrap();
let expected = StringArray::new_null(0);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_1_3a_null_target() {
let union = UnionArray::try_new(
UnionFields::new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("null", DataType::Null, true), ],
),
ScalarBuffer::from(vec![1]), None, vec![
Arc::new(StringArray::new_null(1)),
Arc::new(NullArray::new(1)), ],
)
.unwrap();
let expected = NullArray::new(1);
let extracted = union_extract(&union, "null").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_1_3b_null_target() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1]), None, vec![
Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(1)),
],
)
.unwrap();
let expected = StringArray::new_null(1);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_2_all_types_match() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3]), None, vec![
Arc::new(StringArray::new_null(2)),
Arc::new(Int32Array::from(vec![1, 4])), ],
)
.unwrap();
let expected = Int32Array::from(vec![1, 4]);
let extracted = union_extract(&union, "int").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_3_1_none_match_target_can_contain_null_mask() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1, 1, 1]), None, vec![
Arc::new(StringArray::new_null(4)),
Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
)
.unwrap();
let expected = Int32Array::new_null(4);
let extracted = union_extract(&union, "int").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
fn str1_union3(union3_datatype: DataType) -> UnionFields {
UnionFields::new(
vec![1, 3],
vec![
Field::new("str", DataType::Utf8, true),
Field::new("union", union3_datatype, true),
],
)
}
#[test]
fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
let target_fields = str1();
let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
let union = UnionArray::try_new(
str1_union3(target_type.clone()),
ScalarBuffer::from(vec![1, 1]), None, vec![
Arc::new(StringArray::new_null(2)),
Arc::new(
UnionArray::try_new(
target_fields.clone(),
ScalarBuffer::from(vec![1, 1]),
None,
vec![Arc::new(StringArray::from(vec!["a", "b"]))],
)
.unwrap(),
),
],
)
.unwrap();
let expected = new_null_array(&target_type, 2);
let extracted = union_extract(&union, "union").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_4_1_1_target_with_nulls() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3, 1, 1]), None, vec![
Arc::new(StringArray::new_null(4)),
Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
)
.unwrap();
let expected = Int32Array::from(vec![None, Some(4), None, None]);
let extracted = union_extract(&union, "int").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_4_1_2_target_without_nulls() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 3, 3]), None, vec![
Arc::new(StringArray::new_null(3)),
Arc::new(Int32Array::from(vec![2, 4, 8])), ],
)
.unwrap();
let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
let extracted = union_extract(&union, "int").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn sparse_4_2_some_match_target_cant_contain_null_mask() {
let target_fields = str1();
let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
let union = UnionArray::try_new(
str1_union3(target_type),
ScalarBuffer::from(vec![3, 1]), None, vec![
Arc::new(StringArray::new_null(2)),
Arc::new(
UnionArray::try_new(
target_fields.clone(),
ScalarBuffer::from(vec![1, 1]),
None,
vec![Arc::new(StringArray::from(vec!["a", "b"]))],
)
.unwrap(),
),
],
)
.unwrap();
let expected = UnionArray::try_new(
target_fields,
ScalarBuffer::from(vec![1, 1]),
None,
vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
)
.unwrap();
let extracted = union_extract(&union, "union").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_1_1_both_empty() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(0)),
],
)
.unwrap();
let expected = StringArray::new_null(0);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_1_2_empty_union_target_non_empty() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(0)),
],
)
.unwrap();
let expected = StringArray::new_null(0);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_2_non_empty_union_target_empty() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(2)),
],
)
.unwrap();
let expected = StringArray::new_null(2);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_3_1_null_target_smaller_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(2)),
],
)
.unwrap();
let expected = StringArray::new_null(2);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_3_2_null_target_equal_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
Arc::new(StringArray::new_null(2)), Arc::new(Int32Array::new_null(2)),
],
)
.unwrap();
let expected = StringArray::new_null(2);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_3_3_null_target_bigger_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
Arc::new(StringArray::new_null(3)), Arc::new(Int32Array::new_null(3)),
],
)
.unwrap();
let expected = StringArray::new_null(2);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_1a_single_type_sequential_offsets_equal_len() {
let union = UnionArray::try_new(
str1(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2"])), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "b2"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_2a_single_type_sequential_offsets_bigger() {
let union = UnionArray::try_new(
str1(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "b2"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_3a_single_type_non_sequential() {
let union = UnionArray::try_new(
str1(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "c3"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_1b_empty_siblings_sequential_equal_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a", "b"])), Arc::new(Int32Array::new_null(0)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a", "b"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_2b_empty_siblings_sequential_bigger_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a", "b"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_3b_empty_sibling_non_sequential() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a", "c"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_1c_all_types_match_sequential_equal_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2"])), Arc::new(Int32Array::new_null(2)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "b2"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_2c_all_types_match_sequential_bigger_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), Arc::new(Int32Array::new_null(2)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "b2"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_4_3c_all_types_match_non_sequential() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
Arc::new(Int32Array::new_null(2)), ],
)
.unwrap();
let expected = StringArray::from(vec!["a1", "b3"]);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_5_1a_none_match_less_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let expected = StringArray::new_null(5);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_5_1b_cant_contain_null_mask() {
let target_fields = str1();
let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
let union = UnionArray::try_new(
str1_union3(target_type.clone()),
ScalarBuffer::from(vec![1, 1, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(
UnionArray::try_new(
target_fields.clone(),
ScalarBuffer::from(vec![1]),
None,
vec![Arc::new(StringArray::from(vec!["a"]))],
)
.unwrap(),
), ],
)
.unwrap();
let expected = new_null_array(&target_type, 5);
let extracted = union_extract(&union, "union").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_5_2_none_match_equal_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), Arc::new(Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let expected = StringArray::new_null(5);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_5_3_none_match_greater_len() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), Arc::new(Int32Array::from(vec![1, 2])), ],
)
.unwrap();
let expected = StringArray::new_null(5);
let extracted = union_extract(&union, "str").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn dense_6_some_matches() {
let union = UnionArray::try_new(
str1_int3(),
ScalarBuffer::from(vec![3, 3, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), vec![
Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
let extracted = union_extract(&union, "int").unwrap();
assert_eq!(extracted.into_data(), expected.into_data());
}
#[test]
fn empty_sparse_union() {
let union = UnionArray::try_new(
UnionFields::empty(),
ScalarBuffer::from(vec![]),
None,
vec![],
)
.unwrap();
assert_eq!(
union_extract(&union, "a").unwrap_err().to_string(),
ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
);
}
#[test]
fn empty_dense_union() {
let union = UnionArray::try_new(
UnionFields::empty(),
ScalarBuffer::from(vec![]),
Some(ScalarBuffer::from(vec![])),
vec![],
)
.unwrap();
assert_eq!(
union_extract(&union, "a").unwrap_err().to_string(),
ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
);
}
}