diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 24ec9b7be323..89dc4956abb1 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -47,6 +47,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod metadata; pub mod nested_struct; mod null_equality; pub mod parsers; diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs new file mode 100644 index 000000000000..9520ed016f7c --- /dev/null +++ b/datafusion/common/src/metadata.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::BTreeMap, sync::Arc}; + +use arrow::datatypes::{DataType, Field}; +use hashbrown::HashMap; + +use crate::{error::_plan_err, DataFusionError, ScalarValue}; + +/// A [`ScalarValue`] with optional [`FieldMetadata`] +#[derive(Debug, Clone)] +pub struct Literal { + pub value: ScalarValue, + pub metadata: Option, +} + +impl Literal { + /// Create a new Literal from a scalar value with optional [`FieldMetadata`] + pub fn new(value: ScalarValue, metadata: Option) -> Self { + Self { value, metadata } + } + + /// Access the underlying [ScalarValue] storage + pub fn value(&self) -> &ScalarValue { + &self.value + } + + /// Access the [FieldMetadata] attached to this value, if any + pub fn metadata(&self) -> Option<&FieldMetadata> { + self.metadata.as_ref() + } + + /// Consume self and return components + pub fn into_inner(self) -> (ScalarValue, Option) { + (self.value, self.metadata) + } + + /// Cast this literal's storage type + /// + /// This operation assumes that if the underlying [ScalarValue] can be casted + /// to a given type that any extension type represented by the metadata is also + /// valid. + pub fn cast_storage_to( + &self, + target_type: &DataType, + ) -> Result { + let new_value = self.value().cast_to(target_type)?; + Ok(Literal::new(new_value, self.metadata.clone())) + } +} + +/// Assert equality of data types where one or both sides may have field metadata +/// +/// This currently compares absent metadata (e.g., one side was a DataType) and +/// empty metadata (e.g., one side was a field where the field had no metadata) +/// as equal and uses byte-for-byte comparison for the keys and values of the +/// fields, even though this is potentially too strict for some cases (e.g., +/// extension types where extension metadata is represented by JSON, or cases +/// where field metadata is orthogonal to the interpretation of the data type). +/// +/// Returns a planning error with suitably formatted type representations if +/// actual and expected do not compare to equal. +pub fn check_metadata_with_storage_equal( + actual: ( + &DataType, + Option<&std::collections::HashMap>, + ), + expected: ( + &DataType, + Option<&std::collections::HashMap>, + ), + what: &str, + context: &str, +) -> Result<(), DataFusionError> { + if actual.0 != expected.0 { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + let metadata_equal = match (actual.1, expected.1) { + (None, None) => true, + (None, Some(expected_metadata)) => expected_metadata.is_empty(), + (Some(actual_metadata), None) => actual_metadata.is_empty(), + (Some(actual_metadata), Some(expected_metadata)) => { + actual_metadata == expected_metadata + } + }; + + if !metadata_equal { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + Ok(()) +} + +/// Given a data type represented by storage and optional metadata, generate +/// a user-facing string +/// +/// This function exists to reduce the number of Field debug strings that are +/// used to communicate type information in error messages and plan explain +/// renderings. +pub fn format_type_and_metadata( + data_type: &DataType, + metadata: Option<&std::collections::HashMap>, +) -> String { + match metadata { + Some(metadata) if !metadata.is_empty() => { + format!("{data_type}<{metadata:?}>") + } + _ => data_type.to_string(), + } +} + +/// Literal metadata +/// +/// Stores metadata associated with a literal expressions +/// and is designed to be fast to `clone`. +/// +/// This structure is used to store metadata associated with a literal expression, and it +/// corresponds to the `metadata` field on [`Field`]. +/// +/// # Example: Create [`FieldMetadata`] from a [`Field`] +/// ``` +/// # use std::collections::HashMap; +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true) +/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); +/// // Create a new `FieldMetadata` instance from a `Field` +/// let metadata = FieldMetadata::new_from_field(&field); +/// // There is also a `From` impl: +/// let metadata = FieldMetadata::from(&field); +/// ``` +/// +/// # Example: Update a [`Field`] with [`FieldMetadata`] +/// ``` +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true); +/// # let metadata = FieldMetadata::new_from_field(&field); +/// // Add any metadata from `FieldMetadata` to `Field` +/// let updated_field = metadata.add_to_field(field); +/// ``` +/// +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct FieldMetadata { + /// The inner metadata of a literal expression, which is a map of string + /// keys to string values. + /// + /// Note this is not a `HashMap` because `HashMap` does not provide + /// implementations for traits like `Debug` and `Hash`. + inner: Arc>, +} + +impl Default for FieldMetadata { + fn default() -> Self { + Self::new_empty() + } +} + +impl FieldMetadata { + /// Create a new empty metadata instance. + pub fn new_empty() -> Self { + Self { + inner: Arc::new(BTreeMap::new()), + } + } + + /// Merges two optional `FieldMetadata` instances, overwriting any existing + /// keys in `m` with keys from `n` if present. + /// + /// This function is commonly used in alias operations, particularly for literals + /// with metadata. When creating an alias expression, the metadata from the original + /// expression (such as a literal) is combined with any metadata specified on the alias. + /// + /// # Arguments + /// + /// * `m` - The first metadata (typically from the original expression like a literal) + /// * `n` - The second metadata (typically from the alias definition) + /// + /// # Merge Strategy + /// + /// - If both metadata instances exist, they are merged with `n` taking precedence + /// - Keys from `n` will overwrite keys from `m` if they have the same name + /// - If only one metadata instance exists, it is returned unchanged + /// - If neither exists, `None` is returned + /// + /// # Example usage + /// ```rust + /// use datafusion_common::metadata::FieldMetadata; + /// use std::collections::BTreeMap; + /// + /// // Create metadata for a literal expression + /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("source".to_string(), "constant".to_string()), + /// ("type".to_string(), "int".to_string()), + /// ]))); + /// + /// // Create metadata for an alias + /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("description".to_string(), "answer".to_string()), + /// ("source".to_string(), "user".to_string()), // This will override literal's "source" + /// ]))); + /// + /// // Merge the metadata + /// let merged = FieldMetadata::merge_options( + /// literal_metadata.as_ref(), + /// alias_metadata.as_ref(), + /// ); + /// + /// // Result contains: {"source": "user", "type": "int", "description": "answer"} + /// assert!(merged.is_some()); + /// ``` + pub fn merge_options( + m: Option<&FieldMetadata>, + n: Option<&FieldMetadata>, + ) -> Option { + match (m, n) { + (Some(m), Some(n)) => { + let mut merged = m.clone(); + merged.extend(n.clone()); + Some(merged) + } + (Some(m), None) => Some(m.clone()), + (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Create a new metadata instance from a `Field`'s metadata. + pub fn new_from_field(field: &Field) -> Self { + let inner = field + .metadata() + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self { + inner: Arc::new(inner), + } + } + + /// Create a new metadata instance from a map of string keys to string values. + pub fn new(inner: BTreeMap) -> Self { + Self { + inner: Arc::new(inner), + } + } + + /// Get the inner metadata as a reference to a `BTreeMap`. + pub fn inner(&self) -> &BTreeMap { + &self.inner + } + + /// Return the inner metadata + pub fn into_inner(self) -> Arc> { + self.inner + } + + /// Adds metadata from `other` into `self`, overwriting any existing keys. + pub fn extend(&mut self, other: Self) { + if other.is_empty() { + return; + } + let other = Arc::unwrap_or_clone(other.into_inner()); + Arc::make_mut(&mut self.inner).extend(other); + } + + /// Returns true if the metadata is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of key-value pairs in the metadata. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Convert this `FieldMetadata` into a `HashMap` + pub fn to_hashmap(&self) -> std::collections::HashMap { + self.inner + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + /// Updates the metadata on the Field with this metadata, if it is not empty. + pub fn add_to_field(&self, field: Field) -> Field { + if self.inner.is_empty() { + return field; + } + + field.with_metadata(self.to_hashmap()) + } +} + +impl From<&Field> for FieldMetadata { + fn from(field: &Field) -> Self { + Self::new_from_field(field) + } +} + +impl From> for FieldMetadata { + fn from(inner: BTreeMap) -> Self { + Self::new(inner) + } +} + +impl From> for FieldMetadata { + fn from(map: std::collections::HashMap) -> Self { + Self::new(map.into_iter().collect()) + } +} + +/// From reference +impl From<&std::collections::HashMap> for FieldMetadata { + fn from(map: &std::collections::HashMap) -> Self { + let inner = map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// From hashbrown map +impl From> for FieldMetadata { + fn from(map: HashMap) -> Self { + let inner = map.into_iter().collect(); + Self::new(inner) + } +} + +impl From<&HashMap> for FieldMetadata { + fn from(map: &HashMap) -> Self { + let inner = map + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 7582cff56f87..5b831c9f3047 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -16,22 +16,37 @@ // under the License. use crate::error::{_plan_datafusion_err, _plan_err}; +use crate::metadata::{check_metadata_with_storage_equal, Literal}; use crate::{Result, ScalarValue}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::collections::HashMap; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] pub enum ParamValues { /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` - List(Vec), + List(Vec), /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` - Map(HashMap), + Map(HashMap), } impl ParamValues { - /// Verify parameter list length and type + /// Verify parameter list length and DataType + /// + /// Use [`ParamValues::verify_fields`] to ensure field metadata is considered when + /// computing type equality. + #[deprecated] pub fn verify(&self, expect: &[DataType]) -> Result<()> { + // make dummy Fields + let expect = expect + .iter() + .map(|dt| Field::new("", dt.clone(), true).into()) + .collect::>(); + self.verify_fields(&expect) + } + + /// Verify parameter list length and type + pub fn verify_fields(&self, expect: &[FieldRef]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -45,15 +60,16 @@ impl ParamValues { // Verify if the types of the params matches the types of the values let iter = expect.iter().zip(list.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return _plan_err!( - "Expected parameter of type {}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } + for (i, (param_type, lit)) in iter.enumerate() { + check_metadata_with_storage_equal( + ( + &lit.value.data_type(), + lit.metadata.as_ref().map(|m| m.to_hashmap()).as_ref(), + ), + (param_type.data_type(), Some(param_type.metadata())), + "parameter", + &format!(" at index {i}"), + )?; } Ok(()) } @@ -65,7 +81,7 @@ impl ParamValues { } } - pub fn get_placeholders_with_values(&self, id: &str) -> Result { + pub fn get_placeholders_with_values(&self, id: &str) -> Result { match self { ParamValues::List(list) => { if id.is_empty() { @@ -99,7 +115,7 @@ impl ParamValues { impl From> for ParamValues { fn from(value: Vec) -> Self { - Self::List(value) + Self::List(value.into_iter().map(|v| Literal::new(v, None)).collect()) } } @@ -108,8 +124,10 @@ where K: Into, { fn from(value: Vec<(K, ScalarValue)>) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), Literal::new(v, None))) + .collect(); Self::Map(value) } } @@ -119,8 +137,10 @@ where K: Into, { fn from(value: HashMap) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), Literal::new(v, None))) + .collect(); Self::Map(value) } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index a8148b80495e..6f27975c171c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -64,12 +64,13 @@ use datafusion_catalog::{ DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory, }; use datafusion_common::config::ConfigOptions; +use datafusion_common::metadata::Literal; use datafusion_common::{ config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaReference, TableReference, + DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, }; pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; @@ -708,15 +709,15 @@ impl SessionContext { LogicalPlan::Statement(Statement::Prepare(Prepare { name, input, - data_types, + fields, })) => { // The number of parameters must match the specified data types length. - if !data_types.is_empty() { + if !fields.is_empty() { let param_names = input.get_parameter_names()?; - if param_names.len() != data_types.len() { + if param_names.len() != fields.len() { return plan_err!( "Prepare specifies {} data types but query has {} parameters", - data_types.len(), + fields.len(), param_names.len() ); } @@ -726,7 +727,7 @@ impl SessionContext { // not currently feasible. This is because `now()` would be optimized to a // constant value, causing each EXECUTE to yield the same result, which is // incorrect behavior. - self.state.write().store_prepared(name, data_types, input)?; + self.state.write().store_prepared(name, fields, input)?; self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Execute(execute)) => { @@ -1238,28 +1239,28 @@ impl SessionContext { })?; // Only allow literals as parameters for now. - let mut params: Vec = parameters + let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar, _) => Ok(scalar), + Expr::Literal(scalar, metadata) => Ok(Literal::new(scalar, metadata)), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; // If the prepared statement provides data types, cast the params to those types. - if !prepared.data_types.is_empty() { - if params.len() != prepared.data_types.len() { + if !prepared.fields.is_empty() { + if params.len() != prepared.fields.len() { return exec_err!( "Prepared statement '{}' expects {} parameters, but {} provided", name, - prepared.data_types.len(), + prepared.fields.len(), params.len() ); } params = params .into_iter() - .zip(prepared.data_types.iter()) - .map(|(e, dt)| e.cast_to(dt)) + .zip(prepared.fields.iter()) + .map(|(e, dt)| -> Result<_> { e.cast_storage_to(dt.data_type()) }) .collect::>()?; } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 6749ddd7ab8d..44eeb2df3f46 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -30,7 +30,7 @@ use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use arrow::datatypes::DataType; +use arrow_schema::{DataType, FieldRef}; use datafusion_catalog::information_schema::{ InformationSchemaProvider, INFORMATION_SCHEMA, }; @@ -115,11 +115,11 @@ use uuid::Uuid; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let state = SessionStateBuilder::new() -/// .with_config(SessionConfig::new()) +/// .with_config(SessionConfig::new()) /// .with_runtime_env(Arc::new(RuntimeEnv::default())) /// .with_default_features() /// .build(); -/// Ok(()) +/// Ok(()) /// # } /// ``` /// @@ -872,12 +872,12 @@ impl SessionState { pub(crate) fn store_prepared( &mut self, name: String, - data_types: Vec, + fields: Vec, plan: Arc, ) -> datafusion_common::Result<()> { match self.prepared_plans.entry(name) { Entry::Vacant(e) => { - e.insert(Arc::new(PreparedPlan { data_types, plan })); + e.insert(Arc::new(PreparedPlan { fields, plan })); Ok(()) } Entry::Occupied(e) => { @@ -1322,7 +1322,7 @@ impl SessionStateBuilder { /// let url = Url::try_from("file://").unwrap(); /// let object_store = object_store::local::LocalFileSystem::new(); /// let state = SessionStateBuilder::new() - /// .with_config(SessionConfig::new()) + /// .with_config(SessionConfig::new()) /// .with_object_store(&url, Arc::new(object_store)) /// .with_default_features() /// .build(); @@ -2011,7 +2011,7 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> { #[derive(Debug)] pub(crate) struct PreparedPlan { /// Data types of the parameters - pub(crate) data_types: Vec, + pub(crate) fields: Vec, /// The prepared logical plan pub(crate) plan: Arc, } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aa538f6dee81..979ada2bc6bb 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -33,6 +33,7 @@ use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; +use datafusion_common::metadata::FieldMetadata; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -65,15 +66,13 @@ use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, - DataFusionError, ParamValues, ScalarValue, TableReference, UnnestOptions, + DataFusionError, ScalarValue, TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{ - FieldMetadata, GroupingSet, NullTreatment, Sort, WindowFunction, -}; +use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, @@ -2465,7 +2464,7 @@ async fn filtered_aggr_with_param_values() -> Result<()> { let df = ctx .sql("select count (c2) filter (where c3 > $1) from table1") .await? - .with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)])); + .with_param_values(vec![ScalarValue::from(10u64)]); let df_results = df?.collect().await?; assert_snapshot!( diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 98c3e3ccee8a..a8485cd8fa9e 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; + use super::*; -use datafusion_common::ScalarValue; +use datafusion::assert_batches_eq; +use datafusion_common::{metadata::Literal, ParamValues, ScalarValue}; use insta::assert_snapshot; #[tokio::test] @@ -317,6 +320,50 @@ async fn test_named_parameter_not_bound() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_query_parameters_with_metadata() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT $1, $2").await.unwrap(); + + let metadata1 = HashMap::from([("some_key".to_string(), "some_value".to_string())]); + let metadata2 = + HashMap::from([("some_other_key".to_string(), "some_other_value".to_string())]); + + let df_with_params_replaced = df + .with_param_values(ParamValues::List(vec![ + Literal::new(ScalarValue::UInt32(Some(1)), Some(metadata1.clone().into())), + Literal::new( + ScalarValue::Utf8(Some("two".to_string())), + Some(metadata2.clone().into()), + ), + ])) + .unwrap(); + + // df_with_params_replaced.schema() is not correct here + // https://github.com/apache/datafusion/issues/18102 + let batches = df_with_params_replaced.clone().collect().await.unwrap(); + let schema = batches[0].schema(); + + assert_eq!(schema.field(0).data_type(), &DataType::UInt32); + assert_eq!(schema.field(0).metadata(), &metadata1); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); + assert_eq!(schema.field(1).metadata(), &metadata2); + + assert_batches_eq!( + [ + "+----+-----+", + "| $1 | $2 |", + "+----+-----+", + "| 1 | two |", + "+----+-----+", + ], + &batches + ); + + Ok(()) +} + #[tokio::test] async fn test_version_function() { let expected_version = format!( diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index f1af66de9b59..fb1371da6ceb 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -34,13 +34,13 @@ use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionS use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::{as_float64_array, as_int32_array}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 282b3f6a0f55..acdd2cffafe1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -18,7 +18,7 @@ //! Logical Expressions: [`Expr`] use std::cmp::Ordering; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; @@ -32,6 +32,7 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -447,235 +448,6 @@ impl<'a> TreeNodeContainer<'a, Self> for Expr { } } -/// Literal metadata -/// -/// Stores metadata associated with a literal expressions -/// and is designed to be fast to `clone`. -/// -/// This structure is used to store metadata associated with a literal expression, and it -/// corresponds to the `metadata` field on [`Field`]. -/// -/// # Example: Create [`FieldMetadata`] from a [`Field`] -/// ``` -/// # use std::collections::HashMap; -/// # use datafusion_expr::expr::FieldMetadata; -/// # use arrow::datatypes::{Field, DataType}; -/// # let field = Field::new("c1", DataType::Int32, true) -/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); -/// // Create a new `FieldMetadata` instance from a `Field` -/// let metadata = FieldMetadata::new_from_field(&field); -/// // There is also a `From` impl: -/// let metadata = FieldMetadata::from(&field); -/// ``` -/// -/// # Example: Update a [`Field`] with [`FieldMetadata`] -/// ``` -/// # use datafusion_expr::expr::FieldMetadata; -/// # use arrow::datatypes::{Field, DataType}; -/// # let field = Field::new("c1", DataType::Int32, true); -/// # let metadata = FieldMetadata::new_from_field(&field); -/// // Add any metadata from `FieldMetadata` to `Field` -/// let updated_field = metadata.add_to_field(field); -/// ``` -/// -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -pub struct FieldMetadata { - /// The inner metadata of a literal expression, which is a map of string - /// keys to string values. - /// - /// Note this is not a `HashMap` because `HashMap` does not provide - /// implementations for traits like `Debug` and `Hash`. - inner: Arc>, -} - -impl Default for FieldMetadata { - fn default() -> Self { - Self::new_empty() - } -} - -impl FieldMetadata { - /// Create a new empty metadata instance. - pub fn new_empty() -> Self { - Self { - inner: Arc::new(BTreeMap::new()), - } - } - - /// Merges two optional `FieldMetadata` instances, overwriting any existing - /// keys in `m` with keys from `n` if present. - /// - /// This function is commonly used in alias operations, particularly for literals - /// with metadata. When creating an alias expression, the metadata from the original - /// expression (such as a literal) is combined with any metadata specified on the alias. - /// - /// # Arguments - /// - /// * `m` - The first metadata (typically from the original expression like a literal) - /// * `n` - The second metadata (typically from the alias definition) - /// - /// # Merge Strategy - /// - /// - If both metadata instances exist, they are merged with `n` taking precedence - /// - Keys from `n` will overwrite keys from `m` if they have the same name - /// - If only one metadata instance exists, it is returned unchanged - /// - If neither exists, `None` is returned - /// - /// # Example usage - /// ```rust - /// use datafusion_expr::expr::FieldMetadata; - /// use std::collections::BTreeMap; - /// - /// // Create metadata for a literal expression - /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ - /// ("source".to_string(), "constant".to_string()), - /// ("type".to_string(), "int".to_string()), - /// ]))); - /// - /// // Create metadata for an alias - /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ - /// ("description".to_string(), "answer".to_string()), - /// ("source".to_string(), "user".to_string()), // This will override literal's "source" - /// ]))); - /// - /// // Merge the metadata - /// let merged = FieldMetadata::merge_options( - /// literal_metadata.as_ref(), - /// alias_metadata.as_ref(), - /// ); - /// - /// // Result contains: {"source": "user", "type": "int", "description": "answer"} - /// assert!(merged.is_some()); - /// ``` - pub fn merge_options( - m: Option<&FieldMetadata>, - n: Option<&FieldMetadata>, - ) -> Option { - match (m, n) { - (Some(m), Some(n)) => { - let mut merged = m.clone(); - merged.extend(n.clone()); - Some(merged) - } - (Some(m), None) => Some(m.clone()), - (None, Some(n)) => Some(n.clone()), - (None, None) => None, - } - } - - /// Create a new metadata instance from a `Field`'s metadata. - pub fn new_from_field(field: &Field) -> Self { - let inner = field - .metadata() - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self { - inner: Arc::new(inner), - } - } - - /// Create a new metadata instance from a map of string keys to string values. - pub fn new(inner: BTreeMap) -> Self { - Self { - inner: Arc::new(inner), - } - } - - /// Get the inner metadata as a reference to a `BTreeMap`. - pub fn inner(&self) -> &BTreeMap { - &self.inner - } - - /// Return the inner metadata - pub fn into_inner(self) -> Arc> { - self.inner - } - - /// Adds metadata from `other` into `self`, overwriting any existing keys. - pub fn extend(&mut self, other: Self) { - if other.is_empty() { - return; - } - let other = Arc::unwrap_or_clone(other.into_inner()); - Arc::make_mut(&mut self.inner).extend(other); - } - - /// Returns true if the metadata is empty. - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Returns the number of key-value pairs in the metadata. - pub fn len(&self) -> usize { - self.inner.len() - } - - /// Convert this `FieldMetadata` into a `HashMap` - pub fn to_hashmap(&self) -> std::collections::HashMap { - self.inner - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect() - } - - /// Updates the metadata on the Field with this metadata, if it is not empty. - pub fn add_to_field(&self, field: Field) -> Field { - if self.inner.is_empty() { - return field; - } - - field.with_metadata(self.to_hashmap()) - } -} - -impl From<&Field> for FieldMetadata { - fn from(field: &Field) -> Self { - Self::new_from_field(field) - } -} - -impl From> for FieldMetadata { - fn from(inner: BTreeMap) -> Self { - Self::new(inner) - } -} - -impl From> for FieldMetadata { - fn from(map: std::collections::HashMap) -> Self { - Self::new(map.into_iter().collect()) - } -} - -/// From reference -impl From<&std::collections::HashMap> for FieldMetadata { - fn from(map: &std::collections::HashMap) -> Self { - let inner = map - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self::new(inner) - } -} - -/// From hashbrown map -impl From> for FieldMetadata { - fn from(map: HashMap) -> Self { - let inner = map.into_iter().collect(); - Self::new(inner) - } -} - -impl From<&HashMap> for FieldMetadata { - fn from(map: &HashMap) -> Self { - let inner = map - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - Self::new(inner) - } -} - /// The metadata used in [`Field::metadata`]. /// /// This represents the metadata associated with an Arrow [`Field`]. The metadata consists of key-value pairs. @@ -1370,13 +1142,21 @@ pub struct Placeholder { /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with - pub data_type: Option, + pub field: Option, } impl Placeholder { /// Create a new Placeholder expression pub fn new(id: String, data_type: Option) -> Self { - Self { id, data_type } + Self { + id, + field: data_type.map(|dt| Arc::new(Field::new("", dt, true))), + } + } + + /// Create a new Placeholder expression from a Field + pub fn new_with_metadata(id: String, field: Option) -> Self { + Self { id, field } } } @@ -1843,7 +1623,7 @@ impl Expr { /// ``` /// # use datafusion_expr::col; /// # use std::collections::HashMap; - /// # use datafusion_expr::expr::FieldMetadata; + /// # use datafusion_common::metadata::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_with_metadata("bar", Some(metadata)); @@ -1875,7 +1655,7 @@ impl Expr { /// ``` /// # use datafusion_expr::col; /// # use std::collections::HashMap; - /// # use datafusion_expr::expr::FieldMetadata; + /// # use datafusion_common::metadata::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_qualified_with_metadata(Some("tbl"), "bar", Some(metadata)); @@ -2886,19 +2666,23 @@ impl HashNode for Expr { } } -// Modifies expr if it is a placeholder with datatype of right +// Modifies expr to match the DataType, metadata, and nullability of other if it is +// a placeholder with previously unspecified type information (i.e., most placeholders) fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { - if data_type.is_none() { - let other_dt = other.get_type(schema); - match other_dt { + if let Expr::Placeholder(Placeholder { id: _, field }) = expr { + if field.is_none() { + let other_field = other.to_field(schema); + match other_field { Err(e) => { Err(e.context(format!( "Can not find type of {other} needed to infer type of {expr}" )))?; } - Ok(dt) => { - *data_type = Some(dt); + Ok((_, other_field)) => { + // We can't infer the nullability of the future parameter that might + // be bound, so ensure this is set to true + *field = + Some(other_field.as_ref().clone().with_nullable(true).into()); } } }; @@ -3715,8 +3499,8 @@ pub fn physical_name(expr: &Expr) -> Result { mod test { use crate::expr_fn::col; use crate::{ - case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, - ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + case, lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; @@ -3730,15 +3514,15 @@ mod test { let param_placeholders = vec![ Expr::Placeholder(Placeholder { id: "$1".to_string(), - data_type: None, + field: None, }), Expr::Placeholder(Placeholder { id: "$2".to_string(), - data_type: None, + field: None, }), Expr::Placeholder(Placeholder { id: "$3".to_string(), - data_type: None, + field: None, }), ]; let in_list = Expr::InList(InList { @@ -3764,8 +3548,8 @@ mod test { match expr { Expr::Placeholder(placeholder) => { assert_eq!( - placeholder.data_type, - Some(DataType::Int32), + placeholder.field.unwrap().data_type(), + &DataType::Int32, "Placeholder {} should infer Int32", placeholder.id ); @@ -3789,7 +3573,7 @@ mod test { expr: Box::new(col("name")), pattern: Box::new(Expr::Placeholder(Placeholder { id: "$1".to_string(), - data_type: None, + field: None, })), negated: false, case_insensitive: false, @@ -3802,7 +3586,7 @@ mod test { match inferred_expr { Expr::Like(like) => match *like.pattern { Expr::Placeholder(placeholder) => { - assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + assert_eq!(placeholder.field.unwrap().data_type(), &DataType::Utf8); } _ => panic!("Expected Placeholder"), }, @@ -3817,8 +3601,8 @@ mod test { Expr::SimilarTo(like) => match *like.pattern { Expr::Placeholder(placeholder) => { assert_eq!( - placeholder.data_type, - Some(DataType::Utf8), + placeholder.field.unwrap().data_type(), + &DataType::Utf8, "Placeholder {} should infer Utf8", placeholder.id ); @@ -3829,6 +3613,39 @@ mod test { } } + #[test] + fn infer_placeholder_with_metadata() { + // name == $1, where name is a non-nullable string + let schema = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false) + .with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + )])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let expr = binary_expr(col("name"), Operator::Eq, placeholder("$1")); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::BinaryExpr(BinaryExpr { right, .. }) => match *right { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.field.as_ref().unwrap().data_type(), + &DataType::Utf8 + ); + assert_eq!( + placeholder.field.as_ref().unwrap().metadata(), + df_schema.field(0).metadata() + ); + // Inferred placeholder should still be nullable + assert!(placeholder.field.as_ref().unwrap().is_nullable()); + } + _ => panic!("Expected Placeholder"), + }, + _ => panic!("Expected BinaryExpr"), + } + } + #[test] fn format_case_when() -> Result<()> { let expr = case(col("a")) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4666411dd540..c777c4978f99 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -119,13 +119,13 @@ pub fn ident(name: impl Into) -> Expr { /// /// ```rust /// # use datafusion_expr::{placeholder}; -/// let p = placeholder("$0"); // $0, refers to parameter 1 -/// assert_eq!(p.to_string(), "$0") +/// let p = placeholder("$1"); // $1, refers to parameter 1 +/// assert_eq!(p.to_string(), "$1") /// ``` pub fn placeholder(id: impl Into) -> Expr { Expr::Placeholder(Placeholder { id: id.into(), - data_type: None, + field: None, }) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e803e3534130..4f1da2c442a3 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, - InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, + InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ @@ -28,6 +28,7 @@ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -104,9 +105,9 @@ impl ExprSchemable for Expr { fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + Expr::Placeholder(Placeholder { field, .. }) => match &field { None => schema.data_type(&Column::from_name(name)).cloned(), - Some(dt) => Ok(dt.clone()), + Some(field) => Ok(field.data_type().clone()), }, _ => expr.get_type(schema), }, @@ -211,9 +212,9 @@ impl ExprSchemable for Expr { ) .get_result_type(), Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), - Expr::Placeholder(Placeholder { data_type, .. }) => { - if let Some(dtype) = data_type { - Ok(dtype.clone()) + Expr::Placeholder(Placeholder { field, .. }) => { + if let Some(field) = field { + Ok(field.data_type().clone()) } else { // If the placeholder's type hasn't been specified, treat it as // null (unspecified placeholders generate an error during planning) @@ -309,10 +310,12 @@ impl ExprSchemable for Expr { window_function, ) .map(|(_, nullable)| nullable), - Expr::ScalarVariable(_, _) - | Expr::TryCast { .. } - | Expr::Unnest(_) - | Expr::Placeholder(_) => Ok(true), + Expr::Placeholder(Placeholder { id: _, field }) => { + Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true)) + } + Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => { + Ok(true) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -433,18 +436,18 @@ impl ExprSchemable for Expr { .. }) => { let field = match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => { - match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), - Some(dt) => Ok(Field::new( - &schema_name, - dt.clone(), - expr.nullable(schema)?, - )), - } - } + Expr::Placeholder(Placeholder { field, .. }) => match &field { + // TODO: This seems to be pulling metadata/types from the schema + // based on the Alias destination. name? Is this correct? + None => schema + .field_from_column(&Column::from_name(name)) + .map(|f| f.clone().with_name(&schema_name)), + Some(field) => Ok(field + .as_ref() + .clone() + .with_name(&schema_name) + .with_nullable(expr.nullable(schema)?)), + }, _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), }?; @@ -594,6 +597,10 @@ impl ExprSchemable for Expr { .to_field(schema) .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) .map(Arc::new), + Expr::Placeholder(Placeholder { + id: _, + field: Some(field), + }) => Ok(field.as_ref().clone().with_name(&schema_name).into()), Expr::Like(_) | Expr::SimilarTo(_) | Expr::Not(_) @@ -776,10 +783,12 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -905,7 +914,7 @@ mod tests { let schema = DFSchema::from_unqualified_fields( vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(), - std::collections::HashMap::new(), + HashMap::new(), ) .unwrap(); @@ -921,6 +930,53 @@ mod tests { assert_eq!(meta, outer_ref.metadata(&schema).unwrap()); } + #[test] + fn test_expr_placeholder() { + let schema = MockExprSchema::new(); + + let mut placeholder_meta = HashMap::new(); + placeholder_meta.insert("bar".to_string(), "buzz".to_string()); + let placeholder_meta = FieldMetadata::from(placeholder_meta); + + let expr = Expr::Placeholder(Placeholder::new_with_metadata( + "".to_string(), + Some( + Field::new("", DataType::Utf8, true) + .with_metadata(placeholder_meta.to_hashmap()) + .into(), + ), + )); + + assert!(expr.nullable(&schema).unwrap()); + assert_eq!( + expr.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, true) + ); + assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap()); + + let expr_alias = expr.alias("a placeholder by any other name"); + assert_eq!( + expr_alias.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, true) + ); + assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap()); + + // Non-nullable placeholder field should remain non-nullable + let expr = Expr::Placeholder(Placeholder::new_with_metadata( + "".to_string(), + Some(Field::new("", DataType::Utf8, false).into()), + )); + assert_eq!( + expr.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, false) + ); + let expr_alias = expr.alias("a placeholder by any other name"); + assert_eq!( + expr_alias.data_type_and_nullable(&schema).unwrap(), + (DataType::Utf8, false) + ); + } + #[derive(Debug)] struct MockExprSchema { field: Field, diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index c4bd43bc0a62..335d7b471f5f 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -17,9 +17,8 @@ //! Literal module contains foundational types that are used to represent literals in DataFusion. -use crate::expr::FieldMetadata; use crate::Expr; -use datafusion_common::ScalarValue; +use datafusion_common::{metadata::FieldMetadata, ScalarValue}; /// Create a literal expression pub fn lit(n: T) -> Expr { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 7a283b0420d3..a430add3f786 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -25,7 +25,7 @@ use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{Alias, FieldMetadata, PlannedReplaceSelectItem, Sort as SortExpr}; +use crate::expr::{Alias, PlannedReplaceSelectItem, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -50,9 +50,10 @@ use crate::{ use super::dml::InsertOp; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, @@ -622,11 +623,11 @@ impl LogicalPlanBuilder { } /// Make a builder for a prepare logical plan from the builder's plan - pub fn prepare(self, name: String, data_types: Vec) -> Result { + pub fn prepare(self, name: String, fields: Vec) -> Result { Ok(Self::new(LogicalPlan::Statement(Statement::Prepare( Prepare { name, - data_types, + fields, input: self.plan, }, )))) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 05a2564464c5..5e70cbf1ed91 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -51,9 +51,10 @@ use crate::{ WindowFunctionDefinition, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion_common::cse::{NormalizeEq, Normalizeable}; use datafusion_common::format::ExplainFormat; +use datafusion_common::metadata::check_metadata_with_storage_equal; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -1098,15 +1099,13 @@ impl LogicalPlan { })) } LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - .. + name, fields, .. })) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; Ok(LogicalPlan::Statement(Statement::Prepare(Prepare { name: name.clone(), - data_types: data_types.clone(), + fields: fields.clone(), input: Arc::new(input), }))) } @@ -1282,7 +1281,7 @@ impl LogicalPlan { if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) = plan_with_values { - param_values.verify(&prepare_lp.data_types)?; + param_values.verify_fields(&prepare_lp.fields)?; // try and take ownership of the input if is not shared, clone otherwise Arc::unwrap_or_clone(prepare_lp.input) } else { @@ -1463,8 +1462,10 @@ impl LogicalPlan { let original_name = name_preserver.save(&e); let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { - let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value, None))) + let (value, metadata) = param_values + .get_placeholders_with_values(&id)? + .into_inner(); + Ok(Transformed::yes(Expr::Literal(value, metadata))) } else { Ok(Transformed::no(e)) } @@ -1494,24 +1495,43 @@ impl LogicalPlan { } /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes + /// + /// Note that this will drop any extension or field metadata attached to parameters. Use + /// [`LogicalPlan::get_parameter_fields`] to keep extension metadata. pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { - let mut param_types: HashMap> = HashMap::new(); + let mut parameter_fields = self.get_parameter_fields()?; + Ok(parameter_fields + .drain() + .map(|(name, maybe_field)| { + (name, maybe_field.map(|field| field.data_type().clone())) + }) + .collect()) + } + + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and FieldRefs + pub fn get_parameter_fields( + &self, + ) -> Result>, DataFusionError> { + let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { plan.apply_expressions(|expr| { expr.apply(|expr| { - if let Expr::Placeholder(Placeholder { id, data_type }) = expr { + if let Expr::Placeholder(Placeholder { id, field }) = expr { let prev = param_types.get(id); - match (prev, data_type) { - (Some(Some(prev)), Some(dt)) => { - if prev != dt { - plan_err!("Conflicting types for {id}")?; - } + match (prev, field) { + (Some(Some(prev)), Some(field)) => { + check_metadata_with_storage_equal( + (field.data_type(), Some(field.metadata())), + (prev.data_type(), Some(prev.metadata())), + "parameter", + &format!(": Conflicting types for id {id}"), + )?; } - (_, Some(dt)) => { - param_types.insert(id.clone(), Some(dt.clone())); + (_, Some(field)) => { + param_types.insert(id.clone(), Some(Arc::clone(field))); } _ => { param_types.insert(id.clone(), None); @@ -4231,6 +4251,7 @@ mod tests { binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, }; + use datafusion_common::metadata::Literal; use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; @@ -4771,6 +4792,38 @@ mod tests { .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + #[test] + fn test_replace_placeholder_mismatched_metadata() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + // Create a prepared statement with explicit fields that do not have metadata + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$1"))) + .unwrap() + .build() + .unwrap(); + let prepared_builder = LogicalPlanBuilder::new(plan) + .prepare( + "".to_string(), + vec![Field::new("", DataType::Int32, true).into()], + ) + .unwrap(); + + // Attempt to bind a parameter with metadata + let mut scalar_meta = HashMap::new(); + scalar_meta.insert("some_key".to_string(), "some_value".to_string()); + let param_values = ParamValues::List(vec![Literal::new( + ScalarValue::Int32(Some(42)), + Some(scalar_meta.into()), + )]); + prepared_builder + .plan() + .clone() + .with_param_values(param_values) + .expect_err("prepared field metadata mismatch unexpectedly succeeded"); + } + #[test] fn test_nullable_schema_after_grouping_set() { let schema = Schema::new(vec![ @@ -5143,7 +5196,7 @@ mod tests { .unwrap(); // Check that the placeholder parameters have not received a DataType. - let params = plan.get_parameter_types().unwrap(); + let params = plan.get_parameter_fields().unwrap(); assert_eq!(params.len(), 1); let parameter_type = params.clone().get(placeholder_value).unwrap().clone(); diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 6d3fe9fa75ac..bfc6b53d1136 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; +use arrow::datatypes::FieldRef; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::{DFSchema, DFSchemaRef}; use itertools::Itertools as _; use std::fmt::{self, Display}; @@ -108,10 +109,18 @@ impl Statement { }) => { write!(f, "SetVariable: set {variable:?} to {value:?}") } - Statement::Prepare(Prepare { - name, data_types, .. - }) => { - write!(f, "Prepare: {name:?} [{}]", data_types.iter().join(", ")) + Statement::Prepare(Prepare { name, fields, .. }) => { + write!( + f, + "Prepare: {name:?} [{}]", + fields + .iter() + .map(|f| format_type_and_metadata( + f.data_type(), + Some(f.metadata()) + )) + .join(", ") + ) } Statement::Execute(Execute { name, parameters, .. @@ -192,7 +201,7 @@ pub struct Prepare { /// The name of the statement pub name: String, /// Data types of the parameters ([`Expr::Placeholder`]) - pub data_types: Vec, + pub fields: Vec, /// The logical plan of the statements pub input: Arc, } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c40906239073..204ce14e37d8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ @@ -57,7 +58,6 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 6e425ee439d6..94e91d43a1c4 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -28,8 +28,8 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::FieldMetadata; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 73df60c42e96..7790380dffd5 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,13 +25,12 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{ - Alias, Cast, FieldMetadata, InList, Placeholder, ScalarFunction, -}; +use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 11103472ae2a..8dd35188d0f4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -181,6 +181,7 @@ message PrepareNode { string name = 1; repeated datafusion_common.ArrowType data_types = 2; LogicalPlanNode input = 3; + repeated datafusion_common.Field fields = 4; } message CreateCatalogSchemaNode { @@ -413,6 +414,8 @@ message Wildcard { message PlaceholderNode { string id = 1; datafusion_common.ArrowType data_type = 2; + optional bool nullable = 3; + map metadata = 4; } message LogicalExprList { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b34da2c312de..4cf834d0601e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18434,6 +18434,12 @@ impl serde::Serialize for PlaceholderNode { if self.data_type.is_some() { len += 1; } + if self.nullable.is_some() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; if !self.id.is_empty() { struct_ser.serialize_field("id", &self.id)?; @@ -18441,6 +18447,12 @@ impl serde::Serialize for PlaceholderNode { if let Some(v) = self.data_type.as_ref() { struct_ser.serialize_field("dataType", v)?; } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } struct_ser.end() } } @@ -18454,12 +18466,16 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { "id", "data_type", "dataType", + "nullable", + "metadata", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Id, DataType, + Nullable, + Metadata, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18483,6 +18499,8 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { match value { "id" => Ok(GeneratedField::Id), "dataType" | "data_type" => Ok(GeneratedField::DataType), + "nullable" => Ok(GeneratedField::Nullable), + "metadata" => Ok(GeneratedField::Metadata), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18504,6 +18522,8 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { { let mut id__ = None; let mut data_type__ = None; + let mut nullable__ = None; + let mut metadata__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Id => { @@ -18518,11 +18538,27 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { } data_type__ = map_.next_value()?; } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } } } Ok(PlaceholderNode { id: id__.unwrap_or_default(), data_type: data_type__, + nullable: nullable__, + metadata: metadata__.unwrap_or_default(), }) } } @@ -18889,6 +18925,9 @@ impl serde::Serialize for PrepareNode { if self.input.is_some() { len += 1; } + if !self.fields.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -18899,6 +18938,9 @@ impl serde::Serialize for PrepareNode { if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } struct_ser.end() } } @@ -18913,6 +18955,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { "data_types", "dataTypes", "input", + "fields", ]; #[allow(clippy::enum_variant_names)] @@ -18920,6 +18963,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { Name, DataTypes, Input, + Fields, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18944,6 +18988,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { "name" => Ok(GeneratedField::Name), "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), "input" => Ok(GeneratedField::Input), + "fields" => Ok(GeneratedField::Fields), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18966,6 +19011,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { let mut name__ = None; let mut data_types__ = None; let mut input__ = None; + let mut fields__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -18986,12 +19032,19 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { } input__ = map_.next_value()?; } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } } } Ok(PrepareNode { name: name__.unwrap_or_default(), data_types: data_types__.unwrap_or_default(), input: input__, + fields: fields__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2e1c482db65c..6bdbd46476fb 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -282,6 +282,8 @@ pub struct PrepareNode { pub data_types: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "3")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub fields: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateCatalogSchemaNode { @@ -653,6 +655,13 @@ pub struct PlaceholderNode { pub id: ::prost::alloc::string::String, #[prost(message, optional, tag = "2")] pub data_type: ::core::option::Option, + #[prost(bool, optional, tag = "3")] + pub nullable: ::core::option::Option, + #[prost(map = "string, string", tag = "4")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExprList { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ec6415adc4c9..c85b1316c5bf 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use arrow::datatypes::Field; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, @@ -628,12 +629,22 @@ pub fn parse_expr( ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), )), - ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { + ExprType::Placeholder(PlaceholderNode { + id, + data_type, + nullable, + metadata, + }) => match data_type { None => Ok(Expr::Placeholder(Placeholder::new(id.clone(), None))), - Some(data_type) => Ok(Expr::Placeholder(Placeholder::new( - id.clone(), - Some(data_type.try_into()?), - ))), + Some(data_type) => { + let field = + Field::new("", data_type.try_into()?, nullable.unwrap_or(true)) + .with_metadata(metadata.clone()); + Ok(Expr::Placeholder(Placeholder::new_with_metadata( + id.clone(), + Some(field.into()), + ))) + } }, } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index fd9e07914b07..9f250c86dcf8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,7 +33,7 @@ use crate::{ }; use crate::protobuf::{proto_error, ToProtoError}; -use arrow::datatypes::{DataType, Schema, SchemaBuilder, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; use datafusion::datasource::cte_worktable::CteWorkTable; use datafusion::datasource::file_format::arrow::ArrowFormat; #[cfg(feature = "avro")] @@ -888,9 +888,33 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(DataType::try_from) .collect::>()?; - LogicalPlanBuilder::from(input) - .prepare(prepare.name.clone(), data_types)? - .build() + let fields: Vec = prepare + .fields + .iter() + .map(Field::try_from) + .collect::>()?; + + // If the fields are empty this may have been generated by an + // earlier version of DataFusion, in which case the DataTypes + // can be used to construct the plan. + if fields.is_empty() { + LogicalPlanBuilder::from(input) + .prepare( + prepare.name.clone(), + data_types + .into_iter() + .map(|dt| Field::new("", dt, true).into()) + .collect(), + )? + .build() + } else { + LogicalPlanBuilder::from(input) + .prepare( + prepare.name.clone(), + fields.into_iter().map(|f| f.into()).collect(), + )? + .build() + } } LogicalPlanType::DropView(dropview) => { Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { @@ -1621,7 +1645,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlan::Statement(Statement::Prepare(Prepare { name, - data_types, + fields, input, })) => { let input = @@ -1630,11 +1654,17 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Prepare(Box::new( protobuf::PrepareNode { name: name.clone(), - data_types: data_types + input: Some(Box::new(input)), + // Store the DataTypes for reading by older DataFusion + data_types: fields .iter() - .map(|t| t.try_into()) + .map(|f| f.data_type().try_into()) + .collect::, _>>()?, + // Store the Fields for current and future DataFusion + fields: fields + .iter() + .map(|f| f.as_ref().try_into()) .collect::, _>>()?, - input: Some(Box::new(input)), }, ))), }) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6238c2f1cdde..2774b5b6ba7c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -608,18 +608,20 @@ pub fn serialize_expr( })), } } - Expr::Placeholder(Placeholder { id, data_type }) => { - let data_type = match data_type { - Some(data_type) => Some(data_type.try_into()?), - None => None, - }; - protobuf::LogicalExprNode { - expr_type: Some(ExprType::Placeholder(PlaceholderNode { - id: id.clone(), - data_type, - })), - } - } + Expr::Placeholder(Placeholder { id, field }) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::Placeholder(PlaceholderNode { + id: id.clone(), + data_type: match field { + Some(field) => Some(field.data_type().try_into()?), + None => None, + }, + nullable: field.as_ref().map(|f| f.is_nullable()), + metadata: field + .as_ref() + .map(|f| f.metadata().clone()) + .unwrap_or(HashMap::new()), + })), + }, }; Ok(expr_node) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3d51038eba72..7964eef07338 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1071,6 +1071,37 @@ async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_prepared_statement_with_metadata() -> Result<()> { + let ctx = SessionContext::new(); + + let plan = ctx + .sql("SELECT $1") + .await + .unwrap() + .into_optimized_plan() + .unwrap(); + let prepared = LogicalPlanBuilder::new(plan) + .prepare( + "".to_string(), + vec![Field::new("", DataType::Int32, true) + .with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + ) + .into()], + ) + .unwrap() + .plan() + .clone(); + + let bytes = logical_plan_to_bytes(&prepared)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{prepared}"), format!("{logical_round_trip}")); + // Also check exact equality in case metadata isn't reflected in the display string + assert_eq!(&prepared, &logical_round_trip); + Ok(()) +} + pub mod proto { #[derive(Clone, PartialEq, ::prost::Message)] pub struct TopKPlanProto { diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index c9ef4377d43b..3040bdba3b64 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -18,10 +18,11 @@ use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion::execution::FunctionRegistry; use datafusion::prelude::SessionContext; +use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, create_udf, lit, ColumnarValue}; use datafusion_expr::{Expr, Volatility}; use datafusion_functions::string; @@ -136,6 +137,21 @@ fn roundtrip_qualified_alias() { assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); } +#[test] +fn roundtrip_placeholder_with_metadata() { + let expr = Expr::Placeholder(Placeholder::new_with_metadata( + "placeholder_id".to_string(), + Some( + Field::new("", DataType::Utf8, false) + .with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + ) + .into(), + ), + )); + assert_eq!(expr, roundtrip_expr(&expr)); +} + #[test] fn roundtrip_deeply_nested_binary_expr() { // We need more stack space so this doesn't overflow in dev builds diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 23426701409e..5cff1e349b18 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -287,7 +287,7 @@ impl SqlToRel<'_, S> { schema, planner_context, )?), - self.convert_data_type(&data_type)?, + self.convert_data_type(&data_type)?.data_type().clone(), ))) } @@ -297,7 +297,7 @@ impl SqlToRel<'_, S> { uses_odbc_syntax: _, }) => Ok(Expr::Cast(Cast::new( Box::new(lit(value.into_string().unwrap())), - self.convert_data_type(&data_type)?, + self.convert_data_type(&data_type)?.data_type().clone(), ))), SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( @@ -974,7 +974,7 @@ impl SqlToRel<'_, S> { // numeric constants are treated as seconds (rather as nanoseconds) // to align with postgres / duckdb semantics - let expr = match &dt { + let expr = match dt.data_type() { DataType::Timestamp(TimeUnit::Nanosecond, tz) if expr.get_type(schema)? == DataType::Int64 => { @@ -986,7 +986,12 @@ impl SqlToRel<'_, S> { _ => expr, }; - Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + // Currently drops metadata attached to the type + // https://github.com/apache/datafusion/issues/18060 + Ok(Expr::Cast(Cast::new( + Box::new(expr), + dt.data_type().clone(), + ))) } /// Extracts the root expression and access chain from a compound expression. diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 7075a1afd9dd..d71f3e3f9ea9 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -20,7 +20,7 @@ use arrow::compute::kernels::cast_utils::{ parse_interval_month_day_nano_config, IntervalParseConfig, IntervalUnit, }; use arrow::datatypes::{ - i256, DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + i256, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use bigdecimal::num_bigint::BigInt; use bigdecimal::{BigDecimal, Signed, ToPrimitive}; @@ -45,7 +45,7 @@ impl SqlToRel<'_, S> { pub(crate) fn parse_value( &self, value: Value, - param_data_types: &[DataType], + param_data_types: &[FieldRef], ) -> Result { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), @@ -108,7 +108,7 @@ impl SqlToRel<'_, S> { /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on. fn create_placeholder_expr( param: String, - param_data_types: &[DataType], + param_data_types: &[FieldRef], ) -> Result { // Parse the placeholder as a number because it is the only support from sqlparser and postgres let index = param[1..].parse::(); @@ -133,7 +133,7 @@ impl SqlToRel<'_, S> { // Data type of the parameter debug!("type of param {param} param_data_types[idx]: {param_type:?}"); - Ok(Expr::Placeholder(Placeholder::new( + Ok(Expr::Placeholder(Placeholder::new_with_metadata( param, param_type.cloned(), ))) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index e93c5e066b66..05bb1121131b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -256,7 +256,7 @@ impl IdentNormalizer { pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement - prepare_param_data_types: Arc>, + prepare_param_data_types: Arc>, /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, @@ -290,7 +290,7 @@ impl PlannerContext { /// Update the PlannerContext with provided prepare_param_data_types pub fn with_prepare_param_data_types( mut self, - prepare_param_data_types: Vec, + prepare_param_data_types: Vec, ) -> Self { self.prepare_param_data_types = prepare_param_data_types.into(); self @@ -347,7 +347,7 @@ impl PlannerContext { } /// Return the types of parameters (`$1`, `$2`, etc) if known - pub fn prepare_param_data_types(&self) -> &[DataType] { + pub fn prepare_param_data_types(&self) -> &[FieldRef] { &self.prepare_param_data_types } @@ -433,11 +433,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .options .iter() .any(|x| x.option == ColumnOption::NotNull); - fields.push(Field::new( - self.ident_normalizer.normalize(column.name), - data_type, - !not_nullable, - )); + fields.push( + data_type + .as_ref() + .clone() + .with_name(self.ident_normalizer.normalize(column.name)) + .with_nullable(!not_nullable), + ); } Ok(Schema::new(fields)) @@ -587,11 +589,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } - pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { // First check if any of the registered type_planner can handle this type if let Some(type_planner) = self.context_provider.get_type_planner() { if let Some(data_type) = type_planner.plan_type(sql_type)? { - return Ok(data_type); + return Ok(Field::new("", data_type, true).into()); } } @@ -600,7 +602,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; - Ok(DataType::new_list(inner_data_type, true)) + + // Lists are allowed to have an arbitrarily named field; + // however, a name other than 'item' will cause it to fail an + // == check against a more idiomatically created list in + // arrow-rs which causes issues. We use the low-level + // constructor here to preserve extension metadata from the + // child type. + Ok(Field::new( + "", + DataType::List( + inner_data_type.as_ref().clone().with_name("item").into(), + ), + true, + ) + .into()) } SQLDataType::Array(ArrayElemTypeDef::SquareBracket( inner_sql_type, @@ -608,19 +624,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )) => { let inner_data_type = self.convert_data_type(inner_sql_type)?; if let Some(array_size) = maybe_array_size { - Ok(DataType::new_fixed_size_list( - inner_data_type, - *array_size as i32, + Ok(Field::new( + "", + DataType::FixedSizeList( + inner_data_type.as_ref().clone().with_name("item").into(), + *array_size as i32, + ), true, - )) + ) + .into()) } else { - Ok(DataType::new_list(inner_data_type, true)) + Ok(Field::new( + "", + DataType::List( + inner_data_type.as_ref().clone().with_name("item").into(), + ), + true, + ) + .into()) } } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") } - other => self.convert_simple_data_type(other), + other => { + Ok(Field::new("", self.convert_simple_data_type(other)?, true).into()) + } } } @@ -739,11 +768,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(ident) => ident.clone(), None => Ident::new(format!("c{idx}")), }; - Ok(Arc::new(Field::new( - self.ident_normalizer.normalize(field_name), - data_type, - true, - ))) + Ok(data_type.as_ref().clone().with_name(self.ident_normalizer.normalize(field_name))) }) .collect::>>()?; Ok(DataType::Struct(Fields::from(fields))) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 0e868e8c2689..4338285d678b 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -29,7 +29,7 @@ use crate::planner::{ }; use crate::utils::normalize_ident; -use arrow::datatypes::{DataType, Fields}; +use arrow::datatypes::{Field, FieldRef, Fields}; use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ @@ -730,14 +730,14 @@ impl SqlToRel<'_, S> { statement, } => { // Convert parser data types to DataFusion data types - let mut data_types: Vec = data_types + let mut fields: Vec = data_types .into_iter() .map(|t| self.convert_data_type(&t)) .collect::>()?; // Create planner context with parameters - let mut planner_context = PlannerContext::new() - .with_prepare_param_data_types(data_types.clone()); + let mut planner_context = + PlannerContext::new().with_prepare_param_data_types(fields.clone()); // Build logical plan for inner statement of the prepare statement let plan = self.sql_statement_to_plan_with_context_impl( @@ -745,21 +745,21 @@ impl SqlToRel<'_, S> { &mut planner_context, )?; - if data_types.is_empty() { - let map_types = plan.get_parameter_types()?; + if fields.is_empty() { + let map_types = plan.get_parameter_fields()?; let param_types: Vec<_> = (1..=map_types.len()) .filter_map(|i| { let key = format!("${i}"); map_types.get(&key).and_then(|opt| opt.clone()) }) .collect(); - data_types.extend(param_types.iter().cloned()); + fields.extend(param_types.iter().cloned()); planner_context.with_prepare_param_data_types(param_types); } Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), - data_types, + fields, input: Arc::new(plan), }))) } @@ -1203,7 +1203,7 @@ impl SqlToRel<'_, S> { Ok(OperateFunctionArg { name: arg.name, default_expr, - data_type, + data_type: data_type.data_type().clone(), }) }) .collect::>>(); @@ -1221,7 +1221,9 @@ impl SqlToRel<'_, S> { // Convert resulting expression to data fusion expression // let arg_types = args.as_ref().map(|arg| { - arg.iter().map(|t| t.data_type.clone()).collect::>() + arg.iter() + .map(|t| Arc::new(Field::new("", t.data_type.clone(), true))) + .collect::>() }); let mut planner_context = PlannerContext::new() .with_prepare_param_data_types(arg_types.unwrap_or_default()); @@ -1264,7 +1266,7 @@ impl SqlToRel<'_, S> { or_replace, temporary, name, - return_type, + return_type: return_type.map(|f| f.data_type().clone()), args, params, schema: DFSchemaRef::new(DFSchema::empty()), @@ -1998,10 +2000,10 @@ impl SqlToRel<'_, S> { )?; // Update placeholder's datatype to the type of the target column if let Expr::Placeholder(placeholder) = &mut expr { - placeholder.data_type = placeholder - .data_type + placeholder.field = placeholder + .field .take() - .or_else(|| Some(field.data_type().clone())); + .or_else(|| Some(Arc::clone(field))); } // Cast to target column type, if necessary expr.cast_to(field.data_type(), source.schema())? @@ -2105,8 +2107,7 @@ impl SqlToRel<'_, S> { idx + 1 ) })?; - let dt = field.data_type().clone(); - let _ = prepare_param_data_types.insert(name, dt); + let _ = prepare_param_data_types.insert(name, Arc::clone(field)); } } } diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs index 343a90af3efb..4e8758762af0 100644 --- a/datafusion/sql/tests/cases/params.rs +++ b/datafusion/sql/tests/cases/params.rs @@ -16,8 +16,8 @@ // under the License. use crate::logical_plan; -use arrow::datatypes::DataType; -use datafusion_common::{assert_contains, ParamValues, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{assert_contains, metadata::Literal, ParamValues, ScalarValue}; use datafusion_expr::{LogicalPlan, Prepare, Statement}; use insta::assert_snapshot; use itertools::Itertools as _; @@ -51,11 +51,39 @@ impl ParameterTest<'_> { } } +pub struct ParameterTestWithMetadata<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTestWithMetadata<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_fields().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| ((*k).to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(ParamValues::List(self.param_values.clone())) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { let plan = logical_plan(sql).unwrap(); let data_types = match &plan { - LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { - data_types.iter().join(", ").to_string() + LogicalPlan::Statement(Statement::Prepare(Prepare { fields, .. })) => { + fields.iter().map(|f| f.data_type()).join(", ").to_string() } _ => panic!("Expected a Prepare statement"), }; @@ -704,6 +732,147 @@ fn test_prepare_statement_to_plan_one_param() { ); } +#[test] +fn test_update_infer_with_metadata() { + // Here the uuid field is inferred as nullable because it appears in the filter + // (and not in the update values, where its nullability would be inferred) + let uuid_field = Field::new("", DataType::FixedSizeBinary(16), true).with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())].into(), + ); + let uuid_bytes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + let expected_types = vec![ + ( + "$1", + Some(Field::new("last_name", DataType::Utf8, false).into()), + ), + ("$2", Some(uuid_field.clone().with_name("id").into())), + ]; + let param_values = vec![ + Literal::new(ScalarValue::from("Turing"), None), + Literal::new( + ScalarValue::FixedSizeBinary(16, Some(uuid_bytes)), + Some(uuid_field.metadata().into()), + ), + ]; + + // Check a normal update + let test = ParameterTestWithMetadata { + sql: "update person_with_uuid_extension set last_name=$1 where id=$2", + expected_types: expected_types.clone(), + param_values: param_values.clone(), + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Dml: op=[Update] table=[person_with_uuid_extension] + Projection: person_with_uuid_extension.id AS id, person_with_uuid_extension.first_name AS first_name, $1 AS last_name + Filter: person_with_uuid_extension.id = $2 + TableScan: person_with_uuid_extension + ** Final Plan: + Dml: op=[Update] table=[person_with_uuid_extension] + Projection: person_with_uuid_extension.id AS id, person_with_uuid_extension.first_name AS first_name, Utf8("Turing") AS last_name + Filter: person_with_uuid_extension.id = FixedSizeBinary(16, "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") FieldMetadata { inner: {"ARROW:extension:name": "arrow.uuid"} } + TableScan: person_with_uuid_extension + "# + ); + + // Check a prepared update + let test = ParameterTestWithMetadata { + sql: "PREPARE my_plan AS update person_with_uuid_extension set last_name=$1 where id=$2", + expected_types, + param_values + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Utf8, FixedSizeBinary(16)<{"ARROW:extension:name": "arrow.uuid"}>] + Dml: op=[Update] table=[person_with_uuid_extension] + Projection: person_with_uuid_extension.id AS id, person_with_uuid_extension.first_name AS first_name, $1 AS last_name + Filter: person_with_uuid_extension.id = $2 + TableScan: person_with_uuid_extension + ** Final Plan: + Dml: op=[Update] table=[person_with_uuid_extension] + Projection: person_with_uuid_extension.id AS id, person_with_uuid_extension.first_name AS first_name, Utf8("Turing") AS last_name + Filter: person_with_uuid_extension.id = FixedSizeBinary(16, "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") FieldMetadata { inner: {"ARROW:extension:name": "arrow.uuid"} } + TableScan: person_with_uuid_extension + "# + ); +} + +#[test] +fn test_insert_infer_with_metadata() { + let uuid_field = Field::new("", DataType::FixedSizeBinary(16), false).with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())].into(), + ); + let uuid_bytes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + let expected_types = vec![ + ("$1", Some(uuid_field.clone().with_name("id").into())), + ( + "$2", + Some(Field::new("first_name", DataType::Utf8, false).into()), + ), + ( + "$3", + Some(Field::new("last_name", DataType::Utf8, false).into()), + ), + ]; + let param_values = vec![ + Literal::new( + ScalarValue::FixedSizeBinary(16, Some(uuid_bytes)), + Some(uuid_field.metadata().into()), + ), + Literal::new(ScalarValue::from("Alan"), None), + Literal::new(ScalarValue::from("Turing"), None), + ]; + + // Check a normal insert + let test = ParameterTestWithMetadata { + sql: "insert into person_with_uuid_extension (id, first_name, last_name) values ($1, $2, $3)", + expected_types: expected_types.clone(), + param_values: param_values.clone() + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Dml: op=[Insert Into] table=[person_with_uuid_extension] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person_with_uuid_extension] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: (FixedSizeBinary(16, "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") FieldMetadata { inner: {"ARROW:extension:name": "arrow.uuid"} } AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); + + // Check a prepared insert + let test = ParameterTestWithMetadata { + sql: "PREPARE my_plan AS insert into person_with_uuid_extension (id, first_name, last_name) values ($1, $2, $3)", + expected_types, + param_values + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [FixedSizeBinary(16)<{"ARROW:extension:name": "arrow.uuid"}>, Utf8, Utf8] + Dml: op=[Insert Into] table=[person_with_uuid_extension] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person_with_uuid_extension] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name + Values: (FixedSizeBinary(16, "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") FieldMetadata { inner: {"ARROW:extension:name": "arrow.uuid"} } AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + #[test] fn test_prepare_statement_to_plan_data_type() { let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index ee1b761970de..5d9fd9f2c374 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -151,6 +151,14 @@ impl ContextProvider for MockContextProvider { ), Field::new("😀", DataType::Int32, false), ])), + "person_with_uuid_extension" => Ok(Schema::new(vec![ + Field::new("id", DataType::FixedSizeBinary(16), false).with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + ])), "orders" => Ok(Schema::new(vec![ Field::new("order_id", DataType::UInt32, false), Field::new("customer_id", DataType::UInt32, false),