From a8d36ca0a134cdf75307ca4e81243d45fd4215ff Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Sat, 22 Mar 2025 15:56:09 -0600 Subject: [PATCH 1/8] added test --- Cargo.lock | 3 ++ datafusion/wasmtest/Cargo.toml | 3 ++ datafusion/wasmtest/src/lib.rs | 71 +++++++++++++++++++++++++++++----- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f9263e529034..d6c759b8e71d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2592,7 +2592,10 @@ dependencies = [ "datafusion-physical-plan", "datafusion-sql", "getrandom 0.2.15", + "insta", + "object_store", "tokio", + "url", "wasm-bindgen", "wasm-bindgen-test", ] diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 94515c6754a7a..10eab025734c9 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -58,5 +58,8 @@ getrandom = { version = "0.2.8", features = ["js"] } wasm-bindgen = "0.2.99" [dev-dependencies] +insta = { workspace = true } +object_store = { workspace = true } tokio = { workspace = true } +url = { workspace = true } wasm-bindgen-test = "0.3.49" diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index df0d9d6cbf37e..6c7be9056eb43 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -82,6 +82,7 @@ pub fn basic_parse() { #[cfg(test)] mod test { use super::*; + use datafusion::execution::options::ParquetReadOptions; use datafusion::{ arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray}, @@ -90,12 +91,16 @@ mod test { datasource::MemTable, execution::context::SessionContext, }; + use datafusion_common::test_util::batches_to_string; use datafusion_execution::{ config::SessionConfig, disk_manager::DiskManagerConfig, runtime_env::RuntimeEnvBuilder, }; use datafusion_physical_plan::collect; use datafusion_sql::parser::DFParser; + use insta::assert_snapshot; + use object_store::{memory::InMemory, path::Path, ObjectStore}; + use url::Url; use wasm_bindgen_test::wasm_bindgen_test; wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); @@ -115,6 +120,22 @@ mod test { let session_config = SessionConfig::new().with_target_partitions(1); Arc::new(SessionContext::new_with_config_rt(session_config, rt)) } + + fn create_test_data() -> (Arc, RecordBatch) { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + + let data: Vec = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ]; + + let batch = RecordBatch::try_new(schema.clone(), data).unwrap(); + (schema, batch) + } + #[wasm_bindgen_test(unsupported = tokio::test)] async fn basic_execute() { let sql = "SELECT 2 + 2;"; @@ -185,17 +206,22 @@ mod test { #[wasm_bindgen_test(unsupported = tokio::test)] async fn test_parquet_write() { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("value", DataType::Utf8, false), - ])); + let (schema, batch) = create_test_data(); + let mut buffer = Vec::new(); + let mut writer = datafusion::parquet::arrow::ArrowWriter::try_new( + &mut buffer, + schema.clone(), + None, + ) + .unwrap(); - let data: Vec = vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ]; + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } - let batch = RecordBatch::try_new(schema.clone(), data).unwrap(); + #[wasm_bindgen_test(unsupported = tokio::test)] + async fn test_parquet_read_and_write() { + let (schema, batch) = create_test_data(); let mut buffer = Vec::new(); let mut writer = datafusion::parquet::arrow::ArrowWriter::try_new( &mut buffer, @@ -203,8 +229,33 @@ mod test { None, ) .unwrap(); - writer.write(&batch).unwrap(); writer.close().unwrap(); + + let session_ctx = SessionContext::new(); + let store = InMemory::new(); + + let path = Path::from("a.parquet"); + store.put(&path, buffer.into()).await.unwrap(); + + let url = Url::parse("memory://").unwrap(); + session_ctx.register_object_store(&url, Arc::new(store)); + + let df = session_ctx + .read_parquet("memory:///", ParquetReadOptions::new()) + .await + .unwrap(); + + let result = df.collect().await.unwrap(); + + assert_snapshot!(batches_to_string(&result), @r" + +----+-------+ + | id | value | + +----+-------+ + | 1 | a | + | 2 | b | + | 3 | c | + +----+-------+ + "); } } From 26b5322a071dba10b4dd6ae868550ace39c7e865 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Sat, 17 May 2025 18:04:13 -0600 Subject: [PATCH 2/8] added parameterTest --- datafusion/sql/tests/sql_integration.rs | 70 ++++++++++++++++--------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6d08dd4fe5ffb..e86ba6cf272b4 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -55,6 +55,35 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl<'a> ParameterTest<'a> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + // check parameter types + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + // replace params with values + let plan_with_params = plan.clone().with_param_values(self.param_values.clone()).unwrap(); + + format!( + "** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}" + ) + } +} + #[test] fn parse_decimals_1() { let sql = "SELECT 1"; @@ -4665,33 +4694,22 @@ fn test_prepare_statement_infer_types_from_join() { #[test] fn test_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); + assert_snapshot!(test.run(), @r###" +** Initial Plan: +Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person +** Final Plan: +Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person +"###); } #[test] From d51055ac7a6108fa4b88fed2ce1c0b58b285c07e Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Sat, 17 May 2025 18:09:52 -0600 Subject: [PATCH 3/8] cargo fmt --- datafusion/sql/tests/sql_integration.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e86ba6cf272b4..f28c0438d0a80 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -76,11 +76,12 @@ impl<'a> ParameterTest<'a> { assert_eq!(actual_types, expected_types); // replace params with values - let plan_with_params = plan.clone().with_param_values(self.param_values.clone()).unwrap(); + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); - format!( - "** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}" - ) + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") } } From 6984e0871a3035fe5d6858e7e5bdfd90da527209 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Sat, 17 May 2025 18:15:55 -0600 Subject: [PATCH 4/8] Update sql_integration.rs --- datafusion/sql/tests/sql_integration.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f28c0438d0a80..2b6a6e34861fd 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -65,7 +65,6 @@ impl<'a> ParameterTest<'a> { pub fn run(&self) -> String { let plan = logical_plan(self.sql).unwrap(); - // check parameter types let actual_types = plan.get_parameter_types().unwrap(); let expected_types: HashMap> = self .expected_types @@ -75,7 +74,6 @@ impl<'a> ParameterTest<'a> { assert_eq!(actual_types, expected_types); - // replace params with values let plan_with_params = plan .clone() .with_param_values(self.param_values.clone()) From c1fcb20d5a071b934b7c6a1c62f1f4075c32b38e Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Mon, 19 May 2025 17:45:25 -0600 Subject: [PATCH 5/8] allow needless_lifetimes --- datafusion/sql/tests/sql_integration.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2b6a6e34861fd..de07fcbe90fcc 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -61,6 +61,7 @@ pub struct ParameterTest<'a> { pub param_values: Vec, } +#[allow(clippy::needless_lifetimes)] impl<'a> ParameterTest<'a> { pub fn run(&self) -> String { let plan = logical_plan(self.sql).unwrap(); From 0578e69f8e25e9b8360ffbc524a2090e7f580d71 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 21 May 2025 13:55:37 -0400 Subject: [PATCH 6/8] remove needless lifetime --- datafusion/sql/tests/sql_integration.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index de07fcbe90fcc..1b517f747bc32 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -61,8 +61,7 @@ pub struct ParameterTest<'a> { pub param_values: Vec, } -#[allow(clippy::needless_lifetimes)] -impl<'a> ParameterTest<'a> { +impl ParameterTest<'_> { pub fn run(&self) -> String { let plan = logical_plan(self.sql).unwrap(); From b07ac44ea9409085b793b37d0e546e1a400a0fbf Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 21 May 2025 23:45:57 -0400 Subject: [PATCH 7/8] update some tests --- datafusion/sql/tests/sql_integration.rs | 137 +++++++++--------------- 1 file changed, 51 insertions(+), 86 deletions(-) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 1b517f747bc32..3771b78c2a8b2 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4732,32 +4732,23 @@ fn test_prepare_statement_infer_types_from_predicate() { #[test] fn test_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age BETWEEN $1 AND $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) TableScan: person @@ -4790,12 +4781,16 @@ fn test_prepare_statement_infer_types_from_between_predicate() { #[test] fn test_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -4804,20 +4799,7 @@ fn test_infer_types_subquery() { Filter: person.id = $1 TableScan: person TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -4826,7 +4808,7 @@ fn test_infer_types_subquery() { Filter: person.id = UInt32(10) TableScan: person TableScan: person - " + " ); } @@ -4857,38 +4839,29 @@ fn test_prepare_statement_infer_types_subquery() { #[test] fn test_update_infer() { - let sql = "update person set age=$1 where id=$2"; + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = UInt32(1) TableScan: person - " + " ); } @@ -4918,35 +4891,27 @@ fn test_prepare_statement_update_infer() { #[test] fn test_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" + ** Final Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) From db4e622e0f742b744dc0de61bd8d1318bab45890 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Sat, 24 May 2025 11:37:44 -0400 Subject: [PATCH 8/8] move to params.rs --- datafusion/sql/tests/cases/params.rs | 412 ++++++++++-------------- datafusion/sql/tests/sql_integration.rs | 28 -- 2 files changed, 167 insertions(+), 273 deletions(-) diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs index 7a4fbba1475e2..b3cc49c310718 100644 --- a/datafusion/sql/tests/cases/params.rs +++ b/datafusion/sql/tests/cases/params.rs @@ -22,6 +22,34 @@ use datafusion_expr::{LogicalPlan, Prepare, Statement}; use insta::assert_snapshot; use std::collections::HashMap; +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTest<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { let plan = logical_plan(sql).unwrap(); let data_types = match &plan { @@ -311,31 +339,22 @@ fn test_prepare_statement_to_plan_params_as_constants() { #[test] fn test_infer_types_from_join() { - let sql = - "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; + let test = ParameterTest { + sql: + "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 TableScan: person TableScan: orders - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) TableScan: person @@ -346,64 +365,46 @@ fn test_infer_types_from_join() { #[test] fn test_prepare_statement_infer_types_from_join() { - let sql = - "PREPARE my_plan AS SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))] + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [Int32] Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 TableScan: person TableScan: orders - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) TableScan: person TableScan: orders - " + "# ); } #[test] fn test_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age = $1 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = Int32(10) TableScan: person @@ -413,64 +414,46 @@ fn test_infer_types_from_predicate() { #[test] fn test_prepare_statement_infer_types_from_predicate() { - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [Int32] Projection: person.id, person.age Filter: person.age = $1 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = Int32(10) TableScan: person - " + "# ); } #[test] fn test_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age BETWEEN $1 AND $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) TableScan: person @@ -480,48 +463,42 @@ fn test_infer_types_from_between_predicate() { #[test] fn test_prepare_statement_infer_types_from_between_predicate() { - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; - - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [Int32, Int32] Projection: person.id, person.age Filter: person.age BETWEEN $1 AND $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) TableScan: person - " + "# ); } #[test] fn test_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -530,20 +507,7 @@ fn test_infer_types_subquery() { Filter: person.id = $1 TableScan: person TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -552,18 +516,22 @@ fn test_infer_types_subquery() { Filter: person.id = UInt32(10) TableScan: person TableScan: person - " + " ); } #[test] fn test_prepare_statement_infer_types_subquery() { - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [UInt32] Projection: person.id, person.age Filter: person.age = () @@ -573,20 +541,7 @@ fn test_prepare_statement_infer_types_subquery() { Filter: person.id = $1 TableScan: person TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Projection: person.id, person.age Filter: person.age = () Subquery: @@ -595,116 +550,91 @@ fn test_prepare_statement_infer_types_subquery() { Filter: person.id = UInt32(10) TableScan: person TableScan: person - " + "# ); } #[test] fn test_update_infer() { - let sql = "update person set age=$1 where id=$2"; + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, - @r#" + test.run(), + @r" + ** Initial Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = UInt32(1) TableScan: person - " + " ); } #[test] fn test_prepare_statement_update_infer() { - let sql = "PREPARE my_plan AS update person set age=$1 where id=$2"; + let test = ParameterTest { + sql: "PREPARE my_plan AS update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; - let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [Int32, UInt32] Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = $2 TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" + ** Final Plan: Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 Filter: person.id = UInt32(1) TableScan: person - " + "# ); } #[test] fn test_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; + assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" + ** Final Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) @@ -714,36 +644,28 @@ fn test_insert_infer() { #[test] fn test_prepare_statement_insert_infer() { - let sql = "PREPARE my_plan AS insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); + let test = ParameterTest { + sql: "PREPARE my_plan AS insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + }; assert_snapshot!( - plan, + test.run(), @r#" + ** Initial Plan: Prepare: "my_plan" [UInt32, Utf8, Utf8] Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" + ** Final Plan: Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e6c2d82134b55..365012b7f6b00 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -50,34 +50,6 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; -pub struct ParameterTest<'a> { - pub sql: &'a str, - pub expected_types: Vec<(&'a str, Option)>, - pub param_values: Vec, -} - -impl ParameterTest<'_> { - pub fn run(&self) -> String { - let plan = logical_plan(self.sql).unwrap(); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types: HashMap> = self - .expected_types - .iter() - .map(|(k, v)| (k.to_string(), v.clone())) - .collect(); - - assert_eq!(actual_types, expected_types); - - let plan_with_params = plan - .clone() - .with_param_values(self.param_values.clone()) - .unwrap(); - - format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") - } -} - #[test] fn parse_decimals_1() { let sql = "SELECT 1";