Skip to content
Open
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ pub enum Expr {
/// A named reference to a qualified field in a schema.
Column(Column),
/// A named reference to a variable in a registry.
ScalarVariable(DataType, Vec<String>),
ScalarVariable(FieldRef, Vec<String>),
/// A constant value along with associated [`FieldMetadata`].
Literal(ScalarValue, Option<FieldMetadata>),
/// A binary expression such as "age > 21"
Expand Down Expand Up @@ -2529,8 +2529,8 @@ impl HashNode for Expr {
Expr::Column(column) => {
column.hash(state);
}
Expr::ScalarVariable(data_type, name) => {
data_type.hash(state);
Expr::ScalarVariable(field, name) => {
field.hash(state);
name.hash(state);
}
Expr::Literal(scalar_value, _) => {
Expand Down
14 changes: 5 additions & 9 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl ExprSchemable for Expr {
Expr::Negative(expr) => expr.get_type(schema),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()),
Expr::Literal(l, _) => Ok(l.data_type()),
Expr::Case(case) => {
for (_, then_expr) in &case.when_then_expr {
Expand Down Expand Up @@ -312,12 +312,8 @@ impl ExprSchemable for Expr {
window_function,
)
.map(|(_, nullable)| nullable),
Expr::Placeholder(Placeholder { id: _, field }) => {
Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true))
}
Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => {
Ok(true)
}
Expr::ScalarVariable(field, _) => Ok(field.is_nullable()),
Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
Expand Down Expand Up @@ -451,8 +447,8 @@ impl ExprSchemable for Expr {
Expr::OuterReferenceColumn(field, _) => {
Ok(Arc::new(field.as_ref().clone().with_name(&schema_name)))
}
Expr::ScalarVariable(ty, _) => {
Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
Expr::ScalarVariable(field, _) => {
Ok(Arc::new(field.as_ref().clone().with_name(&schema_name)))
}
Expr::Literal(l, metadata) => {
let mut field = Field::new(&schema_name, l.data_type(), l.is_null());
Expand Down
13 changes: 12 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame,
WindowFunctionDefinition, WindowUDF,
};
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
Expand Down Expand Up @@ -103,6 +103,17 @@ pub trait ContextProvider {
/// A user defined variable is typically accessed via `@var_name`
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType>;

/// Return metadata about a system/user-defined variable, if any.
///
/// By default, this wraps [`Self::get_variable_type`] in an Arrow [`Field`]
/// with nullable set to `true` and no metadata. Implementations that can
/// provide richer information (such as nullability or extension metadata)
/// should override this method.
fn get_variable_field(&self, variable_names: &[String]) -> Option<FieldRef> {
self.get_variable_type(variable_names)
.map(|data_type| data_type.into_nullable_field_ref())
}

/// Return overall configuration options
fn options(&self) -> &ConfigOptions;

Expand Down
22 changes: 16 additions & 6 deletions datafusion/sql/src/expr/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
if id.value.starts_with('@') {
// TODO: figure out if ScalarVariables should be insensitive.
let var_names = vec![id.value];
let ty = self
let field = self
.context_provider
.get_variable_type(&var_names)
.get_variable_field(&var_names)
.or_else(|| {
self.context_provider
.get_variable_type(&var_names)
.map(|ty| ty.into_nullable_field_ref())
})
.ok_or_else(|| {
plan_datafusion_err!("variable {var_names:?} has no type information")
})?;
Ok(Expr::ScalarVariable(ty, var_names))
Ok(Expr::ScalarVariable(field, var_names))
} else {
// Don't use `col()` here because it will try to
// interpret names with '.' as if they were
Expand Down Expand Up @@ -113,13 +118,18 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.into_iter()
.map(|id| self.ident_normalizer.normalize(id))
.collect();
let ty = self
let field = self
.context_provider
.get_variable_type(&var_names)
.get_variable_field(&var_names)
.or_else(|| {
self.context_provider
.get_variable_type(&var_names)
.map(|ty| ty.into_nullable_field_ref())
})
.ok_or_else(|| {
exec_datafusion_err!("variable {var_names:?} has no type information")
})?;
Ok(Expr::ScalarVariable(ty, var_names))
Ok(Expr::ScalarVariable(field, var_names))
} else {
let ids = ids
.into_iter()
Expand Down
7 changes: 5 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2169,12 +2169,15 @@ mod tests {
r#"TRY_CAST(a AS INTEGER UNSIGNED)"#,
),
(
Expr::ScalarVariable(Int8, vec![String::from("@a")]),
Expr::ScalarVariable(
Int8.into_nullable_field_ref(),
vec![String::from("@a")],
),
r#"@a"#,
),
(
Expr::ScalarVariable(
Int8,
Int8.into_nullable_field_ref(),
vec![String::from("@root"), String::from("foo")],
),
r#"@root.foo"#,
Expand Down
Loading