diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index be6cee9885aa7..0c0c1a6dc5588 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -40,8 +40,12 @@ use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; +use sqlparser::ast::{ + ArrayAgg, Expr as SQLExpr, FunctionArg, FunctionArgExpr, JsonOperator, + TrimWhereField, Value, +}; use sqlparser::parser::ParserError::ParserError; +use std::str::FromStr; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn sql_expr_to_logical_expr( @@ -176,7 +180,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), + SQLExpr::Array(arr) => { + if arr.elem.is_empty() { + Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) + } else { + let args = arr + .elem + .iter() + .map(|expr| { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr.clone())) + }) + .collect::>(); + + let args = + self.function_args_to_expr(args, schema, planner_context)?; + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::from_str("make_array")?, + args, + ))) + } + } SQLExpr::Interval(interval) => { self.sql_interval_to_expr(false, interval, schema, planner_context) } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 158054ce6cce7..76e54a5ebfcb0 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -26,7 +26,6 @@ use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -125,45 +124,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - pub(super) fn sql_array_literal( - &self, - elements: Vec, - schema: &DFSchema, - ) -> Result { - let mut values = Vec::with_capacity(elements.len()); - - for element in elements { - let value = self.sql_expr_to_logical_expr( - element, - schema, - &mut PlannerContext::new(), - )?; - match value { - Expr::Literal(scalar) => { - values.push(scalar); - } - _ => { - return not_impl_err!( - "Arrays with elements other than literal are not supported: {value}" - ); - } - } - } - - let data_types: HashSet = - values.iter().map(|e| e.get_datatype()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].get_datatype(); - - Ok(lit(ScalarValue::new_list(Some(values), data_type))) - } - } - /// Convert a SQL interval expression to a DataFusion logical plan /// expression pub(super) fn sql_interval_to_expr( diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 8b4d9686a4f56..04879c931f28e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -24,8 +24,7 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::plan_err; use datafusion_common::{ - assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, - TableReference, + config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, @@ -1364,18 +1363,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.to_string(), - r#"Arrays with different types are not supported: "# - ); -} - #[test] fn recursive_ctes() { let sql = " @@ -1392,16 +1379,6 @@ fn recursive_ctes() { ); } -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#, - format!("{err:?}") - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -4230,6 +4207,15 @@ fn test_multi_grouping_sets() { quick_test(sql, expected); } +#[test] +fn test_array_with_binary_expr() { + let sql = "select [1>0,2>1]"; + + let expected = "Projection: make_array(Int64(1) > Int64(0), Int64(2) > Int64(1))\ + \n EmptyRelation"; + quick_test(sql, expected); +} + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => {