diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6d08dd4fe5ff..3771b78c2a8b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -55,6 +55,34 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTest<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + #[test] fn parse_decimals_1() { let sql = "SELECT 1"; @@ -4665,33 +4693,22 @@ fn test_prepare_statement_infer_types_from_join() { #[test] fn test_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); + assert_snapshot!(test.run(), @r###" +** Initial Plan: +Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person +** Final Plan: +Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person +"###); } #[test] @@ -4715,32 +4732,23 @@ fn test_prepare_statement_infer_types_from_predicate() { #[test] fn test_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age BETWEEN $1 AND $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) TableScan: person @@ -4773,12 +4781,16 @@ fn test_prepare_statement_infer_types_from_between_predicate() { #[test] fn test_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -4787,20 +4799,7 @@ fn test_infer_types_subquery() { Filter: person.id = $1 TableScan: person TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -4809,7 +4808,7 @@ fn test_infer_types_subquery() { Filter: person.id = UInt32(10) TableScan: person TableScan: person - " + " ); } @@ -4840,38 +4839,29 @@ fn test_prepare_statement_infer_types_subquery() { #[test] fn test_update_infer() { - let sql = "update person set age=$1 where id=$2"; + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = UInt32(1) TableScan: person - " + " ); } @@ -4901,35 +4891,27 @@ fn test_prepare_statement_update_infer() { #[test] fn test_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" + ** Final Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3)