use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::fmt::Display;
use std::{sync::Arc, vec};
use arrow_schema::*;
use datafusion_common::config::ConfigOptions;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{plan_err, GetExt, Result, TableReference};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
use datafusion_sql::planner::ContextProvider;
struct MockCsvType {}
impl GetExt for MockCsvType {
fn get_ext(&self) -> String {
"csv".to_string()
}
}
impl FileType for MockCsvType {
fn as_any(&self) -> &dyn Any {
self
}
}
impl Display for MockCsvType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.get_ext())
}
}
#[derive(Default)]
pub(crate) struct MockSessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
pub config_options: ConfigOptions,
}
impl MockSessionState {
pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(expr_planner);
self
}
pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
self.scalar_functions
.insert(scalar_function.name().to_string(), scalar_function);
self
}
pub fn with_aggregate_function(
mut self,
aggregate_function: Arc<AggregateUDF>,
) -> Self {
self.aggregate_functions.insert(
aggregate_function.name().to_string().to_lowercase(),
aggregate_function,
);
self
}
}
pub(crate) struct MockContextProvider {
pub(crate) state: MockSessionState,
}
impl ContextProvider for MockContextProvider {
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let schema = match name.table() {
"test" => Ok(Schema::new(vec![
Field::new("t_date32", DataType::Date32, false),
Field::new("t_date64", DataType::Date64, false),
])),
"j1" => Ok(Schema::new(vec![
Field::new("j1_id", DataType::Int32, false),
Field::new("j1_string", DataType::Utf8, false),
])),
"j2" => Ok(Schema::new(vec![
Field::new("j2_id", DataType::Int32, false),
Field::new("j2_string", DataType::Utf8, false),
])),
"j3" => Ok(Schema::new(vec![
Field::new("j3_id", DataType::Int32, false),
Field::new("j3_string", DataType::Utf8, false),
])),
"test_decimal" => Ok(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("price", DataType::Decimal128(10, 2), false),
])),
"person" => Ok(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("age", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Float64, false),
Field::new(
"birth_date",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new("😀", DataType::Int32, false),
])),
"person_quoted_cols" => Ok(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("First Name", DataType::Utf8, false),
Field::new("Last Name", DataType::Utf8, false),
Field::new("Age", DataType::Int32, false),
Field::new("State", DataType::Utf8, false),
Field::new("Salary", DataType::Float64, false),
Field::new(
"Birth Date",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new("😀", DataType::Int32, false),
])),
"orders" => Ok(Schema::new(vec![
Field::new("order_id", DataType::UInt32, false),
Field::new("customer_id", DataType::UInt32, false),
Field::new("o_item_id", DataType::Utf8, false),
Field::new("qty", DataType::Int32, false),
Field::new("price", DataType::Float64, false),
Field::new("delivered", DataType::Boolean, false),
])),
"array" => Ok(Schema::new(vec![
Field::new(
"left",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
false,
),
Field::new(
"right",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
false,
),
])),
"lineitem" => Ok(Schema::new(vec![
Field::new("l_item_id", DataType::UInt32, false),
Field::new("l_description", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
])),
"aggregate_test_100" => Ok(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c4", DataType::Int16, false),
Field::new("c5", DataType::Int32, false),
Field::new("c6", DataType::Int64, false),
Field::new("c7", DataType::UInt8, false),
Field::new("c8", DataType::UInt16, false),
Field::new("c9", DataType::UInt32, false),
Field::new("c10", DataType::UInt64, false),
Field::new("c11", DataType::Float32, false),
Field::new("c12", DataType::Float64, false),
Field::new("c13", DataType::Utf8, false),
])),
"UPPERCASE_test" => Ok(Schema::new(vec![
Field::new("Id", DataType::UInt32, false),
Field::new("lower", DataType::UInt32, false),
])),
"unnest_table" => Ok(Schema::new(vec![
Field::new(
"array_col",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
false,
),
Field::new(
"struct_col",
DataType::Struct(Fields::from(vec![
Field::new("field1", DataType::Int64, true),
Field::new("field2", DataType::Utf8, true),
])),
false,
),
])),
_ => plan_err!("No table named: {} found", name.table()),
};
match schema {
Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
Err(e) => Err(e),
}
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions.get(name).cloned()
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions.get(name).cloned()
}
fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
unimplemented!()
}
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}
fn options(&self) -> &ConfigOptions {
&self.state.config_options
}
fn get_file_type(
&self,
_ext: &str,
) -> Result<Arc<dyn datafusion_common::file_options::file_type::FileType>> {
Ok(Arc::new(MockCsvType {}))
}
fn create_cte_work_table(
&self,
_name: &str,
schema: SchemaRef,
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(EmptyTable::new(schema)))
}
fn udf_names(&self) -> Vec<String> {
self.state.scalar_functions.keys().cloned().collect()
}
fn udaf_names(&self) -> Vec<String> {
self.state.aggregate_functions.keys().cloned().collect()
}
fn udwf_names(&self) -> Vec<String> {
Vec::new()
}
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
}
}
struct EmptyTable {
table_schema: SchemaRef,
}
impl EmptyTable {
fn new(table_schema: SchemaRef) -> Self {
Self { table_schema }
}
}
impl TableSource for EmptyTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.table_schema)
}
}