Skip to content

Commit 40a717b

Browse files
committed
add test
1 parent a62d0d6 commit 40a717b

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed

datafusion/core/src/execution/session_state.rs

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,16 @@ mod tests {
19551955
use super::{SessionContextProvider, SessionStateBuilder};
19561956
use crate::datasource::MemTable;
19571957
use crate::execution::context::SessionState;
1958+
use crate::{
1959+
common::assert_contains,
1960+
config::ConfigOptions,
1961+
datasource::{empty::EmptyTable, provider_as_source},
1962+
logical_expr::{
1963+
planner::ExprPlanner, AggregateUDF, ScalarUDF, TableSource, WindowUDF,
1964+
},
1965+
physical_plan::ExecutionPlan,
1966+
sql::{planner::ContextProvider, ResolvedTableReference, TableReference},
1967+
};
19581968
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
19591969
use arrow::datatypes::{DataType, Field, Schema};
19601970
use datafusion_catalog::MemoryCatalogProviderList;
@@ -1964,6 +1974,7 @@ mod tests {
19641974
use datafusion_expr::Expr;
19651975
use datafusion_optimizer::optimizer::OptimizerRule;
19661976
use datafusion_optimizer::Optimizer;
1977+
use datafusion_physical_plan::display::DisplayableExecutionPlan;
19671978
use datafusion_sql::planner::{PlannerContext, SqlToRel};
19681979
use std::collections::HashMap;
19691980
use std::sync::Arc;
@@ -2121,4 +2132,148 @@ mod tests {
21212132

21222133
Ok(())
21232134
}
2135+
2136+
/// This test demonstrates why it's more convenient and somewhat necessary to provide
2137+
/// an `expr_planners` method for `SessionState`.
2138+
#[tokio::test]
2139+
async fn test_with_expr_planners() -> Result<()> {
2140+
// A helper method for planning count wildcard with or without expr planners.
2141+
async fn plan_count_wildcard(
2142+
with_expr_planners: bool,
2143+
) -> Result<Arc<dyn ExecutionPlan>> {
2144+
let mut context_provider = MyContextProvider::new().with_table(
2145+
"t",
2146+
provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))),
2147+
);
2148+
if with_expr_planners {
2149+
context_provider = context_provider.with_expr_planners();
2150+
}
2151+
2152+
let state = &context_provider.state;
2153+
let statement = state.sql_to_statement("select count(*) from t", "mysql")?;
2154+
let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?;
2155+
state.create_physical_plan(&plan).await
2156+
}
2157+
2158+
// Planning count wildcard without expr planners should fail.
2159+
let got = plan_count_wildcard(false).await;
2160+
assert_contains!(
2161+
got.unwrap_err().to_string(),
2162+
"Physical plan does not support logical expression Wildcard"
2163+
);
2164+
2165+
// Planning count wildcard with expr planners should succeed.
2166+
let got = plan_count_wildcard(true).await?;
2167+
let displayable = DisplayableExecutionPlan::new(got.as_ref());
2168+
assert_eq!(
2169+
displayable.indent(false).to_string(),
2170+
"ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n"
2171+
);
2172+
2173+
Ok(())
2174+
}
2175+
2176+
/// A `ContextProvider` based on `SessionState`.
2177+
///
2178+
/// Almost all planning context are retrieved from the `SessionState`.
2179+
struct MyContextProvider {
2180+
/// The session state.
2181+
state: SessionState,
2182+
/// Registered tables.
2183+
tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
2184+
/// Controls whether to return expression planners when called `ContextProvider::expr_planners`.
2185+
return_expr_planners: bool,
2186+
}
2187+
2188+
impl MyContextProvider {
2189+
/// Creates a new `SessionContextProvider`.
2190+
pub fn new() -> Self {
2191+
Self {
2192+
state: SessionStateBuilder::default()
2193+
.with_default_features()
2194+
.build(),
2195+
tables: HashMap::new(),
2196+
return_expr_planners: false,
2197+
}
2198+
}
2199+
2200+
/// Registers a table.
2201+
///
2202+
/// The catalog and schema are provided by default.
2203+
pub fn with_table(mut self, table: &str, source: Arc<dyn TableSource>) -> Self {
2204+
self.tables.insert(
2205+
ResolvedTableReference {
2206+
catalog: "default".to_string().into(),
2207+
schema: "public".to_string().into(),
2208+
table: table.to_string().into(),
2209+
},
2210+
source,
2211+
);
2212+
self
2213+
}
2214+
2215+
/// Sets the `return_expr_planners` flag to true.
2216+
pub fn with_expr_planners(self) -> Self {
2217+
Self {
2218+
return_expr_planners: true,
2219+
..self
2220+
}
2221+
}
2222+
}
2223+
2224+
impl ContextProvider for MyContextProvider {
2225+
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
2226+
let resolved_table_ref = ResolvedTableReference {
2227+
catalog: "default".to_string().into(),
2228+
schema: "public".to_string().into(),
2229+
table: name.table().to_string().into(),
2230+
};
2231+
let source = self.tables.get(&resolved_table_ref).cloned().unwrap();
2232+
Ok(source)
2233+
}
2234+
2235+
/// We use a `return_expr_planners` flag to demonstrate why it's necessary to
2236+
/// return the expression planners in the `SessionState`.
2237+
///
2238+
/// Note, the default implementation returns an empty slice.
2239+
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
2240+
if self.return_expr_planners {
2241+
self.state.expr_planners()
2242+
} else {
2243+
&[]
2244+
}
2245+
}
2246+
2247+
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
2248+
self.state.scalar_functions().get(name).cloned()
2249+
}
2250+
2251+
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
2252+
self.state.aggregate_functions().get(name).cloned()
2253+
}
2254+
2255+
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
2256+
self.state.window_functions().get(name).cloned()
2257+
}
2258+
2259+
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
2260+
None
2261+
}
2262+
2263+
fn options(&self) -> &ConfigOptions {
2264+
self.state.config_options()
2265+
}
2266+
2267+
fn udf_names(&self) -> Vec<String> {
2268+
self.state.scalar_functions().keys().cloned().collect()
2269+
}
2270+
2271+
fn udaf_names(&self) -> Vec<String> {
2272+
self.state.aggregate_functions().keys().cloned().collect()
2273+
}
2274+
2275+
fn udwf_names(&self) -> Vec<String> {
2276+
self.state.window_functions().keys().cloned().collect()
2277+
}
2278+
}
21242279
}

0 commit comments

Comments
 (0)