Skip to content

Use make_array to handle SQLExpr::Array. #7427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
};
use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value};
use sqlparser::ast::{
ArrayAgg, Expr as SQLExpr, FunctionArg, FunctionArgExpr, JsonOperator,
TrimWhereField, Value,
};
use sqlparser::parser::ParserError::ParserError;
use std::str::FromStr;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn sql_expr_to_logical_expr(
Expand Down Expand Up @@ -176,7 +180,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)))
}

SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
SQLExpr::Array(arr) => {
if arr.elem.is_empty() {
Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
} else {
let args = arr
.elem
.iter()
.map(|expr| {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr.clone()))
})
.collect::<Vec<FunctionArg>>();

let args =
self.function_args_to_expr(args, schema, planner_context)?;
Ok(Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::from_str("make_array")?,
args,
)))
}
}
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(false, interval, schema, planner_context)
}
Expand Down
40 changes: 0 additions & 40 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn parse_value(
Expand Down Expand Up @@ -125,45 +124,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)))
}

pub(super) fn sql_array_literal(
&self,
elements: Vec<SQLExpr>,
schema: &DFSchema,
) -> Result<Expr> {
let mut values = Vec::with_capacity(elements.len());

for element in elements {
let value = self.sql_expr_to_logical_expr(
element,
schema,
&mut PlannerContext::new(),
)?;
match value {
Expr::Literal(scalar) => {
values.push(scalar);
}
_ => {
return not_impl_err!(
"Arrays with elements other than literal are not supported: {value}"
);
}
}
}

let data_types: HashSet<DataType> =
values.iter().map(|e| e.get_datatype()).collect();

if data_types.is_empty() {
Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
} else if data_types.len() > 1 {
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
} else {
let data_type = values[0].get_datatype();

Ok(lit(ScalarValue::new_list(Some(values), data_type)))
}
}

/// Convert a SQL interval expression to a DataFusion logical plan
/// expression
pub(super) fn sql_interval_to_expr(
Expand Down
34 changes: 10 additions & 24 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};

use datafusion_common::plan_err;
use datafusion_common::{
assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue,
TableReference,
config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference,
};
use datafusion_expr::{
logical_plan::{LogicalPlan, Prepare},
Expand Down Expand Up @@ -1364,18 +1363,6 @@ fn select_interval_out_of_range() {
);
}

#[test]
fn select_array_no_common_type() {
let sql = "SELECT [1, true, null]";
let err = logical_plan(sql).expect_err("query should have failed");

// HashSet doesn't guarantee order
assert_contains!(
err.to_string(),
r#"Arrays with different types are not supported: "#
);
}

#[test]
fn recursive_ctes() {
let sql = "
Expand All @@ -1392,16 +1379,6 @@ fn recursive_ctes() {
);
}

#[test]
fn select_array_non_literal_type() {
let sql = "SELECT [now()]";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#,
format!("{err:?}")
);
}

#[test]
fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
Expand Down Expand Up @@ -4230,6 +4207,15 @@ fn test_multi_grouping_sets() {
quick_test(sql, expected);
}

#[test]
fn test_array_with_binary_expr() {
let sql = "select [1>0,2>1]";

let expected = "Projection: make_array(Int64(1) > Int64(0), Int64(2) > Int64(1))\
\n EmptyRelation";
quick_test(sql, expected);
}

fn assert_field_not_found(err: DataFusionError, name: &str) {
match err {
DataFusionError::SchemaError { .. } => {
Expand Down