diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 94dcd2a86150..08f3646fb7eb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -316,7 +316,7 @@ pub enum Expr { /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(FieldRef, Vec), /// A constant value along with associated [`FieldMetadata`]. Literal(ScalarValue, Option), /// A binary expression such as "age > 21" @@ -2529,8 +2529,8 @@ impl HashNode for Expr { Expr::Column(column) => { column.hash(state); } - Expr::ScalarVariable(data_type, name) => { - data_type.hash(state); + Expr::ScalarVariable(field, name) => { + field.hash(state); name.hash(state); } Expr::Literal(scalar_value, _) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c..b7e968419571 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -116,7 +116,7 @@ impl ExprSchemable for Expr { Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), - Expr::ScalarVariable(ty, _) => Ok(ty.clone()), + Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()), Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { @@ -312,12 +312,8 @@ impl ExprSchemable for Expr { window_function, ) .map(|(_, nullable)| nullable), - Expr::Placeholder(Placeholder { id: _, field }) => { - Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true)) - } - Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => { - Ok(true) - } + Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), + Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -451,8 +447,8 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => { Ok(Arc::new(field.as_ref().clone().with_name(&schema_name))) } - Expr::ScalarVariable(ty, _) => { - Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + Expr::ScalarVariable(field, _) => { + Ok(Arc::new(field.as_ref().clone().with_name(&schema_name))) } Expr::Literal(l, metadata) => { let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947ee..39df2e9dedfb 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -25,7 +25,7 @@ use crate::{ AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF, }; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, @@ -103,6 +103,17 @@ pub trait ContextProvider { /// A user defined variable is typically accessed via `@var_name` fn get_variable_type(&self, variable_names: &[String]) -> Option; + /// Return metadata about a system/user-defined variable, if any. + /// + /// By default, this wraps [`Self::get_variable_type`] in an Arrow [`Field`] + /// with nullable set to `true` and no metadata. Implementations that can + /// provide richer information (such as nullability or extension metadata) + /// should override this method. + fn get_variable_field(&self, variable_names: &[String]) -> Option { + self.get_variable_type(variable_names) + .map(|data_type| data_type.into_nullable_field_ref()) + } + /// Return overall configuration options fn options(&self) -> &ConfigOptions; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade6..e408a5eca3c6 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -39,13 +39,18 @@ impl SqlToRel<'_, S> { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; - let ty = self + let field = self .context_provider - .get_variable_type(&var_names) + .get_variable_field(&var_names) + .or_else(|| { + self.context_provider + .get_variable_type(&var_names) + .map(|ty| ty.into_nullable_field_ref()) + }) .ok_or_else(|| { plan_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(field, var_names)) } else { // Don't use `col()` here because it will try to // interpret names with '.' as if they were @@ -113,13 +118,18 @@ impl SqlToRel<'_, S> { .into_iter() .map(|id| self.ident_normalizer.normalize(id)) .collect(); - let ty = self + let field = self .context_provider - .get_variable_type(&var_names) + .get_variable_field(&var_names) + .or_else(|| { + self.context_provider + .get_variable_type(&var_names) + .map(|ty| ty.into_nullable_field_ref()) + }) .ok_or_else(|| { exec_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(field, var_names)) } else { let ids = ids .into_iter() diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a7fe8efa153c..8c1032eae49c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -2169,12 +2169,15 @@ mod tests { r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - Expr::ScalarVariable(Int8, vec![String::from("@a")]), + Expr::ScalarVariable( + Int8.into_nullable_field_ref(), + vec![String::from("@a")], + ), r#"@a"#, ), ( Expr::ScalarVariable( - Int8, + Int8.into_nullable_field_ref(), vec![String::from("@root"), String::from("foo")], ), r#"@root.foo"#,