use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
use arrow::compute::try_binary;
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::function::Hint;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Utf8View | DataType::BinaryView => $utf8Type,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
_ => {
return datafusion_common::exec_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
**value_type
);
}
},
data_type => {
return datafusion_common::exec_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
data_type
);
}
})
}
};
}
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
pub fn make_scalar_function<F>(
inner: F,
hints: Vec<Hint>,
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
{
move |args: &[ColumnarValue]| {
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});
let is_scalar = len.is_none();
let inferred_length = len.unwrap_or(1);
let args = args
.iter()
.zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
.map(|(arg, hint)| {
let expansion_len = match hint {
Hint::AcceptsSingular => 1,
Hint::Pad => inferred_length,
};
arg.to_array(expansion_len)
})
.collect::<Result<Vec<_>>>()?;
let result = (inner)(&args);
if is_scalar {
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
}
pub fn calculate_binary_math<L, R, O, F>(
left: &dyn Array,
right: &ColumnarValue,
fun: F,
) -> Result<Arc<PrimitiveArray<O>>>
where
R: ArrowPrimitiveType,
L: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
R::Native: TryFrom<ScalarValue>,
{
let left = left.as_primitive::<L>();
let right = right.cast_to(&R::DATA_TYPE, None)?;
let result = match right {
ColumnarValue::Scalar(scalar) => {
let right = R::Native::try_from(scalar.clone()).map_err(|_| {
DataFusionError::NotImplemented(format!(
"Cannot convert scalar value {} to {}",
&scalar,
R::DATA_TYPE
))
})?;
left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))?
}
ColumnarValue::Array(right) => {
let right = right.as_primitive::<R>();
try_binary::<_, _, _, O>(left, right, &fun)?
}
};
Ok(Arc::new(result) as _)
}
pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
if scale < 0 {
Err(ArrowError::ComputeError(
"Negative scale is not supported".into(),
))
} else if scale == 0 {
Ok(value)
} else {
match i128::from(10).checked_pow(scale as u32) {
Some(divisor) => Ok(value / divisor),
None => Err(ArrowError::ComputeError(format!(
"Cannot get a power of {scale}"
))),
}
}
}
#[cfg(test)]
pub mod test {
macro_rules! test_function {
($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
let func = $FUNC;
let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
let cardinality = $ARGS
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
})
.unwrap_or(1);
let scalar_arguments = $ARGS.iter().map(|arg| match arg {
ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
ColumnarValue::Array(_) => None,
}).collect::<Vec<_>>();
let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
let nullables = $ARGS.iter().map(|arg| match arg {
ColumnarValue::Scalar(scalar) => scalar.is_null(),
ColumnarValue::Array(a) => a.null_count() > 0,
}).collect::<Vec<_>>();
let field_array = data_array.into_iter().zip(nullables).enumerate()
.map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
.map(std::sync::Arc::new)
.collect::<Vec<_>>();
let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
arg_fields: &field_array,
scalar_arguments: &scalar_arguments_refs,
});
let arg_fields = $ARGS.iter()
.enumerate()
.map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
.collect::<Vec<_>>();
match expected {
Ok(expected) => {
assert_eq!(return_field.is_ok(), true);
let return_field = return_field.unwrap();
let return_type = return_field.data_type();
assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
args: $ARGS,
arg_fields,
number_rows: cardinality,
return_field,
config_options: $CONFIG_OPTIONS
});
assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
match expected {
Some(v) => assert_eq!(result.value(0), v),
None => assert!(result.is_null(0)),
};
}
Err(expected_error) => {
if let Ok(return_field) = return_field {
match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
args: $ARGS,
arg_fields,
number_rows: cardinality,
return_field,
config_options: $CONFIG_OPTIONS,
}) {
Ok(_) => assert!(false, "expected error"),
Err(error) => {
assert!(expected_error
.strip_backtrace()
.starts_with(&error.strip_backtrace()));
}
}
} else if let Err(error) = return_field {
datafusion_common::assert_contains!(
expected_error.strip_backtrace(),
error.strip_backtrace()
);
}
}
};
};
($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
test_function!(
$FUNC,
$ARGS,
$EXPECTED,
$EXPECTED_TYPE,
$EXPECTED_DATA_TYPE,
$ARRAY_TYPE,
std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
)
};
}
use arrow::datatypes::DataType;
#[allow(unused_imports)]
pub(crate) use test_function;
use super::*;
#[test]
fn string_to_int_type() {
let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
assert_eq!(v, DataType::Int32);
let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
assert_eq!(v, DataType::Int32);
let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
assert_eq!(v, DataType::Int64);
}
#[test]
fn test_decimal128_to_i128() {
let cases = [
(123, 0, Some(123)),
(1230, 1, Some(123)),
(123000, 3, Some(123)),
(1, 0, Some(1)),
(123, -3, None),
(123, i8::MAX, None),
(i128::MAX, 0, Some(i128::MAX)),
(i128::MAX, 3, Some(i128::MAX / 1000)),
];
for (value, scale, expected) in cases {
match decimal128_to_i128(value, scale) {
Ok(actual) => {
assert_eq!(
actual,
expected.expect("Got value but expected none"),
"{value} and {scale} vs {expected:?}"
);
}
Err(_) => assert!(expected.is_none()),
}
}
}
}