diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 8b2b02670449..cbbd3b54ea9e 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -70,28 +70,31 @@ impl RunOpt { let sort_cases = vec![ ( "sort utf8", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("request_method", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort int", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("response_bytes", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort decimal", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("decimal_price", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort integer tuple", - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("request_bytes", &schema)?, options: Default::default(), @@ -100,11 +103,12 @@ impl RunOpt { expr: col("response_bytes", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ( "sort utf8 tuple", - LexOrdering::new(vec![ + [ // sort utf8 tuple PhysicalSortExpr { expr: col("service", &schema)?, @@ -122,11 +126,12 @@ impl RunOpt { expr: col("image", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ( "sort mixed tuple", - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("service", &schema)?, options: Default::default(), @@ -139,7 +144,8 @@ impl RunOpt { expr: col("decimal_price", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ]; for (title, expr) in sort_cases { diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index ac1e64351768..e9a4d71b1633 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -21,27 +21,24 @@ use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type}, }; -use datafusion::physical_expr::LexRequirement; use datafusion::{ catalog::Session, common::{GetExt, Statistics}, -}; -use datafusion::{ - datasource::physical_plan::FileSource, execution::session_state::SessionStateBuilder, -}; -use datafusion::{ datasource::{ file_format::{ csv::CsvFormatFactory, file_compression_type::FileCompressionType, FileFormat, FileFormatFactory, }, - physical_plan::{FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, MemTable, }, error::Result, + execution::session_state::SessionStateBuilder, + physical_expr_common::sort_expr::LexRequirement, physical_plan::ExecutionPlan, prelude::SessionContext, }; + use object_store::{ObjectMeta, ObjectStore}; use tempfile::tempdir; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index e712f4ea8eaa..21da35963345 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -150,10 +150,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - fn aliases(&self) -> &[String] { - &[] - } - fn output_ordering(&self, _input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index cc2c2ee606b3..2e4eac964b18 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -25,9 +25,7 @@ use std::sync::{Arc, Mutex}; use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; -use datafusion_common::{ - Constraints, DFSchema, DataFusionError, HashMap, TableReference, -}; +use datafusion_common::{DFSchema, DataFusionError, HashMap, TableReference}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -143,7 +141,7 @@ impl ListingSchemaProvider { order_exprs: vec![], unbounded: false, options: Default::default(), - constraints: Constraints::empty(), + constraints: Default::default(), column_defaults: Default::default(), }, ) diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 81243e2c4889..e996e1974d9e 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -23,25 +23,22 @@ use std::fmt::Debug; use std::sync::Arc; use crate::TableProvider; -use datafusion_common::error::Result; -use datafusion_expr::Expr; -use datafusion_expr::TableType; -use datafusion_physical_expr::create_physical_sort_exprs; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, -}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::error::Result; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_common_runtime::JoinSet; -use datafusion_datasource::memory::MemSink; -use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; -use datafusion_expr::SortExpr; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::{ + common, ExecutionPlan, ExecutionPlanProperties, Partitioning, +}; use datafusion_session::Session; use async_trait::async_trait; @@ -89,7 +86,7 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), sort_order: Arc::new(Mutex::new(vec![])), }) @@ -239,16 +236,13 @@ impl TableProvider for MemTable { if !sort_order.is_empty() { let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?; - let file_sort_order = sort_order - .iter() - .map(|sort_exprs| { - create_physical_sort_exprs( - sort_exprs, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>()?; + let eqp = state.execution_props(); + let mut file_sort_order = vec![]; + for sort_exprs in sort_order.iter() { + let physical_exprs = + create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?; + file_sort_order.extend(LexOrdering::new(physical_exprs)); + } source = source.try_with_sort_information(file_sort_order)?; } diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index fbfab513229e..99c432b738e5 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -256,7 +256,7 @@ impl StreamConfig { Self { source, order: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), } } @@ -350,15 +350,10 @@ impl TableProvider for StreamTable { input: Arc, _insert_op: InsertOp, ) -> Result> { - let ordering = match self.0.order.first() { - Some(x) => { - let schema = self.0.source.schema(); - let orders = create_ordering(schema, std::slice::from_ref(x))?; - let ordering = orders.into_iter().next().unwrap(); - Some(ordering.into_iter().map(Into::into).collect()) - } - None => None, - }; + let schema = self.0.source.schema(); + let orders = create_ordering(schema, &self.0.order)?; + // It is sufficient to pass only one of the equivalent orderings: + let ordering = orders.into_iter().next().map(Into::into); Ok(Arc::new(DataSinkExec::new( input, diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 77e00d6dcda2..737a9c35f705 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -36,35 +36,31 @@ pub enum Constraint { } /// This object encapsulates a list of functional constraints: -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd)] pub struct Constraints { inner: Vec, } impl Constraints { - /// Create empty constraints - pub fn empty() -> Self { - Constraints::new_unverified(vec![]) - } - /// Create a new [`Constraints`] object from the given `constraints`. - /// Users should use the [`Constraints::empty`] or [`SqlToRel::new_constraint_from_table_constraints`] functions - /// for constructing [`Constraints`]. This constructor is for internal - /// purposes only and does not check whether the argument is valid. The user - /// is responsible for supplying a valid vector of [`Constraint`] objects. + /// Users should use the [`Constraints::default`] or [`SqlToRel::new_constraint_from_table_constraints`] + /// functions for constructing [`Constraints`] instances. This constructor + /// is for internal purposes only and does not check whether the argument + /// is valid. The user is responsible for supplying a valid vector of + /// [`Constraint`] objects. /// /// [`SqlToRel::new_constraint_from_table_constraints`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.new_constraint_from_table_constraints pub fn new_unverified(constraints: Vec) -> Self { Self { inner: constraints } } - /// Check whether constraints is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + /// Extends the current constraints with the given `other` constraints. + pub fn extend(&mut self, other: Constraints) { + self.inner.extend(other.inner); } - /// Projects constraints using the given projection indices. - /// Returns None if any of the constraint columns are not included in the projection. + /// Projects constraints using the given projection indices. Returns `None` + /// if any of the constraint columns are not included in the projection. pub fn project(&self, proj_indices: &[usize]) -> Option { let projected = self .inner @@ -74,14 +70,14 @@ impl Constraints { Constraint::PrimaryKey(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::PrimaryKey(new_indices)) } Constraint::Unique(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::Unique(new_indices)) } @@ -93,15 +89,9 @@ impl Constraints { } } -impl Default for Constraints { - fn default() -> Self { - Constraints::empty() - } -} - impl IntoIterator for Constraints { type Item = Constraint; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 0a65c52f72de..e4838572f60f 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -50,11 +50,8 @@ fn sort_preserving_merge_operator( let sort = sort .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: Default::default(), - }) - .collect::(); + .map(|name| PhysicalSortExpr::new_default(col(name, &schema).unwrap())); + let sort = LexOrdering::new(sort).unwrap(); let exec = MemorySourceConfig::try_new_exec( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index e1bc478b36f0..276151e253f7 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -71,7 +71,6 @@ use std::sync::Arc; use arrow::array::StringViewArray; use arrow::{ array::{DictionaryArray, Float64Array, Int64Array, StringArray}, - compute::SortOptions, datatypes::{Int32Type, Schema}, record_batch::RecordBatch, }; @@ -272,14 +271,11 @@ impl BenchCase { /// Make sort exprs for each column in `schema` fn make_sort_exprs(schema: &Schema) -> LexOrdering { - schema + let sort_exprs = schema .fields() .iter() - .map(|f| PhysicalSortExpr { - expr: col(f.name(), schema).unwrap(), - options: SortOptions::default(), - }) - .collect() + .map(|f| PhysicalSortExpr::new_default(col(f.name(), schema).unwrap())); + LexOrdering::new(sort_exprs).unwrap() } /// Create streams of int64 (where approximately 1/3 values is repeated) diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 8613525cb248..d13407864297 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -21,7 +21,6 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -70,7 +69,7 @@ fn generate_spm_for_round_robin_tie_breaker( let partitiones = vec![rbs.clone(); partition_count]; let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -79,7 +78,8 @@ fn generate_spm_for_round_robin_tie_breaker( expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); let exec = MemorySourceConfig::try_new_exec(&partitiones, schema, None).unwrap(); SortPreservingMergeExec::new(sort, exec) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3c87d3ee2329..493175d30b0c 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,43 +17,41 @@ //! The table implementation. -use super::helpers::{expr_applicable_for_cols, pruned_partition_list}; -use super::{ListingTableUrl, PartitionedFile}; use std::collections::HashMap; use std::{any::Any, str::FromStr, sync::Arc}; +use super::helpers::{expr_applicable_for_cols, pruned_partition_list}; +use super::{ListingTableUrl, PartitionedFile}; use crate::datasource::{ create_ordering, file_format::{file_compression_type::FileCompressionType, FileFormat}, physical_plan::FileSinkConfig, }; use crate::execution::context::SessionState; -use datafusion_catalog::TableProvider; -use datafusion_common::{config_err, DataFusionError, Result}; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; -use datafusion_execution::config::SessionConfig; -use datafusion_expr::dml::InsertOp; -use datafusion_expr::{Expr, TableProviderFilterPushDown}; -use datafusion_expr::{SortExpr, TableType}; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{ExecutionPlan, Statistics}; -use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; +use arrow_schema::Schema; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::stats::Precision; use datafusion_common::{ - config_datafusion_err, internal_err, plan_err, project_schema, Constraints, SchemaExt, + config_datafusion_err, config_err, internal_err, plan_err, project_schema, + Constraints, DataFusionError, Result, SchemaExt, }; +use datafusion_datasource::compute_all_files_statistics; +use datafusion_datasource::file_groups::FileGroup; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_execution::cache::{ cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, }; -use datafusion_physical_expr::{LexOrdering, PhysicalSortRequirement}; +use datafusion_execution::config::SessionConfig; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{Expr, SortExpr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::empty::EmptyExec; +use datafusion_physical_plan::{ExecutionPlan, Statistics}; use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_common::stats::Precision; -use datafusion_datasource::compute_all_files_statistics; -use datafusion_datasource::file_groups::FileGroup; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::{future, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; @@ -797,7 +795,7 @@ impl ListingTable { options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; @@ -1053,25 +1051,9 @@ impl TableProvider for ListingTable { file_extension: self.options().format.get_ext(), }; - let order_requirements = if !self.options().file_sort_order.is_empty() { - // Multiple sort orders in outer vec are equivalent, so we pass only the first one - let orderings = self.try_create_output_ordering()?; - let Some(ordering) = orderings.first() else { - return internal_err!( - "Expected ListingTable to have a sort order, but none found!" - ); - }; - // Converts Vec> into type required by execution plan to specify its required input ordering - Some(LexRequirement::new( - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect::>(), - )) - } else { - None - }; + let orderings = self.try_create_output_ordering()?; + // It is sufficient to pass only one of the equivalent orderings: + let order_requirements = orderings.into_iter().next().map(Into::into); self.options() .format @@ -1277,7 +1259,10 @@ mod tests { use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::prelude::*; - use crate::test::{columns, object_store::register_test_store}; + use crate::test::columns; + use crate::test::object_store::{ + ensure_head_concurrency, make_test_store_and_state, register_test_store, + }; use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; @@ -1286,10 +1271,8 @@ mod tests { use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_plan::collect; - use datafusion_physical_plan::ExecutionPlanProperties; + use datafusion_physical_plan::{collect, ExecutionPlanProperties}; - use crate::test::object_store::{ensure_head_concurrency, make_test_store_and_state}; use tempfile::TempDir; use url::Url; @@ -1419,7 +1402,7 @@ mod tests { // (file_sort_order, expected_result) let cases = vec![ - (vec![], Ok(vec![])), + (vec![], Ok(Vec::::new())), // sort expr, but non column ( vec![vec![ @@ -1430,15 +1413,13 @@ mod tests { // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![LexOrdering::new( - vec![PhysicalSortExpr { + Ok(vec![[PhysicalSortExpr { expr: physical_col("string_col", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: false, }, - }], - ) + }].into(), ]) ), // ok with two columns, different options @@ -1447,16 +1428,14 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![LexOrdering::new( - vec![ + Ok(vec![[ PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) .asc() .nulls_last(), PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) .desc() .nulls_first() - ], - ) + ].into(), ]) ), ]; diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 71686c61a8f7..6a88ad88e5d4 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -201,7 +201,7 @@ mod tests { order_exprs: vec![], unbounded: false, options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); @@ -241,7 +241,7 @@ mod tests { order_exprs: vec![], unbounded: false, options, - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 61d1fee79472..e8b9b8b2384e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -55,8 +55,7 @@ use crate::physical_plan::{ displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; @@ -69,6 +68,7 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; +use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ @@ -84,19 +84,21 @@ use datafusion_expr::{ }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; -use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr::{ + create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, +}; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; +use sqlparser::ast::NullTreatment; -use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; -use datafusion_datasource::file_groups::FileGroup; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; -use sqlparser::ast::NullTreatment; use tokio::sync::Mutex; /// Physical query planner that converts a `LogicalPlan` to an @@ -822,13 +824,17 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_expr = create_physical_sort_exprs( + let sort_exprs = create_physical_sort_exprs( expr, input_dfschema, session_state.execution_props(), )?; - let new_sort = - SortExec::new(sort_expr, physical_input).with_fetch(*fetch); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return internal_err!( + "SortExec requires at least one sort expression" + ); + }; + let new_sort = SortExec::new(ordering, physical_input).with_fetch(*fetch); Arc::new(new_sort) } LogicalPlan::Subquery(_) => todo!(), @@ -1538,7 +1544,7 @@ pub fn create_window_expr_with_name( name, &physical_args, &partition_by, - order_by.as_ref(), + &order_by, window_frame, physical_schema, ignore_nulls, @@ -1566,8 +1572,8 @@ type AggregateExprWithOptionalArgs = ( Arc, // The filter clause, if any Option>, - // Ordering requirements, if any - Option, + // Expressions in the ORDER BY clause + Vec, ); /// Create an aggregate expression with a name from a logical expression @@ -1611,22 +1617,19 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( + let (agg_expr, filter, order_bys) = { + let order_bys = match order_by { + Some(exprs) => create_physical_sort_exprs( exprs, logical_input_schema, execution_props, - )?), - None => None, + )?, + None => vec![], }; - let ordering_reqs: LexOrdering = - physical_sort_exprs.clone().unwrap_or_default(); - let agg_expr = AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) - .order_by(ordering_reqs) + .order_by(order_bys.clone()) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) .human_display(human_displan) @@ -1635,10 +1638,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .build() .map(Arc::new)?; - (agg_expr, filter, physical_sort_exprs) + (agg_expr, filter, order_bys) }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, order_bys)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1674,14 +1677,6 @@ pub fn create_aggregate_expr_and_maybe_filter( ) } -#[deprecated( - since = "47.0.0", - note = "use datafusion::{create_physical_sort_expr, create_physical_sort_exprs}" -)] -pub use datafusion_physical_expr::{ - create_physical_sort_expr, create_physical_sort_exprs, -}; - impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aa36de1e555f..2deca63ea9ea 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2581,7 +2581,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { | | TableScan: t1 projection=[b] | | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | | | CoalesceBatchesExec: target_batch_size=8192 | diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 940b85bb96cf..52a6cc481131 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use super::record_batch_generator::get_supported_types_columns; use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, @@ -26,39 +27,35 @@ use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, StringArray, }; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::concat_batches; use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; -use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::MemTable; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::HashMap; +use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_execution::TaskContext; use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::metrics::MetricValue; use datafusion_physical_plan::InputOrderMode; +use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; use test_utils::{add_empty_batches, StringBatchGenerator}; -use datafusion_execution::memory_pool::FairSpillPool; -use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; -use datafusion_physical_plan::metrics::MetricValue; use rand::rngs::StdRng; use rand::{random, rng, Rng, SeedableRng}; -use super::record_batch_generator::get_supported_types_columns; - // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] // ======================================================================== @@ -309,13 +306,9 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); @@ -329,7 +322,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys]) + .try_with_sort_information(vec![sort_keys.into()]) .unwrap(), ); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 82bfe199234e..753a74995d8f 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -149,14 +149,14 @@ impl DatasetGenerator { for sort_keys in self.sort_keys_set.clone() { let sort_exprs = sort_keys .iter() - .map(|key| { - let col_expr = col(key, schema)?; - Ok(PhysicalSortExpr::new_default(col_expr)) - }) - .collect::>()?; - let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; - - let batches = stagger_batch(sorted_batch); + .map(|key| col(key, schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; + let batch = if let Some(ordering) = LexOrdering::new(sort_exprs) { + sort_batch(&base_batch, &ordering, None)? + } else { + base_batch.clone() + }; + let batches = stagger_batch(batch); let dataset = Dataset::new(batches, sort_keys); datasets.push(dataset); } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index d12d0a130c0c..0d500fd7f441 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,13 +16,16 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, is_table_same_after_sort, TestScalarUDF, }; use arrow::compute::SortOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, convert_to_sort_exprs, +}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -55,26 +58,25 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { col("f", &test_schema)?, ]; - for n_req in 0..=col_exprs.len() { + for n_req in 1..=col_exprs.len() { for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties {eq_properties}" + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties {eq_properties}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -125,27 +127,26 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { a_plus_b, ]; - for n_req in 0..=exprs.len() { + for n_req in 1..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}", + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}", ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, (expected | false), "{err_msg}" ); @@ -300,25 +301,19 @@ fn test_ordering_satisfy_with_equivalence() -> Result<()> { ]; for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::(); + let err_msg = format!("Error in test case: {cols:?}"); + let sort_exprs = convert_to_sort_exprs(&cols); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; // Check expected result with experimental result. assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, + is_table_same_after_sort(ordering.clone(), &table_data_with_properties)?, expected ); assert_eq!( - eq_properties.ordering_satisfy(required.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -371,7 +366,7 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC, d ASC] cannot be deduced let ordering = vec![ @@ -380,12 +375,12 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC] can be deduced let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(is_table_same_after_sort(ordering, batch.clone())?); + assert!(is_table_same_after_sort(ordering, &batch)?); Ok(()) } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index 38e66387a02c..d776796a1b75 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -87,10 +87,7 @@ fn project_orderings_random() -> Result<()> { // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, + is_table_same_after_sort(ordering.clone(), &projected_batch)?, "{}", err_msg ); @@ -147,8 +144,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); + .map(|(expr, name)| (Arc::clone(expr), name.to_string())); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), &table_data_with_properties, @@ -156,33 +152,34 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { )?; let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + ProjectionMapping::try_new(proj_exprs, &test_schema)?; let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| Arc::clone(target)) + .flat_map(|(_, targets)| { + targets.iter().map(|(target, _)| Arc::clone(target)) + }) .collect::>(); - for n_req in 0..=projected_exprs.len() { + for n_req in 1..=projected_exprs.len() { for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; + let sort_exprs = exprs.into_iter().map(|expr| { + PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS) + }); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!( + "Test should always produce non-degenerate orderings" + ); + }; + let expected = + is_table_same_after_sort(ordering.clone(), &projected_batch)?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - projected_eq.ordering_satisfy(requirement.as_ref()), + projected_eq.ordering_satisfy(ordering)?, expected, "{err_msg}" ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 9a2146415749..e35ce3a6f8c9 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -15,18 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::fuzz_cases::equivalence::utils::{ create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use itertools::Itertools; -use std::sync::Arc; #[test] fn test_find_longest_permutation_random() -> Result<()> { @@ -47,13 +49,13 @@ fn test_find_longest_permutation_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, - )?) as PhysicalExprRef; + )?) as _; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, col("b", &test_schema)?, - )) as Arc; + )) as _; let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, @@ -68,16 +70,16 @@ fn test_find_longest_permutation_random() -> Result<()> { for n_req in 0..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs)?; // Make sure that find_longest_permutation return values are consistent let ordering2 = indices .iter() .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, + .map(|(&idx, sort_expr)| { + PhysicalSortExpr::new(Arc::clone(&exprs[idx]), sort_expr.options) }) - .collect::(); + .collect::>(); assert_eq!( ordering, ordering2, "indices and lexicographical ordering do not match" @@ -89,11 +91,11 @@ fn test_find_longest_permutation_random() -> Result<()> { assert_eq!(ordering.len(), indices.len(), "{err_msg}"); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). + let Some(ordering) = LexOrdering::new(ordering) else { + continue; + }; assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, + is_table_same_after_sort(ordering, &table_data_with_properties)?, "{}", err_msg ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index a906648f872d..ef80925d9991 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,55 +15,50 @@ // specific language governing permissions and limitations // under the License. -use datafusion::physical_plan::expressions::col; -use datafusion::physical_plan::expressions::Column; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, EquivalenceClass, ProjectionMapping, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::expressions::{col, Column}; use itertools::izip; use rand::prelude::*; +/// Projects the input schema based on the given projection mapping. pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -100,9 +95,9 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -114,18 +109,18 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -133,12 +128,12 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti // Apply projection to the input_data, return projected equivalence properties and record batch pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, + proj_exprs: impl IntoIterator, String)>, input_data: &RecordBatch, input_eq_properties: &EquivalenceProperties, ) -> Result<(RecordBatch, EquivalenceProperties)> { let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let output_schema = output_schema(&projection_mapping, &input_schema)?; let num_rows = input_data.num_rows(); @@ -168,49 +163,49 @@ fn add_equal_conditions_test() -> Result<()> { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } @@ -226,7 +221,7 @@ fn add_equal_conditions_test() -> Result<()> { /// already sorted according to `required_ordering` to begin with. pub fn is_table_same_after_sort( mut required_ordering: LexOrdering, - batch: RecordBatch, + batch: &RecordBatch, ) -> Result { // Clone the original schema and columns let original_schema = batch.schema(); @@ -327,7 +322,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let option_asc = SortOptions { descending: false, @@ -350,7 +345,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } @@ -376,7 +371,7 @@ pub fn generate_table_for_eq_properties( // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; @@ -461,7 +456,7 @@ pub fn generate_table_for_orderings( let batch = RecordBatch::try_from_iter(arrays)?; // Sort batch according to first ordering expression - let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_columns = get_sort_columns(&batch, &orderings[0])?; let sort_indices = lexsort_to_indices(&sort_columns, None)?; let mut batch = take_record_batch(&batch, &sort_indices)?; @@ -494,29 +489,6 @@ pub fn generate_table_for_orderings( Ok(batch) } -// Convert each tuple to PhysicalSortExpr -pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], -) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() -} - -// Convert each inner tuple to PhysicalSortExpr -pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], -) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() -} - // Utility function to generate random f64 array fn generate_random_f64_array( n_elems: usize, diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 92f375525066..b92dec64e3f1 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -31,7 +31,6 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; @@ -109,13 +108,14 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 703b8715821a..27becaa96f62 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -232,18 +232,15 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort_ordering = LexOrdering::new( - self.sort_columns - .iter() - .map(|c| PhysicalSortExpr { - expr: col(c, &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - .collect(), - ); + let sort_ordering = + LexOrdering::new(self.sort_columns.iter().map(|c| PhysicalSortExpr { + expr: col(c, &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + })) + .unwrap(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort_ordering, exec)); diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index cf6867758edc..99b20790fc46 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -23,6 +23,8 @@ mod sp_repartition_fuzz_tests { use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, @@ -34,19 +36,16 @@ mod sp_repartition_fuzz_tests { }; use datafusion::prelude::SessionContext; use datafusion_common::Result; - use datafusion_execution::{ - config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, - }; - use datafusion_physical_expr::{ - equivalence::{EquivalenceClass, EquivalenceProperties}, - expressions::{col, Column}, - ConstExpr, PhysicalExpr, PhysicalSortExpr, + use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::equivalence::{ + EquivalenceClass, EquivalenceProperties, }; + use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -80,9 +79,9 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -94,18 +93,18 @@ mod sp_repartition_fuzz_tests { }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -151,7 +150,7 @@ mod sp_repartition_fuzz_tests { // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -227,21 +226,21 @@ mod sp_repartition_fuzz_tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; let schema = table_data_with_properties.schema(); - let streams: Vec = (0..N_PARTITION) + let streams = (0..N_PARTITION) .map(|_idx| { let batch = table_data_with_properties.clone(); Box::pin(RecordBatchStreamAdapter::new( schema.clone(), futures::stream::once(async { Ok(batch) }), - )) as SendableRecordBatchStream + )) as _ }) .collect::>(); - // Returns concatenated version of the all available orderings - let exprs = eq_properties - .oeq_class() - .output_ordering() - .unwrap_or_default(); + // Returns concatenated version of the all available orderings: + let Some(exprs) = eq_properties.oeq_class().output_ordering() else { + // We always should have an ordering due to the way we generate the schema: + unreachable!("No ordering found in eq_properties: {:?}", eq_properties); + }; let context = SessionContext::new().task_ctx(); let mem_reservation = @@ -347,20 +346,16 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); let running_source = Arc::new( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(&[input1], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys.clone()]) + .try_with_sort_information(vec![sort_keys.clone().into()]) .unwrap(), ); let running_source = Arc::new(DataSourceExec::new(running_source)); @@ -381,7 +376,7 @@ mod sp_repartition_fuzz_tests { sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone()) }; - let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate); + let final_plan = sort_preserving_merge_exec(sort_keys.into(), intermediate); let task_ctx = ctx.task_ctx(); let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap(); @@ -428,10 +423,9 @@ mod sp_repartition_fuzz_tests { } fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd2e457b42a..316d3ba5a926 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -252,7 +252,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -285,7 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { fn_name.to_string(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &[], Arc::new(window_frame), &extended_schema, false, @@ -594,7 +593,7 @@ async fn run_window_test( let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = LexOrdering::default(); + let mut orderby_exprs = vec![]; for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, @@ -602,13 +601,13 @@ async fn run_window_test( }) } if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { - orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); + orderby_exprs.truncate(1); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = LexOrdering::default(); + let mut sort_keys = vec![]; for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), @@ -622,7 +621,7 @@ async fn run_window_test( } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = LexOrdering::new(vec![ + let source_sort_keys: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -635,7 +634,8 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let mut exec1 = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[vec![concat_input_record]], schema.clone(), None)? .try_with_sort_information(vec![source_sort_keys.clone()])?, @@ -643,7 +643,9 @@ async fn run_window_test( // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. if is_linear { - exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; + if let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; + } } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -654,7 +656,7 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs.clone(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -663,8 +665,8 @@ async fn run_window_test( false, )?) as _; let exec2 = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None)? - .try_with_sort_information(vec![source_sort_keys.clone()])?, + MemorySourceConfig::try_new(&[input1], schema, None)? + .try_with_sort_information(vec![source_sort_keys])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -672,7 +674,7 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame.clone()), &extended_schema, false, diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7695cc0969d8..4ac779d9c220 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -890,16 +890,11 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("a", &schema).unwrap(), - options, - }, - PhysicalSortExpr { - expr: col("b", &schema).unwrap(), - options, - }, - ])]; + let sort_information = vec![[ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into()]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 568be0d18f24..de8f7b36a3a7 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -136,7 +136,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema.clone()), PhysicalGroupBy::default(), aggr_expr.clone(), )), @@ -157,7 +157,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr1, ), @@ -183,7 +183,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr.clone(), ), @@ -215,11 +215,8 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -245,11 +242,8 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 4034800c30cb..705450086525 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -20,11 +20,10 @@ use std::ops::Deref; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_partitions_exec, repartition_exec, schema, - sort_merge_join_exec, sort_preserving_merge_exec, -}; -use crate::physical_optimizer::test_utils::{ - parquet_exec_with_sort, parquet_exec_with_stats, + check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, + parquet_exec_with_stats, repartition_exec, schema, sort_exec, + sort_exec_with_preserve_partitioning, sort_merge_join_exec, + sort_preserving_merge_exec, union_exec, }; use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; @@ -44,12 +43,11 @@ use datafusion_common::ScalarValue; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::{ - expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, +use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; @@ -65,13 +63,11 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::JoinOn; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, Statistics, + get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + PlanProperties, Statistics, }; /// Models operators like BoundedWindowExec that require an input @@ -147,12 +143,8 @@ impl ExecutionPlan for SortRequiredExec { } // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec> { - if self.expr.is_empty() { - vec![None] - } else { - vec![Some(LexRequirement::from(self.expr.clone()))] - } + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn with_new_children( @@ -181,7 +173,7 @@ impl ExecutionPlan for SortRequiredExec { } fn parquet_exec() -> Arc { - parquet_exec_with_sort(vec![]) + parquet_exec_with_sort(schema(), vec![]) } fn parquet_exec_multiple() -> Arc { @@ -327,16 +319,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn sort_exec( - sort_exprs: LexOrdering, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let new_sort = SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning); - Arc::new(new_sort) -} - fn limit_exec(input: Arc) -> Arc { Arc::new(GlobalLimitExec::new( Arc::new(LocalLimitExec::new(input, 100)), @@ -345,10 +327,6 @@ fn limit_exec(input: Arc) -> Arc { )) } -fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - fn sort_required_exec_with_req( input: Arc, sort_exprs: LexOrdering, @@ -1742,10 +1720,11 @@ fn smj_join_key_ordering() -> Result<()> { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); @@ -1942,11 +1921,12 @@ fn repartition_unsorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); + }] + .into(); + let plan = limit_exec(sort_exec(sort_key, parquet_exec())); let expected = &[ "GlobalLimitExec: skip=0, fetch=100", @@ -1966,12 +1946,13 @@ fn repartition_sorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), + filter_exec(sort_exec(sort_key.clone(), parquet_exec())), sort_key, ); @@ -2049,10 +2030,11 @@ fn repartition_ignores_union() -> Result<()> { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); // need resort as the data was not sorted correctly @@ -2072,10 +2054,11 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), @@ -2107,11 +2090,15 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + }] + .into(); + let input = union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let plan = sort_preserving_merge_exec(sort_key, input); // Test: run EnforceDistribution, then EnforceSort. @@ -2145,12 +2132,13 @@ fn repartition_does_not_destroy_sort() -> Result<()> { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), + filter_exec(parquet_exec_with_sort(schema, vec![sort_key.clone()])), sort_key, ); @@ -2183,12 +2171,13 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { // Parquet(unsorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input1 = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key, ); let input2 = filter_exec(parquet_exec()); @@ -2219,18 +2208,19 @@ fn repartition_transitively_with_projection() -> Result<()> { let schema = schema(); let proj_exprs = vec![( Arc::new(BinaryExpr::new( - col("a", &schema).unwrap(), + col("a", &schema)?, Operator::Plus, - col("b", &schema).unwrap(), - )) as Arc, + col("b", &schema)?, + )) as _, "sum".to_string(), )]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("sum", &proj.schema()).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("sum", &proj.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, proj); // Test: run EnforceDistribution, then EnforceSort. @@ -2262,10 +2252,11 @@ fn repartition_transitively_with_projection() -> Result<()> { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2297,10 +2288,11 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2308,10 +2300,9 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { ]; let plan = sort_preserving_merge_exec( sort_key.clone(), - sort_exec( + sort_exec_with_preserve_partitioning( sort_key, projection_exec_with_alias(parquet_exec(), alias), - true, ), ); @@ -2332,11 +2323,12 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); + }] + .into(); + let plan = sort_exec(sort_key, filter_exec(parquet_exec())); // Test: run EnforceDistribution, then EnforceSort. let expected = &[ @@ -2368,10 +2360,11 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -2382,7 +2375,6 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> ("c".to_string(), "c".to_string()), ], ), - false, ); // Test: run EnforceDistribution, then EnforceSort. @@ -2453,10 +2445,11 @@ fn parallelization_single_partition() -> Result<()> { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -2643,12 +2636,13 @@ fn parallelization_two_partitions_into_four() -> Result<()> { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); - let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); + }] + .into(); + let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec())); + let plan_csv = limit_exec(sort_exec(sort_key, csv_exec())); let test_config = TestConfig::default(); @@ -2686,16 +2680,14 @@ fn parallelization_sorted_limit() -> Result<()> { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(filter_exec(sort_exec( - sort_key.clone(), - parquet_exec(), - false, - ))); - let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec(), false))); + }] + .into(); + let plan_parquet = + limit_exec(filter_exec(sort_exec(sort_key.clone(), parquet_exec()))); + let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec()))); let test_config = TestConfig::default(); @@ -2840,14 +2832,15 @@ fn parallelization_union_inputs() -> Result<()> { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), ); let plan_csv = sort_preserving_merge_exec(sort_key.clone(), csv_exec_with_sort(vec![sort_key])); @@ -2881,13 +2874,17 @@ fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = - union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); let plan_csv = sort_preserving_merge_exec(sort_key, input_csv); @@ -2954,14 +2951,15 @@ fn parallelization_sort_preserving_merge_with_union() -> Result<()> { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key.clone(), ); let plan_csv = @@ -2999,22 +2997,26 @@ fn parallelization_does_not_benefit() -> Result<()> { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ ("a".to_string(), "a2".to_string()), ("c".to_string(), "c2".to_string()), ]; - let proj_parquet = - projection_exec_with_alias(parquet_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_parquet.schema()).unwrap(), + let proj_parquet = projection_exec_with_alias( + parquet_exec_with_sort(schema, vec![sort_key]), + alias_pairs, + ); + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_parquet.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); let expected = &[ @@ -3045,10 +3047,11 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -3058,10 +3061,11 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_csv.schema()).unwrap(), + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_csv.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); let expected = &[ "SortPreservingMergeExec: [c2@1 ASC]", @@ -3114,10 +3118,11 @@ fn remove_redundant_roundrobins() -> Result<()> { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3144,10 +3149,11 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3169,10 +3175,11 @@ fn preserve_ordering_through_repartition() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3208,10 +3215,11 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3233,16 +3241,18 @@ fn no_need_for_sort_after_filter() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_req = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let test_config = TestConfig::default(); @@ -3278,10 +3288,11 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); @@ -3300,10 +3311,11 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); let expected = &[ @@ -3337,10 +3349,11 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); @@ -3374,12 +3387,13 @@ fn put_sort_when_input_is_valid() -> Result<()> { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; - let input = parquet_exec_with_sort(vec![sort_key]); + let input = parquet_exec_with_sort(schema, vec![sort_key]); let physical_plan = aggregate_exec_with_alias(input, alias); // TestConfig: @@ -3400,10 +3414,11 @@ fn do_not_add_unnecessary_hash() -> Result<()> { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); @@ -3489,11 +3504,8 @@ async fn test_distribute_sort_parquet() -> Result<()> { ); let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), - options: SortOptions::default(), - }]); - let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192), false); + let sort_key = [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192)); // prior to optimization, this is the starting plan let starting = &[ @@ -3583,16 +3595,12 @@ fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { // Create a base plan let parquet_exec = parquet_exec(); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions::default(), - }; - - let ordering = LexOrdering::new(vec![sort_expr]); + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("id", 0))); // Create a SortPreservingMergeExec with fetch=5 let spm_exec = Arc::new( - SortPreservingMergeExec::new(ordering, parquet_exec.clone()).with_fetch(Some(5)), + SortPreservingMergeExec::new([sort_expr].into(), parquet_exec.clone()) + .with_fetch(Some(5)), ); // Create distribution context diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index f7668c8aab11..80cfdb9e2490 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -18,13 +18,14 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, create_test_schema, create_test_schema2, - create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, limit_exec, - local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, - sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, - spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, + aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, + create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, + hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, + projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, + sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, + union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -32,89 +33,53 @@ use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, TransformedResult}; use datafusion_common::{Result, ScalarValue}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::source::DataSourceExec; +use datafusion_expr_common::operator::Operator; use datafusion_expr::{JoinType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr::expressions::{col, Column, NotExpr}; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, OrderingRequirements +}; +use datafusion_physical_expr::{Distribution, Partitioning}; +use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan, InputOrderMode}; -use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; +use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::listing::PartitionedFile; use datafusion_physical_optimizer::enforce_sorting::{EnforceSorting, PlanWithCorrespondingCoalescePartitions, PlanWithCorrespondingSort, parallelize_sorts, ensure_sorting}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown, assign_initial_requirements, pushdown_sorts}; use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_functions_aggregate::average::avg_udaf; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::source::DataSourceExec; use rstest::rstest; -/// Create a csv exec for tests -fn csv_exec_ordered( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(true, 0, b'"')), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - -/// Created a sorted parquet exec -pub fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - source, - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( + let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), schema.clone(), Arc::new(CsvSource::new(false, 0, 0)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); + .with_file(PartitionedFile::new("x".to_string(), 100)); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + builder = builder.with_output_ordering(vec![ordering]); + } + let config = builder.build(); DataSourceExec::from_data_source(config) } @@ -163,7 +128,7 @@ macro_rules! assert_optimized { plan_with_pipeline_fixer, false, true, - &config, + &config, ) }) .data() @@ -210,24 +175,27 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { let left_schema = create_test_schema2()?; let right_schema = create_test_schema3()?; let left_input = memory_exec(&left_schema); - let parquet_sort_exprs = vec![sort_expr("a", &right_schema)]; - let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("a", &right_schema)].into(); + let right_input = + parquet_exec_with_sort(right_schema.clone(), vec![parquet_ordering]); let on = vec![( Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _, Arc::new(Column::new_with_schema("c", &right_schema)?) as _, )]; let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; - let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); + let physical_plan = sort_exec([sort_expr("a", &join.schema())].into(), join); - let expected_input = ["SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - - let expected_optimized = ["HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", + ]; + let expected_optimized = [ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -236,41 +204,40 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { #[tokio::test] async fn test_do_not_remove_sort_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort = sort_exec(sort_exprs.clone(), source1); - let limit = limit_exec(sort); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let sort = sort_exec(ordering.clone(), source1); + let limit = local_limit_exec(sort, 100); + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, limit]); let repartition = repartition_exec(union); - let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); + let physical_plan = sort_preserving_merge_exec(ordering, repartition); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - + " LocalLimitExec: fetch=100", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " LocalLimitExec: fetch=100", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -279,18 +246,15 @@ async fn test_do_not_remove_sort_with_limit() -> Result<()> { #[tokio::test] async fn test_union_inputs_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs.clone()); - + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering.clone()]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", @@ -298,8 +262,7 @@ async fn test_union_inputs_sorted() -> Result<()> { " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } @@ -307,22 +270,20 @@ async fn test_union_inputs_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let parquet_ordering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", @@ -330,8 +291,7 @@ async fn test_union_inputs_different_sorted() -> Result<()> { " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } @@ -339,35 +299,36 @@ async fn test_union_inputs_different_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted2() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let sort_exprs: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); let physical_plan = sort_preserving_merge_exec(sort_exprs, union); // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First DataSourceExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -376,40 +337,42 @@ async fn test_union_inputs_different_sorted2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted3() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort2 = sort_exec(sort_exprs2, source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let sort2 = sort_exec(ordering2, source1); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // First input to the union is not Sorted (SortExec is finer than required ordering by the SortPreservingMergeExec above). // Second input to the union is already Sorted (matches with the required ordering by the SortPreservingMergeExec above). // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // should adjust sorting in the first input of the union such that it is not unnecessarily fine - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -418,40 +381,42 @@ async fn test_union_inputs_different_sorted3() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted4() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs2.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs2.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering2.clone(), source1.clone()); + let sort2 = sort_exec(ordering2.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs1, union); + let physical_plan = sort_preserving_merge_exec(ordering1, union); // Ordering requirement of the `SortPreservingMergeExec` is not met. // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -460,13 +425,13 @@ async fn test_union_inputs_different_sorted4() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted5() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -476,29 +441,33 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. However, we should be able to change the unnecessarily // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -507,22 +476,20 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted6() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let repartition = repartition_exec(source1); - let spm = sort_preserving_merge_exec(sort_exprs2, repartition); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + let spm = sort_preserving_merge_exec(ordering2, repartition); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, spm]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // The plan is not valid as it is -- the input ordering requirement // of the `SortPreservingMergeExec` under the third child of the @@ -530,24 +497,28 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { // At the same time, this ordering requirement is unnecessarily fine. // The final plan should be valid AND the ordering of the third child // shouldn't be finer than necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -556,33 +527,36 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted7() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs1, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1.clone(), source1.clone()); + let sort2 = sort_exec(ordering1, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering2, union); // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec - let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_output = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_output, physical_plan, true); Ok(()) @@ -591,13 +565,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted8() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr_options( "nullable_col", &schema, @@ -614,55 +588,454 @@ async fn test_union_inputs_different_sorted8() -> Result<()> { nulls_first: false, }, ), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let physical_plan = union_exec(vec![sort1, sort2]); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. - let expected_input = ["UnionExec", + let expected_input = [ + "UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. - let expected_optimized = ["UnionExec", + let expected_optimized = [ + "UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] -async fn test_window_multi_path_sort() -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let sort_exprs = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(sort_exprs, source); + let partition_bys = &[col("nullable_col", &schema)?]; + let physical_plan = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); - let sort_exprs1 = vec![ - sort_expr("nullable_col", &schema), - sort_expr("non_nullable_col", &schema), + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - // reverse sorting of sort_exprs2 - let sort_exprs3 = vec![sort_expr_options( + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "count".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + let physical_plan = projection_exec(proj_exprs, bounded_window)?; + + let expected_input = [ + "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + let ordering = [sort_expr_options( "nullable_col", &schema, SortOptions { descending: true, nulls_first: false, }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), )]; - let source1 = parquet_exec_sorted(&schema, sort_exprs1); - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - let sort1 = sort_exec(sort_exprs3.clone(), source1); - let sort2 = sort_exec(sort_exprs3.clone(), source2); + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let physical_plan = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let physical_plan = bounded_window_exec_with_partition( + "count", + vec![], + partition_bys, + bounded_window, + ); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let sort3 = sort_exec(ordering2, sort2); + let physical_plan = + bounded_window_exec_with_partition("count", vec![], partition_bys, sort3); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} +#[tokio::test] +async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let physical_plan = sort_exec(ordering2, sort2); + + let expected_input = [ + "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let partition_bys1 = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys1, sort); + let partition_bys2 = &[col("non_nullable_col", &schema)?]; + let bounded_window2 = bounded_window_exec_with_partition( + "non_nullable_col", + vec![], + partition_bys2, + bounded_window, + ); + let requirement = [PhysicalSortRequirement::new( + col("non_nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan = Arc::new(OutputRequirementExec::new( + bounded_window2, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + )); + + let expected_input = [ + "OutputRequirementExec", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "OutputRequirementExec", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + // " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "OutputRequirementExec", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_window_multi_path_sort() -> Result<()> { + let schema = create_test_schema()?; + let ordering1 = [ + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), + ] + .into(); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + // Reverse of the above + let ordering3: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering1]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); + let sort1 = sort_exec(ordering3.clone(), source1); + let sort2 = sort_exec(ordering3.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = sort_preserving_merge_exec(sort_exprs3.clone(), union); - let physical_plan = bounded_window_exec("nullable_col", sort_exprs3, spm); + let spm = sort_preserving_merge_exec(ordering3.clone(), union); + let physical_plan = bounded_window_exec("nullable_col", ordering3, spm); // The `WindowAggExec` gets its sorting from multiple children jointly. // During the removal of `SortExec`s, it should be able to remove the @@ -675,13 +1048,15 @@ async fn test_window_multi_path_sort() -> Result<()> { " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; let expected_optimized = [ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -690,35 +1065,38 @@ async fn test_window_multi_path_sort() -> Result<()> { #[tokio::test] async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - - let sort_exprs1 = LexOrdering::new(vec![ + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let sort1 = sort_exec(sort_exprs1.clone(), source1); - let sort2 = sort_exec(sort_exprs1.clone(), source2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering2.clone()]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2.clone()]); + let sort1 = sort_exec(ordering1.clone(), source1); + let sort2 = sort_exec(ordering1.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = Arc::new(SortPreservingMergeExec::new(sort_exprs1, union)) as _; - let physical_plan = bounded_window_exec("nullable_col", sort_exprs2, spm); + let spm = Arc::new(SortPreservingMergeExec::new(ordering1, union)) as _; + let physical_plan = bounded_window_exec("nullable_col", ordering2, spm); // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -727,13 +1105,13 @@ async fn test_window_multi_path_sort2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -743,34 +1121,37 @@ async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - - let sort2 = sort_exec(sort_exprs2, source1); - let limit = local_limit_exec(sort2); - let limit = global_limit_exec(limit); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); + let limit = local_limit_exec(sort2, 100); + let limit = global_limit_exec(limit, 0, Some(100)); let union = union_exec(vec![sort1, limit]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -781,13 +1162,13 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join_types = vec![ @@ -801,11 +1182,12 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("nullable_col", &join.schema()), sort_expr("non_nullable_col", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs.clone(), join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); let join_plan = format!( "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -853,13 +1235,13 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join_types = vec![ @@ -872,11 +1254,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("col_a", &join.schema()), sort_expr("col_b", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); let join_plan = format!( "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -925,58 +1308,65 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner); // order by (col_b, col_a) - let sort_exprs1 = vec![ + let ordering = [ sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join.clone()); - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", + let expected_optimized = [ + "SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) - let sort_exprs2 = vec![ + let ordering2 = [ sort_expr("nullable_col", &join.schema()), sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering2, join); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - - // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; + // Can push down the sort requirements since col_a = nullable_col + let expected_optimized = [ + "SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -985,33 +1375,34 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { #[tokio::test] async fn test_multilayer_coalesce_partitions() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); + let source1 = parquet_exec(schema.clone()); let repartition = repartition_exec(source1); - let coalesce = Arc::new(CoalescePartitionsExec::new(repartition)) as _; + let coalesce = coalesce_partitions_exec(repartition) as _; // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), coalesce, ); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let physical_plan = sort_exec(sort_exprs, filter); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_exec(ordering, filter); // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1020,26 +1411,30 @@ async fn test_multilayer_coalesce_partitions() -> Result<()> { #[tokio::test] async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + let physical_plan = sort_exec([sort_expr("a", &schema)].into(), coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1051,20 +1446,20 @@ async fn test_with_lost_ordering_unbounded_bounded( #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; // create either bounded or unbounded source let source = if source_unbounded { - stream_exec_ordered(&schema, sort_exprs) + stream_exec_ordered(&schema, sort_exprs.clone().into()) } else { - csv_exec_ordered(&schema, sort_exprs) + csv_exec_sorted(&schema, sort_exprs.clone()) }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + let physical_plan = sort_exec(sort_exprs.into(), coalesce_partitions); // Expected inputs unbounded and bounded let expected_input_unbounded = vec![ @@ -1079,7 +1474,7 @@ async fn test_with_lost_ordering_unbounded_bounded( " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; // Expected unbounded result (same for with and without flag) @@ -1096,14 +1491,14 @@ async fn test_with_lost_ordering_unbounded_bounded( " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = if source_unbounded { @@ -1138,20 +1533,24 @@ async fn test_with_lost_ordering_unbounded_bounded( #[tokio::test] async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); - let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); + let physical_plan = sort_exec([sort_expr("b", &schema)].into(), spm); - let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; + let expected_optimized = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) @@ -1160,27 +1559,31 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), spm, ); - let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) @@ -1189,13 +1592,12 @@ async fn test_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, vec![]); - let sort = sort_exec(sort_exprs.clone(), source); + let sort = sort_exec(sort_exprs.clone().into(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); - let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); - + let spm = sort_preserving_merge_exec(sort_exprs.clone().into(), repartition); let physical_plan = bounded_window_exec("a", sort_exprs, spm); let expected_input = [ @@ -1221,15 +1623,15 @@ async fn test_window_multi_layer_requirement() -> Result<()> { #[tokio::test] async fn test_not_replaced_with_partial_sort_for_bounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let parquet_input = parquet_exec_sorted(&schema, input_sort_exprs); - + let parquet_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let parquet_input = parquet_exec_with_sort(schema.clone(), vec![parquet_ordering]); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), parquet_input, ); let expected_input = [ @@ -1333,8 +1735,8 @@ macro_rules! assert_optimized { async fn test_remove_unnecessary_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), input); let expected_input = [ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1354,57 +1756,54 @@ async fn test_remove_unnecessary_sort() -> Result<()> { async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "non_nullable_col", &source.schema(), SortOptions { descending: true, nulls_first: true, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort); - - let window_agg = - bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); - - let sort_exprs = vec![sort_expr_options( + let coalesce_batches = coalesce_batches_exec(sort, 128); + let window_agg = bounded_window_exec("non_nullable_col", ordering, coalesce_batches); + let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - - let sort = sort_exec(sort_exprs.clone(), window_agg); - + )] + .into(); + let sort = sort_exec(ordering2.clone(), window_agg); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), sort, ); + let physical_plan = bounded_window_exec("non_nullable_col", ordering2, filter); - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " FilterExec: NOT non_nullable_col@1", " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]" + ]; - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + let expected_optimized = [ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", " FilterExec: NOT non_nullable_col@1", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]" + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1414,10 +1813,8 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { async fn test_add_required_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - - let physical_plan = sort_preserving_merge_exec(sort_exprs, source); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering, source); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1436,13 +1833,12 @@ async fn test_add_required_sort() -> Result<()> { async fn test_remove_unnecessary_sort1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering.clone(), spm); + let physical_plan = sort_preserving_merge_exec(ordering, sort); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), spm); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1463,19 +1859,18 @@ async fn test_remove_unnecessary_sort1() -> Result<()> { async fn test_remove_unnecessary_sort2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort2 = sort_exec(sort_exprs.clone(), spm); - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort3 = sort_exec(sort_exprs, spm2); + ] + .into(); + let sort2 = sort_exec(ordering2.clone(), spm); + let spm2 = sort_preserving_merge_exec(ordering2, sort2); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort3 = sort_exec(ordering3, spm2); let physical_plan = repartition_exec(repartition_exec(sort3)); let expected_input = [ @@ -1488,7 +1883,6 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -1503,21 +1897,20 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { async fn test_remove_unnecessary_sort3() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = LexOrdering::new(vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let repartition_exec = repartition_exec(spm); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let repartition_exec = repartition_exec(spm); + ] + .into(); let sort2 = Arc::new( - SortExec::new(sort_exprs.clone(), repartition_exec) + SortExec::new(ordering2.clone(), repartition_exec) .with_preserve_partitioning(true), ) as _; - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - + let spm2 = sort_preserving_merge_exec(ordering2, sort2); let physical_plan = aggregate_exec(spm2); // When removing a `SortPreservingMergeExec`, make sure that partitioning @@ -1532,7 +1925,6 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = [ "AggregateExec: mode=Final, gby=[], aggr=[]", " CoalescePartitionsExec", @@ -1548,34 +1940,29 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { async fn test_remove_unnecessary_sort4() -> Result<()> { let schema = create_test_schema()?; let source1 = repartition_exec(memory_exec(&schema)); - let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - - let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); - // let sort = sort_exec(sort_exprs.clone(), union); - let sort = Arc::new( - SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), - ) as _; - let spm = sort_preserving_merge_exec(sort_exprs, sort); - + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = + Arc::new(SortExec::new(ordering.clone(), union).with_preserve_partitioning(true)) + as _; + let spm = sort_preserving_merge_exec(ordering, sort); let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), spm, ); - - let sort_exprs = vec![ + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let physical_plan = sort_exec(sort_exprs, filter); + ] + .into(); + let physical_plan = sort_exec(ordering2, filter); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", @@ -1583,16 +1970,18 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " DataSourceExec: partitions=1, partition_sizes=[0]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " UnionExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " DataSourceExec: partitions=1, partition_sizes=[0]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1602,18 +1991,17 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { async fn test_remove_unnecessary_sort6() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - ) - .with_fetch(Some(2)), + let input = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + source, ); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), input, ); @@ -1635,21 +2023,19 @@ async fn test_remove_unnecessary_sort6() -> Result<()> { async fn test_remove_unnecessary_sort7() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![ + let input = sort_exec( + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ]), + ] + .into(), source, - )); - - let physical_plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - input, - ) - .with_fetch(Some(2)), - ) as Arc; + ); + let physical_plan = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + input, + ); let expected_input = [ "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC]", @@ -1670,16 +2056,14 @@ async fn test_remove_unnecessary_sort7() -> Result<()> { async fn test_remove_unnecessary_sort8() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(LocalLimitExec::new(input, 2)); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), limit, ); @@ -1703,13 +2087,9 @@ async fn test_remove_unnecessary_sort8() -> Result<()> { async fn test_do_not_pushdown_through_limit() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), limit); let expected_input = [ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1732,12 +2112,11 @@ async fn test_do_not_pushdown_through_limit() -> Result<()> { async fn test_remove_unnecessary_spm1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input2 = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], input); + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let input = sort_preserving_merge_exec(ordering.clone(), source); + let input2 = sort_preserving_merge_exec(ordering, input); let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("nullable_col", &schema)], input2); + sort_preserving_merge_exec([sort_expr("nullable_col", &schema)].into(), input2); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1759,7 +2138,7 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = sort_preserving_merge_exec_with_fetch( - vec![sort_expr("non_nullable_col", &schema)], + [sort_expr("non_nullable_col", &schema)].into(), source, 100, ); @@ -1782,12 +2161,13 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { async fn test_change_wrong_sorting() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let sort = sort_exec(vec![sort_exprs[0].clone()], source); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + let sort = sort_exec([sort_exprs[0].clone()].into(), source); + let physical_plan = sort_preserving_merge_exec(sort_exprs.into(), sort); + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1806,13 +2186,13 @@ async fn test_change_wrong_sorting() -> Result<()> { async fn test_change_wrong_sorting2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let spm1 = sort_preserving_merge_exec(sort_exprs.clone(), source); - let sort2 = sort_exec(vec![sort_exprs[0].clone()], spm1); - let physical_plan = sort_preserving_merge_exec(vec![sort_exprs[1].clone()], sort2); + let spm1 = sort_preserving_merge_exec(sort_exprs.clone().into(), source); + let sort2 = sort_exec([sort_exprs[0].clone()].into(), spm1); + let physical_plan = sort_preserving_merge_exec([sort_exprs[1].clone()].into(), sort2); let expected_input = [ "SortPreservingMergeExec: [non_nullable_col@1 ASC]", @@ -1833,31 +2213,31 @@ async fn test_change_wrong_sorting2() -> Result<()> { async fn test_multiple_sort_window_exec() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort_exprs2 = vec![ + let ordering1 = [sort_expr("nullable_col", &schema)]; + let sort1 = sort_exec(ordering1.clone().into(), source); + let window_agg1 = bounded_window_exec("non_nullable_col", ordering1.clone(), sort1); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; + let window_agg2 = bounded_window_exec("non_nullable_col", ordering2, window_agg1); + let physical_plan = bounded_window_exec("non_nullable_col", ordering1, window_agg2); - let sort1 = sort_exec(sort_exprs1.clone(), source); - let window_agg1 = bounded_window_exec("non_nullable_col", sort_exprs1.clone(), sort1); - let window_agg2 = bounded_window_exec("non_nullable_col", sort_exprs2, window_agg1); - // let filter_exec = sort_exec; - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1871,15 +2251,12 @@ async fn test_multiple_sort_window_exec() -> Result<()> { // EnforceDistribution may invalidate ordering invariant. async fn test_commutativity() -> Result<()> { let schema = create_test_schema()?; - let config = ConfigOptions::new(); - let memory_exec = memory_exec(&schema); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let sort_exprs = [sort_expr("nullable_col", &schema)]; let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); + let orig_plan = sort_exec(sort_exprs.into(), repartition); - let orig_plan = - Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; let actual = get_plan_string(&orig_plan); let expected_input = vec![ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1892,26 +2269,25 @@ async fn test_commutativity() -> Result<()> { "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_input:#?}\nactual:\n\n{actual:#?}\n\n" ); - let mut plan = orig_plan.clone(); + let config = ConfigOptions::new(); let rules = vec![ Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut first_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + first_plan = rule.optimize(first_plan, &config)?; } - let first_plan = plan.clone(); - let mut plan = orig_plan.clone(); let rules = vec![ Arc::new(EnforceSorting::new()) as Arc, Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut second_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + second_plan = rule.optimize(second_plan, &config)?; } - let second_plan = plan.clone(); assert_eq!(get_plan_string(&first_plan), get_plan_string(&second_plan)); Ok(()) @@ -1922,15 +2298,15 @@ async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let repartition = repartition_exec(source); - let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); + let coalesce_partitions = coalesce_partitions_exec(repartition); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); // Add local sort let sort = Arc::new( - SortExec::new(sort_exprs.clone(), repartition).with_preserve_partitioning(true), + SortExec::new(ordering.clone(), repartition).with_preserve_partitioning(true), ) as _; - let spm = sort_preserving_merge_exec(sort_exprs.clone(), sort); - let sort = sort_exec(sort_exprs, spm); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering, spm); let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort @@ -1958,15 +2334,15 @@ async fn test_coalesce_propagate() -> Result<()> { #[tokio::test] async fn test_replace_with_partial_sort2() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("c", &schema), sort_expr("d", &schema), - ], + ] + .into(), unbounded_input, ); @@ -1985,76 +2361,65 @@ async fn test_replace_with_partial_sort2() -> Result<()> { #[tokio::test] async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { - // SortExec: expr=[b] <-- can't push this down - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_b: LexOrdering = [sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_b, plan); + let plan = sort_exec(ordering_b, plan); let expected_input = [ - "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", // <-- can't push this down + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; // should not be able to push shorts - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, plan, true); + assert_optimized!(expected_input, expected_input, plan, true); Ok(()) } // test when the required input ordering is satisfied so could push through #[tokio::test] async fn test_push_with_required_input_ordering_allowed() -> Result<()> { - // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_ab = - LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_ab = [sort_expr("a", &schema), sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_ab, plan); + let plan = sort_exec(ordering_ab, plan); let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", // <-- can push this down (as it is compatible with the required input ordering) + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - // should able to push shorts - let expected = [ + // Should be able to push down + let expected_optimized = [ "RequiredInputOrderingExec", " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected, plan, true); + assert_optimized!(expected_input, expected_optimized, plan, true); Ok(()) } #[tokio::test] async fn test_replace_with_partial_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![sort_expr("a", &schema), sort_expr("c", &schema)], + [sort_expr("a", &schema), sort_expr("c", &schema)].into(), unbounded_input, ); @@ -2073,38 +2438,38 @@ async fn test_replace_with_partial_sort() -> Result<()> { #[tokio::test] async fn test_not_replaced_with_partial_sort_for_unbounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), unbounded_input, ); let expected_input = [ "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } #[tokio::test] async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { let input_schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_options( + let ordering = [sort_expr_options( "nullable_col", &input_schema, SortOptions { descending: false, nulls_first: false, }, - )]; - let source = parquet_exec_sorted(&input_schema, sort_exprs); + )] + .into(); + let source = parquet_exec_with_sort(input_schema.clone(), vec![ordering]) as _; // Function definition - Alias of the resulting column - Arguments of the function #[derive(Clone)] @@ -3307,7 +3672,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.func.1, &case.func.2, &partition_by, - &LexOrdering::default(), + &[], case.window_frame, input_schema.as_ref(), false, @@ -3341,7 +3706,8 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { ) }) .collect::>(); - let physical_plan = sort_exec(sort_expr, window_exec); + let ordering = LexOrdering::new(sort_expr).unwrap(); + let physical_plan = sort_exec(ordering, window_exec); assert_optimized!( case.initial_plan, @@ -3358,11 +3724,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { #[test] fn test_removes_unused_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort = sort_exec(vec![sort_expr("a", &schema)], unbounded_input); - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); // same sort as data source + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let orthogonal_sort = sort_exec([sort_expr("a", &schema)].into(), unbounded_input); + let output_sort = sort_exec(input_ordering, orthogonal_sort); // same sort as data source // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3370,7 +3736,7 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should remove orthogonal sort, and the uppermost (unneeded) sort: let expected_optimized = [ @@ -3384,12 +3750,12 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { #[test] fn test_keeps_used_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); let orthogonal_sort = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); + sort_exec_with_fetch([sort_expr("a", &schema)].into(), Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output + let output_sort = sort_exec(input_ordering, orthogonal_sort); // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3397,7 +3763,7 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should keep the orthogonal sort, since it modifies the output: let expected_optimized = expected_input; @@ -3409,15 +3775,17 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { #[test] fn test_handles_multiple_orthogonal_sorts() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort_0 = sort_exec(vec![sort_expr("c", &schema)], unbounded_input); // has no fetch, so can be removed + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let ordering0: LexOrdering = [sort_expr("c", &schema)].into(); + let orthogonal_sort_0 = sort_exec(ordering0.clone(), unbounded_input); // has no fetch, so can be removed + let ordering1: LexOrdering = [sort_expr("a", &schema)].into(); let orthogonal_sort_1 = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output - let orthogonal_sort_2 = sort_exec(vec![sort_expr("c", &schema)], orthogonal_sort_1); // has no fetch, so can be removed - let orthogonal_sort_3 = sort_exec(vec![sort_expr("a", &schema)], orthogonal_sort_2); // has no fetch, so can be removed - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort_3); // final sort + sort_exec_with_fetch(ordering1.clone(), Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output + let orthogonal_sort_2 = sort_exec(ordering0, orthogonal_sort_1); // has no fetch, so can be removed + let orthogonal_sort_3 = sort_exec(ordering1, orthogonal_sort_2); // has no fetch, so can be removed + let output_sort = sort_exec(input_ordering, orthogonal_sort_3); // final sort // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3428,7 +3796,7 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should keep only the needed orthogonal sort, and remove the unneeded ones: let expected_optimized = [ @@ -3445,13 +3813,14 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { fn test_parallelize_sort_preserves_fetch() -> Result<()> { // Create a schema let schema = create_test_schema3()?; - let parquet_exec = parquet_exec(&schema); - let coalesced = Arc::new(CoalescePartitionsExec::new(parquet_exec.clone())); - let top_coalesced = - Arc::new(CoalescePartitionsExec::new(coalesced.clone()).with_fetch(Some(10))); + let parquet_exec = parquet_exec(schema); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()); + let top_coalesced = coalesce_partitions_exec(coalesced.clone()) + .with_fetch(Some(10)) + .unwrap(); let requirements = PlanWithCorrespondingCoalescePartitions::new( - top_coalesced.clone(), + top_coalesced, true, vec![PlanWithCorrespondingCoalescePartitions::new( coalesced, diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index dd2c1960a658..56d48901f284 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -17,28 +17,26 @@ use std::sync::Arc; +use crate::physical_optimizer::test_utils::{ + coalesce_batches_exec, coalesce_partitions_exec, global_limit_exec, local_limit_exec, + sort_exec, sort_preserving_merge_exec, stream_exec, +}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_expr::expressions::{col, lit, BinaryExpr}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -48,48 +46,6 @@ fn create_schema() -> SchemaRef { ])) } -fn streaming_table_exec(schema: SchemaRef) -> Result> { - Ok(Arc::new(StreamingTableExec::try_new( - Arc::clone(&schema), - vec![Arc::new(DummyStreamPartition { schema }) as _], - None, - None, - true, - None, - )?)) -} - -fn global_limit_exec( - input: Arc, - skip: usize, - fetch: Option, -) -> Arc { - Arc::new(GlobalLimitExec::new(input, skip, fetch)) -} - -fn local_limit_exec( - input: Arc, - fetch: usize, -) -> Arc { - Arc::new(LocalLimitExec::new(input, fetch)) -} - -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} - fn projection_exec( schema: SchemaRef, input: Arc, @@ -118,16 +74,6 @@ fn filter_exec( )?)) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec( - local_limit: Arc, -) -> Arc { - Arc::new(CoalescePartitionsExec::new(local_limit)) -} - fn repartition_exec( streaming_table: Arc, ) -> Result> { @@ -141,24 +87,11 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } -#[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, -} -impl PartitionStream for DummyStreamPartition { - fn schema(&self) -> &SchemaRef { - &self.schema - } - fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - unreachable!() - } -} - #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -183,7 +116,7 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); let initial = get_plan_string(&global_limit); @@ -209,10 +142,10 @@ fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_li fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let local_limit = local_limit_exec(coalesce_batches, 5); let coalesce_partitions = coalesce_partitions_exec(local_limit); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); @@ -247,7 +180,7 @@ fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limi #[test] fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let filter = filter_exec(Arc::clone(&schema), streaming_table)?; let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -279,8 +212,8 @@ fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(schema, coalesce_batches)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -310,18 +243,17 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc #[test] fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; let repartition = repartition_exec(projection)?; - let sort = sort_exec( - vec![PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }], - repartition, - ); - let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let ordering: LexOrdering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }] + .into(); + let sort = sort_exec(ordering.clone(), repartition); + let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -357,7 +289,7 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; let coalesce_partitions = coalesce_partitions_exec(filter); diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index f9810eab8f59..409d392f7819 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -30,9 +30,8 @@ use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::cast; -use datafusion_physical_expr::{expressions, expressions::col, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::{self, cast, col}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ aggregates::{AggregateExec, AggregateMode}, collect, @@ -236,12 +235,13 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { #[test] fn test_has_order_by() -> Result<()> { - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema()).unwrap(), + let schema = schema(); + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let source = parquet_exec_with_sort(vec![sort_key]); - let schema = source.schema(); + }] + .into(); + let source = parquet_exec_with_sort(schema.clone(), vec![sort_key]); // `SELECT a FROM DataSourceExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec @@ -263,7 +263,7 @@ fn test_has_order_by() -> Result<()> { "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", ]; - let plan: Arc = Arc::new(limit_exec); + let plan = Arc::new(limit_exec) as _; assert_plan_matches_expected(&plan, &expected)?; Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 62f04f2fe740..90124e0fcfc7 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -17,6 +17,8 @@ #[cfg(test)] mod test { + use std::sync::Arc; + use arrow::array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SortOptions}; use datafusion::datasource::listing::ListingTable; @@ -32,7 +34,7 @@ mod test { use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -49,8 +51,8 @@ mod test { execute_stream_partitioned, get_plan_string, ExecutionPlan, ExecutionPlanProperties, }; + use futures::TryStreamExt; - use std::sync::Arc; /// Creates a test table with statistics from the test data directory. /// @@ -258,17 +260,12 @@ mod test { async fn test_statistics_by_partition_of_sort() -> Result<()> { let scan_1 = create_scan_exec_with_statistics(None, Some(1)).await; // Add sort execution plan - let sort = SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]), - scan_1, - ); - let sort_exec: Arc = Arc::new(sort.clone()); + let ordering = [PhysicalSortExpr::new( + Arc::new(Column::new("id", 0)), + SortOptions::new(false, false), + )]; + let sort = SortExec::new(ordering.clone().into(), scan_1); + let sort_exec: Arc = Arc::new(sort); let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; @@ -284,17 +281,7 @@ mod test { let scan_2 = create_scan_exec_with_statistics(None, Some(2)).await; // Add sort execution plan let sort_exec: Arc = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]), - scan_2, - ) - .with_preserve_partitioning(true), + SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), ); let expected_statistic_partition_1 = create_partition_statistics(2, 110, 3, 4, true); diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 7c00d323a8e6..f2958deb577e 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -25,21 +25,22 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; -use datafusion_common::{JoinSide, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, JoinType, Result, ScalarValue}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; -use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::{ - Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -54,13 +55,10 @@ use datafusion_physical_plan::projection::{update_expr, ProjectionExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr_common::columnar_value::ColumnarValue; use itertools::Itertools; /// Mocked UDF @@ -519,7 +517,7 @@ fn test_streaming_table_after_projection() -> Result<()> { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -528,11 +526,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }]), + }] + .into(), ] .into_iter(), true, @@ -579,7 +579,7 @@ fn test_streaming_table_after_projection() -> Result<()> { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -588,11 +588,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }]), + }] + .into(), ] ); assert!(result.is_infinite()); @@ -652,21 +654,24 @@ fn test_projection_after_projection() -> Result<()> { fn test_output_req_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(OutputRequirementExec::new( - csv.clone(), - Some(LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 1)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: Some(SortOptions::default()), - }, - ])), + csv, + Some(OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 1)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + )), Distribution::HashPartitioned(vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), @@ -699,20 +704,23 @@ fn test_output_req_after_projection() -> Result<()> { ]; assert_eq!(get_plan_string(&after_optimize), expected); - let expected_reqs = LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 2)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 0)), - Operator::Plus, - Arc::new(Column::new("new_a", 1)), - )), - options: Some(SortOptions::default()), - }, - ]); + let expected_reqs = OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 2)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + ); assert_eq!( after_optimize .as_any() @@ -795,15 +803,15 @@ fn test_filter_after_projection() -> Result<()> { Arc::new(Column::new("a", 0)), )), )); - let filter: Arc = Arc::new(FilterExec::try_new(predicate, csv)?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let filter = Arc::new(FilterExec::try_new(predicate, csv)?); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 0)), "a_new".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), (Arc::new(Column::new("d", 3)), "d".to_string()), ], filter.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -880,7 +888,7 @@ fn test_join_after_projection() -> Result<()> { None, StreamJoinPartitionMode::SinglePartition, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), @@ -889,7 +897,7 @@ fn test_join_after_projection() -> Result<()> { (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), ], join, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", @@ -945,7 +953,7 @@ fn test_join_after_required_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + let join = Arc::new(SymmetricHashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -994,7 +1002,7 @@ fn test_join_after_required_projection() -> Result<()> { None, StreamJoinPartitionMode::SinglePartition, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 5)), "a".to_string()), (Arc::new(Column::new("b", 6)), "b".to_string()), @@ -1008,7 +1016,7 @@ fn test_join_after_required_projection() -> Result<()> { (Arc::new(Column::new("e", 4)), "e".to_string()), ], join, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", @@ -1061,7 +1069,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { Field::new("c", DataType::Int32, true), ]); - let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + let join = Arc::new(NestedLoopJoinExec::try_new( left_csv, right_csv, Some(JoinFilter::new( @@ -1071,12 +1079,12 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )), &JoinType::Inner, None, - )?); + )?) as _; - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![(col_left_c, "c".to_string())], Arc::clone(&join), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c]", @@ -1104,7 +1112,7 @@ fn test_hash_join_after_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(HashJoinExec::try_new( + let join = Arc::new(HashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -1152,7 +1160,7 @@ fn test_hash_join_after_projection() -> Result<()> { PartitionMode::Auto, true, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), @@ -1160,7 +1168,7 @@ fn test_hash_join_after_projection() -> Result<()> { (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), ], join.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" @@ -1174,7 +1182,7 @@ fn test_hash_join_after_projection() -> Result<()> { let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; assert_eq!(get_plan_string(&after_optimize), expected); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 0)), "a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), @@ -1197,7 +1205,7 @@ fn test_hash_join_after_projection() -> Result<()> { #[test] fn test_repartition_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let repartition: Arc = Arc::new(RepartitionExec::try_new( + let repartition = Arc::new(RepartitionExec::try_new( csv, Partitioning::Hash( vec![ @@ -1208,14 +1216,14 @@ fn test_repartition_after_projection() -> Result<()> { 6, ), )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("b", 1)), "b_new".to_string()), (Arc::new(Column::new("a", 0)), "a".to_string()), (Arc::new(Column::new("d", 3)), "d_new".to_string()), ], repartition, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", @@ -1257,31 +1265,26 @@ fn test_repartition_after_projection() -> Result<()> { #[test] fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let sort_exec = SortExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -1307,31 +1310,26 @@ fn test_sort_after_projection() -> Result<()> { #[test] fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let sort_exec = SortPreservingMergeExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -1357,16 +1355,15 @@ fn test_sort_preserving_after_projection() -> Result<()> { #[test] fn test_union_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let union: Arc = - Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let union = Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], union.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 71b9757604ec..f7c134cae34e 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,8 +18,10 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, create_test_schema3, sort_preserving_merge_exec, - stream_exec_ordered_with_projection, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, + create_test_schema3, parquet_exec_with_sort, sort_exec, + sort_exec_with_preserve_partitioning, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -27,30 +29,26 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{assert_contains, Result}; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::source::DataSourceExec; use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{self, col, Column}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ + plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext +}; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::collect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::{ - displayable, get_plan_string, ExecutionPlan, Partitioning, + collect, displayable, get_plan_string, ExecutionPlan, Partitioning, }; -use datafusion::datasource::source::DataSourceExec; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{self, col, Column}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext}; -use datafusion_common::config::ConfigOptions; -use crate::physical_optimizer::enforce_sorting::parquet_exec_sorted; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use object_store::memory::InMemory; use object_store::ObjectStore; use rstest::rstest; @@ -192,16 +190,15 @@ async fn test_replace_multiple_input_repartition_1( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, sort_exprs.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs.clone()) }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(sort_exprs.clone(), repartition); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -261,27 +258,22 @@ async fn test_with_inter_children_change_only( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_default("a", &schema)]; + let ordering: LexOrdering = [sort_expr_default("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("a", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let sort = sort_exec(ordering.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); + let sort2 = sort_exec_with_preserve_partitioning(ordering.clone(), filter); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort2.schema())], sort2); + let physical_plan = sort_preserving_merge_exec(ordering, sort2); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -364,18 +356,17 @@ async fn test_replace_multiple_input_repartition_2( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -440,19 +431,19 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -522,20 +513,20 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); + let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr, 8192); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let coalesce_batches_exec_2 = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -610,19 +601,17 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - - let physical_plan: Arc = - coalesce_partitions_exec(coalesce_batches_exec); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let physical_plan = coalesce_partitions_exec(coalesce_batches_exec); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -683,20 +672,19 @@ async fn test_with_multiple_replacable_repartitions( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -770,23 +758,21 @@ async fn test_not_replace_with_different_orderings( #[values(false, true)] source_unbounded: bool, #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering_a = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering_a) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering_a) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let sort = sort_exec( - vec![sort_expr_default("c", &repartition_hash.schema())], - repartition_hash, - true, - ); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort.schema())], sort); + let ordering_c: LexOrdering = + [sort_expr_default("c", &repartition_hash.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering_c.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -843,17 +829,16 @@ async fn test_with_lost_ordering( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = - sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); + let physical_plan = sort_exec(ordering, coalesce_partitions); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -912,28 +897,26 @@ async fn test_with_lost_and_kept_ordering( #[values(false, true)] source_unbounded: bool, #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering_a = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering_a) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering_a) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("c", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &coalesce_partitions.schema())].into(); + let sort = sort_exec(ordering_c.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort2.schema())], sort2); + let sort2 = sort_exec_with_preserve_partitioning(ordering_c.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort2); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -1019,22 +1002,22 @@ async fn test_with_multiple_child_trees( ) -> Result<()> { let schema = create_test_schema()?; - let left_sort_exprs = vec![sort_expr("a", &schema)]; + let left_ordering = [sort_expr("a", &schema)].into(); let left_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, left_sort_exprs) + stream_exec_ordered_with_projection(&schema, left_ordering) } else { - memory_exec_sorted(&schema, left_sort_exprs) + memory_exec_sorted(&schema, left_ordering) }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); - let right_sort_exprs = vec![sort_expr("a", &schema)]; + let right_ordering = [sort_expr("a", &schema)].into(); let right_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, right_sort_exprs) + stream_exec_ordered_with_projection(&schema, right_ordering) } else { - memory_exec_sorted(&schema, right_sort_exprs) + memory_exec_sorted(&schema, right_ordering) }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); @@ -1043,14 +1026,9 @@ async fn test_with_multiple_child_trees( let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec( - vec![sort_expr_default("a", &hash_join_exec.schema())], - hash_join_exec, - true, - ); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort.schema())], sort); + let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -1149,18 +1127,6 @@ fn sort_expr_options( } } -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( - SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning), - ) -} - fn repartition_exec_round_robin(input: Arc) -> Arc { Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(8)).unwrap()) } @@ -1188,14 +1154,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) -} - fn hash_join_exec( left: Arc, right: Arc, @@ -1233,7 +1191,7 @@ fn create_test_schema() -> Result { // projection parameter is given static due to testing needs fn memory_exec_sorted( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { let values = (0..sz).collect::>(); @@ -1249,7 +1207,6 @@ fn memory_exec_sorted( let rows = 5; let partitions = 1; - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new({ let data: Vec> = (0..partitions) .map(|_| vec![make_partition(schema, rows)]) @@ -1258,7 +1215,7 @@ fn memory_exec_sorted( DataSourceExec::new(Arc::new( MemorySourceConfig::try_new(&data, schema.clone(), Some(projection)) .unwrap() - .try_with_sort_information(vec![sort_exprs]) + .try_with_sort_information(vec![ordering]) .unwrap(), )) }) @@ -1268,12 +1225,11 @@ fn memory_exec_sorted( fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { // Create a schema let schema = create_test_schema3()?; - let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( - "a", &schema, - )]; - let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs); - let coalesced = - Arc::new(CoalescePartitionsExec::new(parquet_exec.clone()).with_fetch(Some(10))); + let parquet_sort_exprs = vec![[sort_expr("a", &schema)].into()]; + let parquet_exec = parquet_exec_with_sort(schema, parquet_sort_exprs); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()) + .with_fetch(Some(10)) + .unwrap(); // Test sort's fetch is greater than coalesce fetch, return error because it's not reasonable let requirements = OrderPreservationContext::new( @@ -1315,17 +1271,15 @@ fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { #[test] fn test_plan_with_order_breaking_variants_preserves_fetch() -> Result<()> { let schema = create_test_schema3()?; - let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( - "a", &schema, - )]; - let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - let spm = SortPreservingMergeExec::new( - LexOrdering::new(parquet_sort_exprs), + let parquet_sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let parquet_exec = parquet_exec_with_sort(schema, vec![parquet_sort_exprs.clone()]); + let spm = sort_preserving_merge_exec_with_fetch( + parquet_sort_exprs, parquet_exec.clone(), - ) - .with_fetch(Some(10)); + 10, + ); let requirements = OrderPreservationContext::new( - Arc::new(spm), + spm, true, vec![OrderPreservationContext::new( parquet_exec.clone(), diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index a73d084a081f..bf9de28d2905 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -30,6 +30,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{JoinType, Result}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::repartition::RepartitionExec; @@ -410,16 +411,17 @@ fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { async fn test_bounded_window_agg_sort_requirement() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "c9", &source.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - let bw = bounded_window_exec("c9", sort_exprs, sort); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); + let bw = bounded_window_exec("c9", ordering, sort); assert_plan(bw.as_ref(), vec![ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", @@ -458,7 +460,7 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { async fn test_global_limit_single_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(source); + let limit = global_limit_exec(source, 0, Some(100)); assert_plan( limit.as_ref(), @@ -477,7 +479,7 @@ async fn test_global_limit_single_partition() -> Result<()> { async fn test_global_limit_multi_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(repartition_exec(source)); + let limit = global_limit_exec(repartition_exec(source), 0, Some(100)); assert_plan( limit.as_ref(), @@ -497,7 +499,7 @@ async fn test_global_limit_multi_partition() -> Result<()> { async fn test_local_limit() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = local_limit_exec(source); + let limit = local_limit_exec(source, 100); assert_plan( limit.as_ref(), @@ -518,12 +520,12 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -562,15 +564,16 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let schema2 = create_test_schema2(); let source1 = memory_exec(&schema1); let right = memory_exec(&schema2); - let sort_exprs1 = vec![sort_expr_options( + let ordering1 = [sort_expr_options( "c9", &source1.schema(), SortOptions::default(), - )]; - let left = sort_exec(sort_exprs1, source1); + )] + .into(); + let left = sort_exec(ordering1, source1); // Missing sort of the right child here.. - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -610,16 +613,16 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); let right = Arc::new(RepartitionExec::try_new( right, Partitioning::RoundRobinBatch(10), )?); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 955486a31030..ebb623dc6110 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -40,10 +40,10 @@ use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{expressions, PhysicalExpr}; +use datafusion_physical_expr::expressions::{self, col}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -56,6 +56,7 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -69,10 +70,10 @@ use datafusion_physical_plan::{ }; /// Create a non sorted parquet exec -pub fn parquet_exec(schema: &SchemaRef) -> Arc { +pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -83,11 +84,12 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { /// Create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( + schema: SchemaRef, output_ordering: Vec, ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -243,14 +245,23 @@ pub fn bounded_window_exec( sort_exprs: impl IntoIterator, input: Arc, ) -> Arc { - let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); + bounded_window_exec_with_partition(col_name, sort_exprs, &[], input) +} + +pub fn bounded_window_exec_with_partition( + col_name: &str, + sort_exprs: impl IntoIterator, + partition_by: &[Arc], + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect::>(); let schema = input.schema(); let window_expr = create_window_expr( &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], - &[], - sort_exprs.as_ref(), + partition_by, + &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, @@ -276,36 +287,37 @@ pub fn filter_exec( } pub fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + Arc::new(SortPreservingMergeExec::new(ordering, input)) } pub fn sort_preserving_merge_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, fetch: usize, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input).with_fetch(Some(fetch))) + Arc::new(SortPreservingMergeExec::new(ordering, input).with_fetch(Some(fetch))) } pub fn union_exec(input: Vec>) -> Arc { Arc::new(UnionExec::new(input)) } -pub fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) -} - -pub fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) +pub fn local_limit_exec( + input: Arc, + fetch: usize, +) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) } -pub fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +pub fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } pub fn repartition_exec(input: Arc) -> Arc { @@ -335,30 +347,46 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) +pub fn coalesce_batches_exec( + input: Arc, + batch_size: usize, +) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, batch_size)) } pub fn sort_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, + input: Arc, +) -> Arc { + sort_exec_with_fetch(ordering, None, input) +} + +pub fn sort_exec_with_preserve_partitioning( + ordering: LexOrdering, input: Arc, ) -> Arc { - sort_exec_with_fetch(sort_exprs, None, input) + Arc::new(SortExec::new(ordering, input).with_preserve_partitioning(true)) } pub fn sort_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, fetch: Option, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input).with_fetch(fetch)) + Arc::new(SortExec::new(ordering, input).with_fetch(fetch)) +} + +pub fn projection_exec( + expr: Vec<(Arc, String)>, + input: Arc, +) -> Result> { + Ok(Arc::new(ProjectionExec::try_new(expr, input)?)) } /// A test [`ExecutionPlan`] whose requirements can be configured. #[derive(Debug)] pub struct RequirementsTestExec { - required_input_ordering: LexOrdering, + required_input_ordering: Option, maintains_input_order: bool, input: Arc, } @@ -366,7 +394,7 @@ pub struct RequirementsTestExec { impl RequirementsTestExec { pub fn new(input: Arc) -> Self { Self { - required_input_ordering: LexOrdering::default(), + required_input_ordering: None, maintains_input_order: true, input, } @@ -375,7 +403,7 @@ impl RequirementsTestExec { /// sets the required input ordering pub fn with_required_input_ordering( mut self, - required_input_ordering: LexOrdering, + required_input_ordering: Option, ) -> Self { self.required_input_ordering = required_input_ordering; self @@ -420,9 +448,11 @@ impl ExecutionPlan for RequirementsTestExec { self.input.properties() } - fn required_input_ordering(&self) -> Vec> { - let requirement = LexRequirement::from(self.required_input_ordering.clone()); - vec![Some(requirement)] + fn required_input_ordering(&self) -> Vec> { + vec![self + .required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -501,13 +531,28 @@ impl PartitionStream for TestStreamPartition { } } -/// Create an unbounded stream exec +/// Create an unbounded stream table without data ordering. +pub fn stream_exec(schema: &SchemaRef) -> Arc { + Arc::new( + StreamingTableExec::try_new( + Arc::clone(schema), + vec![Arc::new(TestStreamPartition { + schema: Arc::clone(schema), + }) as _], + None, + vec![], + true, + None, + ) + .unwrap(), + ) +} + +/// Create an unbounded stream table with data ordering. pub fn stream_exec_ordered( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( StreamingTableExec::try_new( Arc::clone(schema), @@ -515,7 +560,7 @@ pub fn stream_exec_ordered( schema: Arc::clone(schema), }) as _], None, - vec![sort_exprs], + vec![ordering], true, None, ) @@ -523,12 +568,11 @@ pub fn stream_exec_ordered( ) } -// Creates a stream exec source for the test purposes +/// Create an unbounded stream table with data ordering and built-in projection. pub fn stream_exec_ordered_with_projection( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; Arc::new( @@ -538,7 +582,7 @@ pub fn stream_exec_ordered_with_projection( schema: Arc::clone(schema), }) as _], Some(&projection), - vec![sort_exprs], + vec![ordering], true, None, ) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 25458efa4fa5..a2437d2da814 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -973,10 +973,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - - fn aliases(&self) -> &[String] { - &[] - } } impl ScalarFunctionWrapper { diff --git a/datafusion/datasource/benches/split_groups_by_statistics.rs b/datafusion/datasource/benches/split_groups_by_statistics.rs index 3876b0b1217b..d51fdfc0a6e9 100644 --- a/datafusion/datasource/benches/split_groups_by_statistics.rs +++ b/datafusion/datasource/benches/split_groups_by_statistics.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; +use std::time::Duration; + use arrow::datatypes::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::{generate_test_files, verify_sort_integrity}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; -use std::time::Duration; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { let file_schema = Arc::new(Schema::new(vec![Field::new( @@ -31,13 +33,8 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { false, )])); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(datafusion_physical_expr::expressions::Column::new( - "value", 0, - )), - options: arrow::compute::SortOptions::default(), - }; - let sort_ordering = LexOrdering::from(vec![sort_expr]); + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("value", 0))); + let sort_ordering = LexOrdering::from([sort_expr]); // Small, medium, large number of files let file_counts = [10, 100, 1000]; diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index b6cdd05e5230..174ff06ebcc9 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -43,18 +43,18 @@ use arrow::{ buffer::Buffer, datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, }; +use datafusion_common::config::ConfigOptions; use datafusion_common::{ - config::ConfigOptions, exec_err, ColumnStatistics, Constraints, Result, Statistics, + exec_err, ColumnStatistics, Constraints, DataFusionError, Result, ScalarValue, + Statistics, }; -use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::{ - expressions::Column, EquivalenceProperties, LexOrdering, Partitioning, - PhysicalSortExpr, -}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; use datafusion_physical_plan::{ display::{display_orderings, ProjectSchemaDisplay}, @@ -62,6 +62,7 @@ use datafusion_physical_plan::{ projection::{all_alias_free_columns, new_projections_for_columns, ProjectionExec}, DisplayAs, DisplayFormatType, ExecutionPlan, }; + use log::{debug, warn}; /// The base configurations for a [`DataSourceExec`], the a physical plan for @@ -547,7 +548,7 @@ impl DataSource for FileScanConfig { fn eq_properties(&self) -> EquivalenceProperties { let (schema, constraints, _, orderings) = self.project(); - EquivalenceProperties::new_with_orderings(schema, orderings.as_slice()) + EquivalenceProperties::new_with_orderings(schema, orderings) .with_constraints(constraints) } @@ -659,7 +660,7 @@ impl FileScanConfig { object_store_url, file_schema, file_groups: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), projection: None, limit: None, table_partition_cols: vec![], @@ -748,10 +749,7 @@ impl FileScanConfig { pub fn projected_constraints(&self) -> Constraints { let indexes = self.projection_indices(); - - self.constraints - .project(&indexes) - .unwrap_or_else(Constraints::empty) + self.constraints.project(&indexes).unwrap_or_default() } /// Set the projection of the files @@ -1435,16 +1433,16 @@ fn get_projected_output_ordering( ) -> Vec { let mut all_orderings = vec![]; for output_ordering in &base_config.output_ordering { - let mut new_ordering = LexOrdering::default(); + let mut new_ordering = vec![]; for PhysicalSortExpr { expr, options } in output_ordering.iter() { if let Some(col) = expr.as_any().downcast_ref::() { let name = col.name(); if let Some((idx, _)) = projected_schema.column_with_name(name) { // Compute the new sort expression (with correct index) after projection: - new_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(name, idx)), - options: *options, - }); + new_ordering.push(PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + *options, + )); continue; } } @@ -1453,11 +1451,9 @@ fn get_projected_output_ordering( break; } - // do not push empty entries - // otherwise we may have `Some(vec![])` at the output ordering. - if new_ordering.is_empty() { + let Some(new_ordering) = LexOrdering::new(new_ordering) else { continue; - } + }; // Check if any file groups are not sorted if base_config.file_groups.iter().any(|group| { @@ -1518,41 +1514,17 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { #[cfg(test)] mod tests { + use super::*; use crate::{ generate_test_files, test_util::MockSource, tests::aggr_test_schema, verify_sort_integrity, }; - use super::*; - use arrow::{ - array::{Int32Array, RecordBatch}, - compute::SortOptions, - }; - + use arrow::array::{Int32Array, RecordBatch}; use datafusion_common::stats::Precision; - use datafusion_common::{assert_batches_eq, DFSchema}; - use datafusion_expr::{execution_props::ExecutionProps, SortExpr}; - use datafusion_physical_expr::create_physical_expr; - use std::collections::HashMap; - - fn create_physical_sort_expr( - e: &SortExpr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, - ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } + use datafusion_common::{assert_batches_eq, internal_err}; + use datafusion_expr::SortExpr; + use datafusion_physical_expr::create_physical_sort_expr; /// Returns the column names on the schema pub fn columns(schema: &Schema) -> Vec { @@ -2069,7 +2041,7 @@ mod tests { )))) .collect::>(), )); - let sort_order = LexOrdering::from( + let Some(sort_order) = LexOrdering::new( case.sort .into_iter() .map(|expr| { @@ -2080,7 +2052,9 @@ mod tests { ) }) .collect::>>()?, - ); + ) else { + return internal_err!("This test should always use an ordering"); + }; let partitioned_files = FileGroup::new( case.files.into_iter().map(From::from).collect::>(), @@ -2243,13 +2217,15 @@ mod tests { wrap_partition_type_in_dict(DataType::Utf8), false, )]) - .with_constraints(Constraints::empty()) .with_statistics(Statistics::new_unknown(&file_schema)) .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( "test.parquet".to_string(), 1024, )])]) - .with_output_ordering(vec![LexOrdering::default()]) + .with_output_ordering(vec![[PhysicalSortExpr::new_default(Arc::new( + Column::new("date", 0), + ))] + .into()]) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .with_newlines_in_values(true) .build(); @@ -2393,13 +2369,11 @@ mod tests { let exec_props = ExecutionProps::new(); let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; let sort_expr = [col("value").sort(true, false)]; - - let physical_sort_exprs: Vec<_> = sort_expr - .iter() - .map(|expr| create_physical_sort_expr(expr, &df_schema, &exec_props).unwrap()) - .collect(); - - let sort_ordering = LexOrdering::from(physical_sort_exprs); + let sort_ordering = sort_expr + .map(|expr| { + create_physical_sort_expr(&expr, &df_schema, &exec_props).unwrap() + }) + .into(); // Test case parameters struct TestCase { diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 54cea71843ee..98c7bd273fdf 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -24,7 +24,17 @@ use std::sync::Arc; use crate::sink::DataSink; use crate::source::{DataSource, DataSourceExec}; -use async_trait::async_trait; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use arrow::datatypes::{Schema, SchemaRef}; +use datafusion_common::{internal_err, plan_err, project_schema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::{ + OrderingEquivalenceClass, ProjectionMapping, +}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use datafusion_physical_plan::memory::MemoryStream; use datafusion_physical_plan::projection::{ all_alias_free_columns, new_projections_for_columns, ProjectionExec, @@ -34,14 +44,7 @@ use datafusion_physical_plan::{ PhysicalExpr, SendableRecordBatchStream, Statistics, }; -use arrow::array::{RecordBatch, RecordBatchOptions}; -use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::{internal_err, plan_err, project_schema, Result, ScalarValue}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use async_trait::async_trait; use futures::StreamExt; use itertools::Itertools; use tokio::sync::RwLock; @@ -181,7 +184,7 @@ impl DataSource for MemorySourceConfig { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } @@ -424,24 +427,21 @@ impl MemorySourceConfig { // If there is a projection on the source, we also need to project orderings if let Some(projection) = &self.projection { + let base_schema = self.original_schema(); + let proj_exprs = projection.iter().map(|idx| { + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }); + let projection_mapping = + ProjectionMapping::try_new(proj_exprs, &base_schema)?; let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, + Arc::clone(&base_schema), + sort_information, ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + let proj_eqp = + base_eqp.project(&projection_mapping, Arc::clone(&self.projected_schema)); + let oeq_class: OrderingEquivalenceClass = proj_eqp.into(); + sort_information = oeq_class.into(); } self.sort_information = sort_information; @@ -463,7 +463,7 @@ impl MemorySourceConfig { target_partitions: usize, output_ordering: LexOrdering, ) -> Result>>> { - if !self.eq_properties().ordering_satisfy(&output_ordering) { + if !self.eq_properties().ordering_satisfy(output_ordering)? { Ok(None) } else { let total_num_batches = @@ -765,22 +765,22 @@ mod memory_source_tests { use crate::memory::MemorySourceConfig; use crate::source::DataSourceExec; - use datafusion_physical_plan::ExecutionPlan; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_expr_common::sort_expr::LexOrdering; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::ExecutionPlan; #[test] - fn test_memory_order_eq() -> datafusion_common::Result<()> { + fn test_memory_order_eq() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, false), Field::new("b", DataType::Int64, false), Field::new("c", DataType::Int64, false), ])); - let sort1 = LexOrdering::new(vec![ + let sort1: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), @@ -789,13 +789,14 @@ mod memory_source_tests { expr: col("b", &schema)?, options: SortOptions::default(), }, - ]); - let sort2 = LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(); + let sort2: LexOrdering = [PhysicalSortExpr { expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let mut expected_output_order = LexOrdering::default(); - expected_output_order.extend(sort1.clone()); + }] + .into(); + let mut expected_output_order = sort1.clone(); expected_output_order.extend(sort2.clone()); let sort_information = vec![sort1.clone(), sort2.clone()]; @@ -817,19 +818,17 @@ mod memory_source_tests { #[cfg(test)] mod tests { + use super::*; use crate::test_util::col; use crate::tests::{aggr_test_schema, make_partition}; - use super::*; - use arrow::array::{ArrayRef, Int32Array, Int64Array, StringArray}; - use arrow::compute::SortOptions; - use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_plan::expressions::lit; - use arrow::datatypes::{DataType, Field}; use datafusion_common::assert_batches_eq; use datafusion_common::stats::{ColumnStatistics, Precision}; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::expressions::lit; + use futures::StreamExt; #[tokio::test] @@ -1310,10 +1309,8 @@ mod tests { #[test] fn test_repartition_with_sort_information() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), - options: SortOptions::default(), - }]); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); let has_sort = vec![sort_key.clone()]; let output_ordering = Some(sort_key); @@ -1360,10 +1357,8 @@ mod tests { #[test] fn test_repartition_with_batch_ordering_not_matching_sizing() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), - options: SortOptions::default(), - }]); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); let has_sort = vec![sort_key.clone()]; let output_ordering = Some(sort_key); diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index 0552370d8ed0..d92d168f8e5f 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -22,20 +22,18 @@ use std::fmt; use std::fmt::Debug; use std::sync::Arc; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::{ - execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, -}; - use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexRequirement, OrderingRequirements}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, + ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, +}; use async_trait::async_trait; use futures::StreamExt; @@ -184,10 +182,10 @@ impl ExecutionPlan for DataSinkExec { vec![Distribution::SinglePartition; self.children().len()] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { // The required input ordering is set externally (e.g. by a `ListingTable`). - // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). - vec![self.sort_order.as_ref().cloned()] + // Otherwise, there is no specific requirement (i.e. `sort_order` is `None`). + vec![self.sort_order.as_ref().cloned().map(Into::into)] } fn maintains_input_order(&self) -> Vec { diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index b42d3bb361b7..db9af0ff7675 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -20,24 +20,25 @@ //! Currently, this module houses code to sort file groups if they are non-overlapping with //! respect to the required sort order. See [`MinMaxStatistics`] -use futures::{Stream, StreamExt}; use std::sync::Arc; use crate::file_groups::FileGroup; use crate::PartitionedFile; use arrow::array::RecordBatch; +use arrow::compute::SortColumn; use arrow::datatypes::SchemaRef; -use arrow::{ - compute::SortColumn, - row::{Row, Rows}, -}; +use arrow::row::{Row, Rows}; use datafusion_common::stats::Precision; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::{ColumnStatistics, Statistics}; +use futures::{Stream, StreamExt}; + /// A normalized representation of file min/max statistics that allows for efficient sorting & comparison. /// The min/max values are ordered by [`Self::sort_order`]. /// Furthermore, any columns that are reversed in the sort order have their min/max values swapped. @@ -71,9 +72,7 @@ impl MinMaxStatistics { projection: Option<&[usize]>, // Indices of projection in full table schema (None = all columns) files: impl IntoIterator, ) -> Result { - use datafusion_common::ScalarValue; - - let statistics_and_partition_values = files + let Some(statistics_and_partition_values) = files .into_iter() .map(|file| { file.statistics @@ -81,9 +80,9 @@ impl MinMaxStatistics { .zip(Some(file.partition_values.as_slice())) }) .collect::>>() - .ok_or_else(|| { - DataFusionError::Plan("Parquet file missing statistics".to_string()) - })?; + else { + return plan_err!("Parquet file missing statistics"); + }; // Helper function to get min/max statistics for a given column of projected_schema let get_min_max = |i: usize| -> Result<(Vec, Vec)> { @@ -96,9 +95,7 @@ impl MinMaxStatistics { .get_value() .cloned() .zip(s.column_statistics[i].max_value.get_value().cloned()) - .ok_or_else(|| { - DataFusionError::Plan("statistics not found".to_string()) - }) + .ok_or_else(|| plan_datafusion_err!("statistics not found")) } else { let partition_value = &pv[i - s.column_statistics.len()]; Ok((partition_value.clone(), partition_value.clone())) @@ -109,27 +106,28 @@ impl MinMaxStatistics { .unzip()) }; - let sort_columns = sort_columns_from_physical_sort_exprs(projected_sort_order) - .ok_or(DataFusionError::Plan( - "sort expression must be on column".to_string(), - ))?; + let Some(sort_columns) = + sort_columns_from_physical_sort_exprs(projected_sort_order) + else { + return plan_err!("sort expression must be on column"); + }; // Project the schema & sort order down to just the relevant columns let min_max_schema = Arc::new( projected_schema .project(&(sort_columns.iter().map(|c| c.index()).collect::>()))?, ); - let min_max_sort_order = LexOrdering::from( - sort_columns - .iter() - .zip(projected_sort_order.iter()) - .enumerate() - .map(|(i, (col, sort))| PhysicalSortExpr { - expr: Arc::new(Column::new(col.name(), i)), - options: sort.options, - }) - .collect::>(), - ); + + let min_max_sort_order = projected_sort_order + .iter() + .zip(sort_columns.iter()) + .enumerate() + .map(|(idx, (sort_expr, col))| { + let expr = Arc::new(Column::new(col.name(), idx)); + PhysicalSortExpr::new(expr, sort_expr.options) + }); + // Safe to `unwrap` as we know that sort columns are non-empty: + let min_max_sort_order = LexOrdering::new(min_max_sort_order).unwrap(); let (min_values, max_values): (Vec<_>, Vec<_>) = sort_columns .iter() @@ -189,25 +187,23 @@ impl MinMaxStatistics { .map_err(|e| e.context("create sort fields"))?; let converter = RowConverter::new(sort_fields)?; - let sort_columns = sort_columns_from_physical_sort_exprs(sort_order).ok_or( - DataFusionError::Plan("sort expression must be on column".to_string()), - )?; + let Some(sort_columns) = sort_columns_from_physical_sort_exprs(sort_order) else { + return plan_err!("sort expression must be on column"); + }; // swap min/max if they're reversed in the ordering let (new_min_cols, new_max_cols): (Vec<_>, Vec<_>) = sort_order .iter() .zip(sort_columns.iter().copied()) .map(|(sort_expr, column)| { - if sort_expr.options.descending { - max_values - .column_by_name(column.name()) - .zip(min_values.column_by_name(column.name())) + let maxes = max_values.column_by_name(column.name()); + let mins = min_values.column_by_name(column.name()); + let opt_value = if sort_expr.options.descending { + maxes.zip(mins) } else { - min_values - .column_by_name(column.name()) - .zip(max_values.column_by_name(column.name())) - } - .ok_or_else(|| { + mins.zip(maxes) + }; + opt_value.ok_or_else(|| { plan_datafusion_err!( "missing column in MinMaxStatistics::new: '{}'", column.name() @@ -285,7 +281,7 @@ fn sort_columns_from_physical_sort_exprs( sort_order .iter() .map(|expr| expr.expr.as_any().downcast_ref::()) - .collect::>>() + .collect() } /// Get all files as well as the file level summary statistics (no statistic for partition columns). diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 691f5684a11c..17253f644740 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -988,7 +988,7 @@ impl LogicalPlan { Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(input), - constraints: Constraints::empty(), + constraints: Constraints::default(), name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 673908a4d7e7..f310f31be352 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -179,10 +179,6 @@ impl AggregateUDFImpl for Sum { unreachable!("stub should not have state_fields()") } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -344,10 +340,6 @@ impl AggregateUDFImpl for Min { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -429,10 +421,6 @@ impl AggregateUDFImpl for Max { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -494,6 +482,7 @@ impl AggregateUDFImpl for Avg { fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } + fn aliases(&self) -> &[String] { &self.aliases } diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 587e667a4775..832e82dda35b 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -188,12 +188,12 @@ impl TryFrom for PlanProperties { let proto_output_ordering = PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) .map_err(|e| DataFusionError::External(Box::new(e)))?; - let orderings = Some(parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, &default_ctx, &schema, &codex, - )?); + )?; let partitioning_vec = unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; @@ -211,11 +211,10 @@ impl TryFrom for PlanProperties { .to_string(), ))?; - let eq_properties = match orderings { - Some(ordering) => { - EquivalenceProperties::new_with_orderings(Arc::new(schema), &[ordering]) - } - None => EquivalenceProperties::new(Arc::new(schema)), + let eq_properties = if sort_exprs.is_empty() { + EquivalenceProperties::new(Arc::new(schema)) + } else { + EquivalenceProperties::new_with_orderings(Arc::new(schema), [sort_exprs]) }; let emission_type: EmissionType = @@ -300,10 +299,7 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::{ - physical_expr::{LexOrdering, PhysicalSortExpr}, - physical_plan::Partitioning, - }; + use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; use super::*; @@ -313,13 +309,12 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let mut eqp = EquivalenceProperties::new(Arc::clone(&schema)); + let _ = eqp.reorder([PhysicalSortExpr::new_default( + datafusion::physical_plan::expressions::col("a", &schema)?, + )]); let original_props = PlanProperties::new( - EquivalenceProperties::new(Arc::clone(&schema)).with_reorder( - LexOrdering::new(vec![PhysicalSortExpr { - expr: datafusion::physical_plan::expressions::col("a", &schema)?, - options: Default::default(), - }]), - ), + eqp, Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 699af1d5c5e0..874a2ac8b82e 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -25,8 +25,10 @@ use abi_stable::{ use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; use datafusion::{ - error::DataFusionError, logical_expr::function::AccumulatorArgs, - physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, + error::DataFusionError, + logical_expr::function::AccumulatorArgs, + physical_expr::{PhysicalExpr, PhysicalSortExpr}, + prelude::SessionContext, }; use datafusion_proto::{ physical_plan::{ @@ -62,7 +64,7 @@ impl TryFrom> for FFI_AccumulatorArgs { let codec = DefaultPhysicalExtensionCodec {}; let ordering_req = - serialize_physical_sort_exprs(args.ordering_req.to_owned(), &codec)?; + serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?; let expr = serialize_physical_exprs(args.exprs, &codec)?; @@ -94,7 +96,7 @@ pub struct ForeignAccumulatorArgs { pub return_field: FieldRef, pub schema: Schema, pub ignore_nulls: bool, - pub ordering_req: LexOrdering, + pub order_bys: Vec, pub is_reversed: bool, pub name: String, pub is_distinct: bool, @@ -115,9 +117,7 @@ impl TryFrom for ForeignAccumulatorArgs { let default_ctx = SessionContext::new(); let codex = DefaultPhysicalExtensionCodec {}; - // let proto_ordering_req = - // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); - let ordering_req = parse_physical_sort_exprs( + let order_bys = parse_physical_sort_exprs( &proto_def.ordering_req, &default_ctx, &schema, @@ -130,7 +130,7 @@ impl TryFrom for ForeignAccumulatorArgs { return_field, schema, ignore_nulls: proto_def.ignore_nulls, - ordering_req, + order_bys, is_reversed: value.is_reversed, name: value.name.to_string(), is_distinct: proto_def.distinct, @@ -145,7 +145,7 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { return_field: Arc::clone(&value.return_field), schema: &value.schema, ignore_nulls: value.ignore_nulls, - ordering_req: &value.ordering_req, + order_bys: &value.order_bys, is_reversed: value.is_reversed, name: value.name.as_str(), is_distinct: value.is_distinct, @@ -159,10 +159,8 @@ mod tests { use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ - error::Result, - logical_expr::function::AccumulatorArgs, - physical_expr::{LexOrdering, PhysicalSortExpr}, - physical_plan::expressions::col, + error::Result, logical_expr::function::AccumulatorArgs, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, }; #[test] @@ -172,10 +170,7 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema)?, - options: Default::default(), - }]), + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], is_reversed: false, name: "round_trip", is_distinct: true, diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 2529ed7a06dc..eb7a408ab178 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -588,10 +588,8 @@ impl From for FFI_AggregateOrderSensitivity { mod tests { use arrow::datatypes::Schema; use datafusion::{ - common::create_array, - functions_aggregate::sum::Sum, - physical_expr::{LexOrdering, PhysicalSortExpr}, - physical_plan::expressions::col, + common::create_array, functions_aggregate::sum::Sum, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, scalar::ScalarValue, }; @@ -644,10 +642,7 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: true, - ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema)?, - options: Default::default(), - }]), + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], is_reversed: false, name: "round_trip", is_distinct: true, @@ -698,10 +693,7 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: true, - ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema)?, - options: Default::default(), - }]), + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], is_reversed: false, name: "round_trip", is_distinct: true, diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index 01b16f1b0a8c..39303889f0fe 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -19,7 +19,7 @@ use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; /// [`AccumulatorArgs`] contains information about how an aggregate @@ -50,9 +50,7 @@ pub struct AccumulatorArgs<'a> { /// ```sql /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; /// ``` - /// - /// If no `ORDER BY` is specified, `ordering_req` will be empty. - pub ordering_req: &'a LexOrdering, + pub order_bys: &'a [PhysicalSortExpr], /// Whether the aggregation is running in reverse order pub is_reversed: bool, diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 229d9a900105..2f20e916743b 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr_common::accumulator::Accumulator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -87,13 +87,13 @@ pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result Vec { - ordering_req + order_bys .iter() .zip(data_types.iter()) .map(|(sort_expr, dtype)| { diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index cffa50bdda12..80cb65be2ed7 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -15,20 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::util::bench_util::{ create_boolean_array, create_dict_from_values, create_primitive_array, create_string_array_with_len, }; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDFImpl, GroupsAccumulator, -}; + +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; fn prepare_group_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); @@ -36,7 +37,7 @@ fn prepare_group_accumulator() -> Box { return_field: Field::new("f", DataType::Int64, true).into(), schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "COUNT(f)", is_distinct: false, @@ -59,7 +60,7 @@ fn prepare_accumulator() -> Box { return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "COUNT(f)", is_distinct: true, diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index 25df78b15f11..4517db6b1510 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; + use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; fn prepare_accumulator(data_type: &DataType) -> Box { let field = Field::new("f", data_type.clone(), true).into(); @@ -32,7 +34,7 @@ fn prepare_accumulator(data_type: &DataType) -> Box { return_field: field, schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "SUM(f)", is_distinct: false, diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 024c0a823fa9..9b0d62e936bc 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -143,7 +143,7 @@ impl ApproxPercentileCont { let percentile = validate_input_percentile_expr(&args.exprs[1])?; let is_descending = args - .ordering_req + .order_bys .first() .map(|sort_expr| sort_expr.options.descending) .unwrap_or(false); diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 71278767a83f..8c426d1f884a 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,6 +17,11 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] +use std::cmp::Ordering; +use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val}; +use std::sync::Arc; + use arrow::array::{ new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, }; @@ -25,20 +30,16 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, Signature, Volatility}; -use datafusion_expr::{AggregateUDFImpl, Documentation}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; -use std::mem::{size_of, size_of_val}; -use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, @@ -94,10 +95,6 @@ impl AggregateUDFImpl for ArrayAgg { "array_agg" } - fn aliases(&self) -> &[String] { - &[] - } - fn signature(&self) -> &Signature { &self.signature } @@ -165,17 +162,15 @@ impl AggregateUDFImpl for ArrayAgg { // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid // ARRAY_AGG(DISTINCT col ORDER BY other_col) <- Invalid // ARRAY_AGG(DISTINCT col ORDER BY concat(col, '')) <- Invalid - if acc_args.ordering_req.len() > 1 { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); - } - let mut sort_option: Option = None; - if let Some(order) = acc_args.ordering_req.first() { - if !order.expr.eq(&acc_args.exprs[0]) { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); + let sort_option = match acc_args.order_bys { + [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options), + [] => None, + _ => { + return exec_err!( + "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list" + ); } - sort_option = Some(order.options) - } - + }; return Ok(Box::new(DistinctArrayAggAccumulator::try_new( &data_type, sort_option, @@ -183,15 +178,14 @@ impl AggregateUDFImpl for ArrayAgg { )?)); } - if acc_args.ordering_req.is_empty() { + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { return Ok(Box::new(ArrayAggAccumulator::try_new( &data_type, ignore_nulls, )?)); - } + }; - let ordering_dtypes = acc_args - .ordering_req + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; @@ -199,7 +193,7 @@ impl AggregateUDFImpl for ArrayAgg { OrderSensitiveArrayAggAccumulator::try_new( &data_type, &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.is_reversed, ignore_nulls, ) @@ -538,6 +532,31 @@ impl OrderSensitiveArrayAggAccumulator { ignore_nulls, }) } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let column_wise_ordering_values = if self.ordering_values.is_empty() { + fields + .iter() + .map(|f| new_empty_array(f.data_type())) + .collect::>() + } else { + (0..fields.len()) + .map(|i| { + let column_values = self.ordering_values.iter().map(|x| x[i].clone()); + ScalarValue::iter_to_array(column_values) + }) + .collect::>()? + }; + + let ordering_array = StructArray::try_new( + Fields::from(fields), + column_wise_ordering_values, + None, + )?; + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } } impl Accumulator for OrderSensitiveArrayAggAccumulator { @@ -692,33 +711,6 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } } -impl OrderSensitiveArrayAggAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let num_columns = fields.len(); - - let mut column_wise_ordering_values = vec![]; - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let struct_field = Fields::from(fields); - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -726,7 +718,7 @@ mod tests { use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; #[test] @@ -991,7 +983,7 @@ mod tests { struct ArrayAggAccumulatorBuilder { return_field: FieldRef, distinct: bool, - ordering: LexOrdering, + order_bys: Vec, schema: Schema, } @@ -1004,7 +996,7 @@ mod tests { Self { return_field: Field::new("f", data_type.clone(), true).into(), distinct: false, - ordering: Default::default(), + order_bys: vec![], schema: Schema { fields: Fields::from(vec![Field::new( "col", @@ -1022,13 +1014,14 @@ mod tests { } fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { - self.ordering.extend([PhysicalSortExpr::new( + let new_order = PhysicalSortExpr::new( Arc::new( Column::new_with_schema(col, &self.schema) .expect("column not available in schema"), ), sort_options, - )]); + ); + self.order_bys.push(new_order); self } @@ -1037,7 +1030,7 @@ mod tests { return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: false, - ordering_req: &self.ordering, + order_bys: &self.order_bys, is_reversed: false, name: "", is_distinct: self.distinct, diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index e5de6d76217f..d779e0a399b5 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -179,10 +179,6 @@ impl AggregateUDFImpl for BoolAnd { } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } @@ -319,10 +315,6 @@ impl AggregateUDFImpl for BoolOr { } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index f375a68d9458..6fb2ed4528a9 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -347,10 +347,6 @@ impl AggregateUDFImpl for Count { }) } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc @@ -762,13 +758,14 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use arrow::array::{Int32Array, NullArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use datafusion_expr::function::AccumulatorArgs; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::LexOrdering; - use std::sync::Arc; #[test] fn count_accumulator_nulls() -> Result<()> { @@ -803,7 +800,7 @@ mod tests { ignore_nulls: false, is_reversed: false, return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), - ordering_req: &LexOrdering::default(), + order_bys: &[], }; let inner_dict = arrow::array::DictionaryArray::::from_iter([ diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index e8022245dba5..42c0a57fbf28 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -45,7 +45,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, - GroupsAccumulator, Signature, SortExpr, Volatility, + GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; @@ -149,24 +149,26 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialFirstValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - // When requirement is empty, or it is signalled by outside caller that - // the ordering requirement is/will be satisfied. - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - FirstValueAccumulator::try_new( acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.ignore_nulls, ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -176,60 +178,60 @@ impl AggregateUDFImpl for FirstValue { true, ) .into()]; - fields.extend(args.ordering_fields.to_vec()); + fields.extend(args.ordering_fields.iter().cloned()); fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - // TODO: extract to function use DataType::*; - matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - // TODO: extract to function - fn create_accumulator( + fn create_accumulator( args: AccumulatorArgs, - ) -> Result> - where - T: ArrowPrimitiveType + Send, - { - let ordering_dtypes = args - .ordering_req + ) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(args.schema)) .collect::>>()?; - Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - args.ordering_req.clone(), + FirstPrimitiveGroupsAccumulator::::try_new( + ordering, args.ignore_nulls, args.return_field.data_type(), &ordering_dtypes, true, - )?)) + ) + .map(|acc| Box::new(acc) as _) } match args.return_field.data_type() { @@ -277,19 +279,13 @@ impl AggregateUDFImpl for FirstValue { create_accumulator::(args) } - _ => { - internal_err!( - "GroupsAccumulator not supported for first_value({})", - args.return_field.data_type() - ) - } + _ => internal_err!( + "GroupsAccumulator not supported for first_value({})", + args.return_field.data_type() + ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, @@ -303,8 +299,8 @@ impl AggregateUDFImpl for FirstValue { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(last_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -350,8 +346,6 @@ where pick_first_in_group: bool, // derived from `ordering_req`. sort_options: Vec, - // Stores whether incoming data already satisfies the ordering requirement. - input_requirement_satisfied: bool, // Ignore null values. ignore_nulls: bool, /// The output type @@ -370,20 +364,17 @@ where ordering_dtypes: &[DataType], pick_first_in_group: bool, ) -> Result { - let requirement_satisfied = ordering_req.is_empty(); - let default_orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; + .collect::>()?; - let sort_options = get_sort_options(ordering_req.as_ref()); + let sort_options = get_sort_options(&ordering_req); Ok(Self { null_builder: BooleanBufferBuilder::new(0), ordering_req, sort_options, - input_requirement_satisfied: requirement_satisfied, ignore_nulls, default_orderings, data_type: data_type.clone(), @@ -396,18 +387,6 @@ where }) } - fn need_update(&self, group_idx: usize) -> bool { - if !self.is_sets.get_bit(group_idx) { - return true; - } - - if self.ignore_nulls && !self.null_builder.get_bit(group_idx) { - return true; - } - - !self.input_requirement_satisfied - } - fn should_update_state( &self, group_idx: usize, @@ -573,17 +552,12 @@ where let group_idx = *group_idx; let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val)); - let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val)); if !passed_filter || !is_set { continue; } - if !self.need_update(group_idx) { - continue; - } - if self.ignore_nulls && vals.is_null(idx_in_val) { continue; } @@ -720,7 +694,7 @@ where let (is_set_arr, val_and_order_cols) = match values.split_last() { Some(result) => result, - None => return internal_err!("Empty row in FISRT_VALUE"), + None => return internal_err!("Empty row in FIRST_VALUE"), }; let is_set_arr = as_boolean_array(is_set_arr)?; @@ -782,14 +756,96 @@ where } } } + +/// This accumulator is used when there is no ordering specified for the +/// `FIRST_VALUE` aggregation. It simply returns the first value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialFirstValueAccumulator { + first: ScalarValue, + // Whether we have seen the first value yet. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialFirstValueAccumulator { + /// Creates a new `TrivialFirstValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|first| Self { + first, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialFirstValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !self.is_set { + // Get first entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut first_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the first non-null value. + for i in 0..value.len() { + if !value.is_null(i) { + first_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the first value if it exists. + first_idx = Some(0); + } + if let Some(first_idx) = first_idx { + let mut row = get_row_at_idx(values, first_idx)?; + self.first = row.swap_remove(0); + self.first.compact(); + self.is_set = true; + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // FIRST_VALUE(first1, first2, first3, ...) + // Second index contains is_set flag. + if !self.is_set { + let flags = states[1].as_boolean(); + let filtered_states = + filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(first) = filtered_states.first() { + if !first.is_empty() { + self.first = ScalarValue::try_from_array(first, 0)?; + self.is_set = true; + } + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.first.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.first) + self.first.size() + } +} + #[derive(Debug)] pub struct FirstValueAccumulator { first: ScalarValue, - // At the beginning, `is_set` is false, which means `first` is not seen yet. - // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. + // Whether we have seen the first value yet. is_set: bool, - // Stores ordering values, of the aggregator requirement corresponding to first value - // of the aggregator. These values are used during merging of multiple partitions. + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, @@ -810,14 +866,13 @@ impl FirstValueAccumulator { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|first| Self { first, is_set: false, orderings, ordering_req, - requirement_satisfied, + requirement_satisfied: false, ignore_nulls, }) } @@ -830,10 +885,9 @@ impl FirstValueAccumulator { // Updates state with the values in the given row. fn update_with_new_row(&mut self, mut row: Vec) { // Ensure any Array based scalars hold have a single value to reduce memory pressure - row.iter_mut().for_each(|s| { + for s in row.iter_mut() { s.compact(); - }); - + } self.first = row.remove(0); self.orderings = row; self.is_set = true; @@ -886,30 +940,24 @@ impl Accumulator for FirstValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.first.clone()]; result.extend(self.orderings.iter().cloned()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + if !self.is_set + || (!self.requirement_satisfied + && compare_rows( + &self.orderings, + &row[1..], + &get_sort_options(&self.ordering_req), + )? + .is_gt()) + { self.update_with_new_row(row); } - } else if !self.requirement_satisfied { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - let orderings = &row[1..]; - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_gt() - { - self.update_with_new_row(row); - } - } } Ok(()) } @@ -922,10 +970,8 @@ impl Accumulator for FirstValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b)); @@ -934,7 +980,7 @@ impl Accumulator for FirstValueAccumulator { let mut first_row = get_row_at_idx(&filtered_states, first_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() @@ -1029,47 +1075,40 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialLastValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - LastValueAccumulator::try_new( acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.ignore_nulls, ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let StateFieldsArgs { - name, - input_fields, - return_field: _, - ordering_fields, - is_distinct: _, - } = args; let mut fields = vec![Field::new( - format_state_name(name, "last_value"), - input_fields[0].data_type().clone(), + format_state_name(args.name, "last_value"), + args.return_field.data_type().clone(), true, ) .into()]; - fields.extend(ordering_fields.to_vec()); + fields.extend(args.ordering_fields.iter().cloned()); fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, @@ -1083,8 +1122,8 @@ impl AggregateUDFImpl for LastValue { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(first_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -1093,26 +1132,27 @@ impl AggregateUDFImpl for LastValue { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; - matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } fn create_groups_accumulator( @@ -1125,14 +1165,17 @@ impl AggregateUDFImpl for LastValue { where T: ArrowPrimitiveType + Send, { - let ordering_dtypes = args - .ordering_req + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(args.schema)) .collect::>>()?; Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - args.ordering_req.clone(), + ordering, args.ignore_nulls, args.return_field.data_type(), &ordering_dtypes, @@ -1195,6 +1238,85 @@ impl AggregateUDFImpl for LastValue { } } +/// This accumulator is used when there is no ordering specified for the +/// `LAST_VALUE` aggregation. It simply updates the last value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialLastValueAccumulator { + last: ScalarValue, + // The `is_set` flag keeps track of whether the last value is finalized. + // This information is used to discriminate genuine NULLs and NULLS that + // occur due to empty partitions. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialLastValueAccumulator { + /// Creates a new `TrivialLastValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|last| Self { + last, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialLastValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Get last entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut last_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the last non-null value. + for i in (0..value.len()).rev() { + if !value.is_null(i) { + last_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the last value if it exists. + last_idx = Some(value.len() - 1); + } + if let Some(last_idx) = last_idx { + let mut row = get_row_at_idx(values, last_idx)?; + self.last = row.swap_remove(0); + self.last.compact(); + self.is_set = true; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // LAST_VALUE(last1, last2, last3, ...) + // Second index contains is_set flag. + let flags = states[1].as_boolean(); + let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(last) = filtered_states.last() { + if !last.is_empty() { + self.last = ScalarValue::try_from_array(last, 0)?; + self.is_set = true; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.last.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.last) + self.last.size() + } +} + #[derive(Debug)] struct LastValueAccumulator { last: ScalarValue, @@ -1202,6 +1324,8 @@ struct LastValueAccumulator { // This information is used to discriminate genuine NULLs and NULLS that // occur due to empty partitions. is_set: bool, + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, @@ -1222,14 +1346,13 @@ impl LastValueAccumulator { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|last| Self { last, is_set: false, orderings, ordering_req, - requirement_satisfied, + requirement_satisfied: false, ignore_nulls, }) } @@ -1237,10 +1360,9 @@ impl LastValueAccumulator { // Updates state with the values in the given row. fn update_with_new_row(&mut self, mut row: Vec) { // Ensure any Array based scalars hold have a single value to reduce memory pressure - row.iter_mut().for_each(|s| { + for s in row.iter_mut() { s.compact(); - }); - + } self.last = row.remove(0); self.orderings = row; self.is_set = true; @@ -1264,6 +1386,7 @@ impl LastValueAccumulator { return Ok((!value.is_empty()).then_some(value.len() - 1)); } } + let sort_columns = ordering_values .iter() .zip(self.ordering_req.iter()) @@ -1295,31 +1418,27 @@ impl Accumulator for LastValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.last.clone()]; result.extend(self.orderings.clone()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set || self.requirement_satisfied { - if let Some(last_idx) = self.get_last_idx(values)? { - let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(row); - } - } else if let Some(last_idx) = self.get_last_idx(values)? { + if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; let orderings = &row[1..]; // Update when there is a more recent entry - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_lt() + if !self.is_set + || self.requirement_satisfied + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() { self.update_with_new_row(row); } } - Ok(()) } @@ -1331,10 +1450,8 @@ impl Accumulator for LastValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b)); @@ -1343,7 +1460,7 @@ impl Accumulator for LastValueAccumulator { let mut last_row = get_row_at_idx(&filtered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set @@ -1382,7 +1499,7 @@ fn filter_states_according_to_is_set( states .iter() .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) - .collect::>>() + .collect() } /// Combines array refs and their corresponding orderings to construct `SortColumn`s. @@ -1393,7 +1510,7 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec>() + .collect() } #[cfg(test)] @@ -1411,18 +1528,10 @@ mod tests { #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -1459,22 +1568,14 @@ mod tests { .collect::>(); // FirstValueAccumulator - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; @@ -1489,34 +1590,22 @@ mod tests { ])?); } - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; assert_eq!(merged_state.len(), state1.len()); // LastValueAccumulator - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; @@ -1531,12 +1620,8 @@ mod tests { ])?); } - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; @@ -1546,7 +1631,7 @@ mod tests { } #[test] - fn test_frist_group_acc() -> Result<()> { + fn test_first_group_acc() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1555,13 +1640,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1649,13 +1734,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1730,13 +1815,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1798,10 +1883,8 @@ mod tests { #[test] fn test_first_list_acc_size() -> Result<()> { fn size_after_batch(values: &[ArrayRef]) -> Result { - let mut first_accumulator = FirstValueAccumulator::try_new( + let mut first_accumulator = TrivialFirstValueAccumulator::try_new( &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), - &[], - LexOrdering::default(), false, )?; @@ -1826,10 +1909,8 @@ mod tests { #[test] fn test_last_list_acc_size() -> Result<()> { fn size_after_batch(values: &[ArrayRef]) -> Result { - let mut last_accumulator = LastValueAccumulator::try_new( + let mut last_accumulator = TrivialLastValueAccumulator::try_new( &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), - &[], - LexOrdering::default(), false, )?; diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index bfaea4b2398c..5c3d265d1d6b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -214,10 +214,6 @@ impl AggregateUDFImpl for Median { } } - fn aliases(&self) -> &[String] { - &[] - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index bb46aa540461..2f8ac93a195e 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -239,10 +239,6 @@ impl AggregateUDFImpl for Max { )?)) } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( @@ -1211,10 +1207,6 @@ impl AggregateUDFImpl for Min { )?)) } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 1525b2f991a1..fac0d33ceabe 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -86,7 +86,7 @@ pub fn nth_value( description = "The position (nth) of the value to retrieve, based on the ordering." ) )] -/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi +/// Expression for a `NTH_VALUE(..., ... ORDER BY ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. #[derive(Debug)] @@ -148,20 +148,21 @@ impl AggregateUDFImpl for NthValueAgg { } }; - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialNthValueAccumulator::try_new( + n, + acc_args.return_field.data_type(), + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; - NthValueAccumulator::try_new( - n, - &data_type, - &ordering_dtypes, - acc_args.ordering_req.clone(), - ) - .map(|acc| Box::new(acc) as _) + NthValueAccumulator::try_new(n, &data_type, &ordering_dtypes, ordering) + .map(|acc| Box::new(acc) as _) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -182,10 +183,6 @@ impl AggregateUDFImpl for NthValueAgg { Ok(fields.into_iter().map(Arc::new).collect()) } - fn aliases(&self) -> &[String] { - &[] - } - fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Reversed(nth_value_udaf()) } @@ -195,6 +192,126 @@ impl AggregateUDFImpl for NthValueAgg { } } +#[derive(Debug)] +pub struct TrivialNthValueAccumulator { + /// The `N` value. + n: i64, + /// Stores entries in the `NTH_VALUE` result. + values: VecDeque, + /// Data types of the value. + datatype: DataType, +} + +impl TrivialNthValueAccumulator { + /// Create a new order-insensitive NTH_VALUE accumulator based on the given + /// item data type. + pub fn try_new(n: i64, datatype: &DataType) -> Result { + if n == 0 { + // n cannot be 0 + return internal_err!("Nth value indices are 1 based. 0 is invalid index"); + } + Ok(Self { + n, + values: VecDeque::new(), + datatype: datatype.clone(), + }) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let mut row = get_row_at_idx(values, index)?; + self.values.push_back(row.swap_remove(0)); + // At index 1, we have n index argument, which is constant. + } + Ok(()) + } +} + +impl Accumulator for TrivialNthValueAccumulator { + /// Updates its state with the `values`. Assumes data in the `values` satisfies the required + /// ordering for the accumulator (across consecutive batches, not just batch-wise). + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !values.is_empty() { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + if from_start { + // direction is from start + let n_remaining = n_required.saturating_sub(self.values.len()); + self.append_new_data(values, Some(n_remaining))?; + } else { + // direction is from end + self.append_new_data(values, None)?; + let start_offset = self.values.len().saturating_sub(n_required); + if start_offset > 0 { + self.values.drain(0..start_offset); + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if !states.is_empty() { + // First entry in the state is the aggregation result. + let n_required = self.n.unsigned_abs() as usize; + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter() { + self.values.extend(v); + if self.values.len() > n_required { + // There is enough data collected, can stop merging: + break; + } + } + } + Ok(()) + } + + fn state(&mut self) -> Result> { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatype, + ))]) + } + + fn evaluate(&mut self) -> Result { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + let nth_value_idx = if from_start { + // index is from start + let forward_idx = n_required - 1; + (forward_idx < self.values.len()).then_some(forward_idx) + } else { + // index is from end + self.values.len().checked_sub(n_required) + }; + if let Some(idx) = nth_value_idx { + Ok(self.values[idx].clone()) + } else { + ScalarValue::try_from(self.datatype.clone()) + } + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values) + + size_of::() + } +} + #[derive(Debug)] pub struct NthValueAccumulator { /// The `N` value. @@ -236,6 +353,64 @@ impl NthValueAccumulator { ordering_req, }) } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let mut column_wise_ordering_values = vec![]; + let num_columns = fields.len(); + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } + + let struct_field = Fields::from(fields); + let ordering_array = + StructArray::try_new(struct_field, column_wise_ordering_values, None)?; + + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } + + fn evaluate_values(&self) -> ScalarValue { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let row = get_row_at_idx(values, index)?; + self.values.push_back(row[0].clone()); + // At index 1, we have n index argument. + // Ordering values cover starting from 2nd index to end + self.ordering_values.push_back(row[2..].to_vec()); + } + Ok(()) + } } impl Accumulator for NthValueAccumulator { @@ -269,91 +444,60 @@ impl Accumulator for NthValueAccumulator { if states.is_empty() { return Ok(()); } - // First entry in the state is the aggregation result. - let array_agg_values = &states[0]; - let n_required = self.n.unsigned_abs() as usize; - if self.ordering_req.is_empty() { - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { - self.values.extend(v); - if self.values.len() > n_required { - // There is enough data collected can stop merging - break; - } - } - } else if let Some(agg_orderings) = states[1].as_list_opt::() { - // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. - // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores - // values received from its ordering requirement expression. (This information is necessary for during merging). - - // Stores NTH_VALUE results coming from each partition - let mut partition_values: Vec> = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values: Vec>> = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone()); - - partition_ordering_values.push(self.ordering_values.clone()); - - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - - for v in array_agg_res.into_iter() { - partition_values.push(v.into()); - } - - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { - // Extract value from struct to ordering_rows for each group/partition - partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(s) = ordering_row { - let mut ordering_columns_per_row = vec![]; - - for column in s.columns() { - let sv = ScalarValue::try_from_array(column, 0)?; - ordering_columns_per_row.push(sv); - } - - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", - ordering_row.data_type() - ) - } - }).collect::>>() - }).collect::>>()?; - for ordering_values in ordering_values.into_iter() { - partition_ordering_values.push(ordering_values.into()); - } - - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - let (new_values, new_orderings) = merge_ordered_arrays( - &mut partition_values, - &mut partition_ordering_values, - &sort_options, - )?; - self.values = new_values.into(); - self.ordering_values = new_orderings.into(); - } else { + // Second entry stores values received for ordering requirement columns + // for each aggregation value inside NTH_VALUE list. For each `StructArray` + // inside this list, we will receive an `Array` that stores values received + // from its ordering requirement expression. This information is necessary + // during merging. + let Some(agg_orderings) = states[1].as_list_opt::() else { return exec_err!("Expects to receive a list array"); + }; + + // Stores NTH_VALUE results coming from each partition + let mut partition_values = vec![self.values.clone()]; + // First entry in the state is the aggregation result. + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + // Stores ordering requirement expression results coming from each partition: + let mut partition_ordering_values = vec![self.ordering_values.clone()]; + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + // Extract value from struct to ordering_rows for each group/partition: + for partition_ordering_rows in orderings.into_iter() { + let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| { + let ScalarValue::Struct(s_array) = ordering_row else { + return exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ); + }; + s_array + .columns() + .iter() + .map(|column| ScalarValue::try_from_array(column, 0)) + .collect() + }).collect::>>()?; + partition_ordering_values.push(ordering_values); } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + self.values = new_values.into(); + self.ordering_values = new_orderings.into(); Ok(()) } fn state(&mut self) -> Result> { - let mut result = vec![self.evaluate_values()]; - if !self.ordering_req.is_empty() { - result.push(self.evaluate_orderings()?); - } - Ok(result) + Ok(vec![self.evaluate_values(), self.evaluate_orderings()?]) } fn evaluate(&mut self) -> Result { @@ -396,63 +540,3 @@ impl Accumulator for NthValueAccumulator { total } } - -impl NthValueAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - - let mut column_wise_ordering_values = vec![]; - let num_columns = fields.len(); - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let struct_field = Fields::from(fields); - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } - - fn evaluate_values(&self) -> ScalarValue { - let mut values_cloned = self.values.clone(); - let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list_nullable( - values_slice, - &self.datatypes[0], - )) - } - - /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete - /// None represents all of the new `values` need to be added to the state. - fn append_new_data( - &mut self, - values: &[ArrayRef], - fetch: Option, - ) -> Result<()> { - let n_row = values[0].len(); - let n_to_add = if let Some(fetch) = fetch { - std::cmp::min(fetch, n_row) - } else { - n_row - }; - for index in 0..n_to_add { - let row = get_row_at_idx(values, index)?; - self.values.push_back(row[0].clone()); - // At index 1, we have n index argument. - // Ordering values cover starting from 2nd index to end - self.ordering_values.push_back(row[2..].to_vec()); - } - Ok(()) - } -} diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index f948df840e73..bf6d21a808e7 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -393,7 +393,6 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; #[test] @@ -445,7 +444,7 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, @@ -456,7 +455,7 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 4682e574bfa2..09199e19cffc 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -17,12 +17,15 @@ //! [`StringAgg`] accumulator for the `string_agg` function +use std::any::Any; +use std::mem::size_of_val; + use crate::array_agg::ArrayAgg; + use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, ScalarValue}; +use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, @@ -30,8 +33,6 @@ use datafusion_expr::{ use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; -use std::any::Any; -use std::mem::size_of_val; make_udaf_expr_and_func!( StringAgg, @@ -271,7 +272,7 @@ mod tests { use arrow::datatypes::{Fields, Schema}; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; #[test] @@ -413,7 +414,7 @@ mod tests { struct StringAggAccumulatorBuilder { sep: String, distinct: bool, - ordering: LexOrdering, + order_bys: Vec, schema: Schema, } @@ -422,7 +423,7 @@ mod tests { Self { sep: sep.to_string(), distinct: Default::default(), - ordering: Default::default(), + order_bys: vec![], schema: Schema { fields: Fields::from(vec![Field::new( "col", @@ -439,7 +440,7 @@ mod tests { } fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { - self.ordering.extend([PhysicalSortExpr::new( + self.order_bys.extend([PhysicalSortExpr::new( Arc::new( Column::new_with_schema(col, &self.schema) .expect("column not available in schema"), @@ -454,7 +455,7 @@ mod tests { return_field: Field::new("f", DataType::LargeUtf8, true).into(), schema: &self.schema, ignore_nulls: false, - ordering_req: &self.ordering, + order_bys: &self.order_bys, is_reversed: false, name: "", is_distinct: self.distinct, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 37d208ffb03a..9495e087d250 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -220,10 +220,6 @@ impl AggregateUDFImpl for Sum { } } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { !args.is_distinct } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 021000dc100b..743cdeb1d3f0 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -232,6 +232,7 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 3e89242aba26..219a9b576423 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -165,6 +165,7 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 2572e8679484..07edfb70f4aa 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -17,20 +17,20 @@ //! Sort expressions -use crate::physical_expr::{fmt_sql, PhysicalExpr}; -use std::fmt; -use std::fmt::{Display, Formatter}; +use std::cmp::Ordering; +use std::fmt::{self, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::{Deref, Index, Range, RangeFrom, RangeTo}; -use std::sync::{Arc, LazyLock}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; use std::vec::IntoIter; +use crate::physical_expr::{fmt_sql, PhysicalExpr}; + use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{HashSet, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; -use itertools::Itertools; /// Represents Sort operation for a column in a RecordBatch /// @@ -77,7 +77,7 @@ use itertools::Itertools; /// .nulls_last(); /// assert_eq!(sort_expr.to_string(), "a DESC NULLS LAST"); /// ``` -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq)] pub struct PhysicalSortExpr { /// Physical expression representing the column to sort pub expr: Arc, @@ -96,6 +96,15 @@ impl PhysicalSortExpr { Self::new(expr, SortOptions::default()) } + /// Reverses the sort expression. For instance, `[a ASC NULLS LAST]` turns + /// into `[a DESC NULLS FIRST]`. Such reversals are useful in planning, e.g. + /// when constructing equivalent window expressions. + pub fn reverse(&self) -> Self { + let mut result = self.clone(); + result.options = !result.options; + result + } + /// Set the sort sort options to ASC pub fn asc(mut self) -> Self { self.options.descending = false; @@ -129,23 +138,58 @@ impl PhysicalSortExpr { to_str(&self.options) ) } -} -/// Access the PhysicalSortExpr as a PhysicalExpr -impl AsRef for PhysicalSortExpr { - fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { - self.expr.as_ref() + /// Evaluates the sort expression into a `SortColumn` that can be passed + /// into the arrow sort kernel. + pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { + let array_to_sort = match self.expr.evaluate(batch)? { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, + }; + Ok(SortColumn { + values: array_to_sort, + options: Some(self.options), + }) + } + + /// Checks whether this sort expression satisfies the given `requirement`. + /// If sort options are unspecified in `requirement`, only expressions are + /// compared for inequality. See [`options_compatible`] for details on + /// how sort options compare with one another. + pub fn satisfy( + &self, + requirement: &PhysicalSortRequirement, + schema: &Schema, + ) -> bool { + self.expr.eq(&requirement.expr) + && requirement.options.is_none_or(|opts| { + options_compatible( + &self.options, + &opts, + self.expr.nullable(schema).unwrap_or(true), + ) + }) + } + + /// Checks whether this sort expression satisfies the given `sort_expr`. + /// See [`options_compatible`] for details on how sort options compare with + /// one another. + pub fn satisfy_expr(&self, sort_expr: &Self, schema: &Schema) -> bool { + self.expr.eq(&sort_expr.expr) + && options_compatible( + &self.options, + &sort_expr.options, + self.expr.nullable(schema).unwrap_or(true), + ) } } impl PartialEq for PhysicalSortExpr { - fn eq(&self, other: &PhysicalSortExpr) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } -impl Eq for PhysicalSortExpr {} - impl Hash for PhysicalSortExpr { fn hash(&self, state: &mut H) { self.expr.hash(state); @@ -159,38 +203,20 @@ impl Display for PhysicalSortExpr { } } -impl PhysicalSortExpr { - /// evaluate the sort expression into SortColumn that can be passed into arrow sort kernel - pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { - let value_to_sort = self.expr.evaluate(batch)?; - let array_to_sort = match value_to_sort { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, - }; - Ok(SortColumn { - values: array_to_sort, - options: Some(self.options), - }) - } - - /// Checks whether this sort expression satisfies the given `requirement`. - /// If sort options are unspecified in `requirement`, only expressions are - /// compared for inequality. - pub fn satisfy( - &self, - requirement: &PhysicalSortRequirement, - schema: &Schema, - ) -> bool { +/// Returns whether the given two [`SortOptions`] are compatible. Here, +/// compatibility means that they are either exactly equal, or they differ only +/// in whether NULL values come in first/last, which is immaterial because the +/// column in question is not nullable (specified by the `nullable` parameter). +pub fn options_compatible( + options_lhs: &SortOptions, + options_rhs: &SortOptions, + nullable: bool, +) -> bool { + if nullable { + options_lhs == options_rhs + } else { // If the column is not nullable, NULLS FIRST/LAST is not important. - let nullable = self.expr.nullable(schema).unwrap_or(true); - self.expr.eq(&requirement.expr) - && if nullable { - requirement.options.is_none_or(|opts| self.options == opts) - } else { - requirement - .options - .is_none_or(|opts| self.options.descending == opts.descending) - } + options_lhs.descending == options_rhs.descending } } @@ -222,28 +248,8 @@ pub struct PhysicalSortRequirement { pub options: Option, } -impl From for PhysicalSortExpr { - /// If options is `None`, the default sort options `ASC, NULLS LAST` is used. - /// - /// The default is picked to be consistent with - /// PostgreSQL: - fn from(value: PhysicalSortRequirement) -> Self { - let options = value.options.unwrap_or(SortOptions { - descending: false, - nulls_first: false, - }); - PhysicalSortExpr::new(value.expr, options) - } -} - -impl From for PhysicalSortRequirement { - fn from(value: PhysicalSortExpr) -> Self { - PhysicalSortRequirement::new(value.expr, Some(value.options)) - } -} - impl PartialEq for PhysicalSortRequirement { - fn eq(&self, other: &PhysicalSortRequirement) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } @@ -293,37 +299,16 @@ impl PhysicalSortRequirement { Self { expr, options } } - /// Replace the required expression for this requirement with the new one - pub fn with_expr(mut self, expr: Arc) -> Self { - self.expr = expr; - self - } - /// Returns whether this requirement is equal or more specific than `other`. - pub fn compatible(&self, other: &PhysicalSortRequirement) -> bool { + pub fn compatible(&self, other: &Self) -> bool { self.expr.eq(&other.expr) && other .options .is_none_or(|other_opts| self.options == Some(other_opts)) } - - #[deprecated(since = "43.0.0", note = "use LexRequirement::from_lex_ordering")] - pub fn from_sort_exprs<'a>( - ordering: impl IntoIterator, - ) -> LexRequirement { - let ordering = ordering.into_iter().cloned().collect(); - LexRequirement::from_lex_ordering(ordering) - } - #[deprecated(since = "43.0.0", note = "use LexOrdering::from_lex_requirement")] - pub fn to_sort_exprs( - requirements: impl IntoIterator, - ) -> LexOrdering { - let requirements = requirements.into_iter().collect(); - LexOrdering::from_lex_requirement(requirements) - } } -/// Returns the SQL string representation of the given [SortOptions] object. +/// Returns the SQL string representation of the given [`SortOptions`] object. #[inline] fn to_str(options: &SortOptions) -> &str { match (options.descending, options.nulls_first) { @@ -334,162 +319,135 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` contains a `Vec`, which represents -/// a lexicographical ordering. +// Cross-conversion utilities between `PhysicalSortExpr` and `PhysicalSortRequirement` +impl From for PhysicalSortRequirement { + fn from(value: PhysicalSortExpr) -> Self { + Self::new(value.expr, Some(value.options)) + } +} + +impl From for PhysicalSortExpr { + /// The default sort options `ASC, NULLS LAST` when the requirement does + /// not specify sort options. This default is consistent with PostgreSQL. + /// + /// Reference: + fn from(value: PhysicalSortRequirement) -> Self { + let options = value + .options + .unwrap_or_else(|| SortOptions::new(false, false)); + Self::new(value.expr, options) + } +} + +/// This object represents a lexicographical ordering and contains a vector +/// of `PhysicalSortExpr` objects. /// -/// For example, `vec![a ASC, b DESC]` represents a lexicographical ordering +/// For example, a `vec![a ASC, b DESC]` represents a lexicographical ordering /// that first sorts by column `a` in ascending order, then by column `b` in /// descending order. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +/// +/// # Invariants +/// +/// The following always hold true for a `LexOrdering`: +/// +/// 1. It is non-degenerate, meaning it contains at least one element. +/// 2. It is duplicate-free, meaning it does not contain multiple entries for +/// the same column. +#[derive(Debug, Clone, PartialEq, Eq)] pub struct LexOrdering { - inner: Vec, -} - -impl AsRef for LexOrdering { - fn as_ref(&self) -> &LexOrdering { - self - } + /// Vector of sort expressions representing the lexicographical ordering. + exprs: Vec, + /// Set of expressions in the lexicographical ordering, used to ensure + /// that the ordering is duplicate-free. Note that the elements in this + /// set are the same underlying physical expressions as in `exprs`. + set: HashSet>, } impl LexOrdering { - /// Creates a new [`LexOrdering`] from a vector - pub fn new(inner: Vec) -> Self { - Self { inner } - } - - /// Return an empty LexOrdering (no expressions) - pub fn empty() -> &'static LexOrdering { - static EMPTY_ORDER: LazyLock = LazyLock::new(LexOrdering::default); - &EMPTY_ORDER - } - - /// Returns the number of elements that can be stored in the LexOrdering - /// without reallocating. - pub fn capacity(&self) -> usize { - self.inner.capacity() - } - - /// Clears the LexOrdering, removing all elements. - pub fn clear(&mut self) { - self.inner.clear() - } - - /// Takes ownership of the actual vector of `PhysicalSortExpr`s in the LexOrdering. - pub fn take_exprs(self) -> Vec { - self.inner - } - - /// Returns `true` if the LexOrdering contains `expr` - pub fn contains(&self, expr: &PhysicalSortExpr) -> bool { - self.inner.contains(expr) - } - - /// Add all elements from `iter` to the LexOrdering. - pub fn extend>(&mut self, iter: I) { - self.inner.extend(iter) - } - - /// Remove all elements from the LexOrdering where `f` evaluates to `false`. - pub fn retain(&mut self, f: F) - where - F: FnMut(&PhysicalSortExpr) -> bool, - { - self.inner.retain(f) - } - - /// Returns `true` if the LexOrdering contains no elements. - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Returns an iterator over each `&PhysicalSortExpr` in the LexOrdering. - pub fn iter(&self) -> core::slice::Iter { - self.inner.iter() + /// Creates a new [`LexOrdering`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(exprs: impl IntoIterator) -> Option { + let (non_empty, ordering) = Self::construct(exprs); + non_empty.then_some(ordering) } - /// Returns the number of elements in the LexOrdering. - pub fn len(&self) -> usize { - self.inner.len() - } - - /// Removes the last element from the LexOrdering and returns it, or `None` if it is empty. - pub fn pop(&mut self) -> Option { - self.inner.pop() - } - - /// Appends an element to the back of the LexOrdering. - pub fn push(&mut self, physical_sort_expr: PhysicalSortExpr) { - self.inner.push(physical_sort_expr) + /// Appends an element to the back of the `LexOrdering`. + pub fn push(&mut self, sort_expr: PhysicalSortExpr) { + if self.set.insert(Arc::clone(&sort_expr.expr)) { + self.exprs.push(sort_expr); + } } - /// Truncates the LexOrdering, keeping only the first `len` elements. - pub fn truncate(&mut self, len: usize) { - self.inner.truncate(len) + /// Add all elements from `iter` to the `LexOrdering`. + pub fn extend(&mut self, sort_exprs: impl IntoIterator) { + for sort_expr in sort_exprs { + self.push(sort_expr); + } } - /// Merge the contents of `other` into `self`, removing duplicates. - pub fn merge(mut self, other: LexOrdering) -> Self { - self.inner = self.inner.into_iter().chain(other).unique().collect(); - self + /// Returns the leading `PhysicalSortExpr` of the `LexOrdering`. Note that + /// this function does not return an `Option`, as a `LexOrdering` is always + /// non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortExpr { + // Can safely `unwrap` because `LexOrdering` is non-degenerate: + self.exprs.first().unwrap() } - /// Converts a `LexRequirement` into a `LexOrdering`. - /// - /// This function converts [`PhysicalSortRequirement`] to [`PhysicalSortExpr`] - /// for each entry in the input. - /// - /// If the required ordering is `None` for an entry in `requirement`, the - /// default ordering `ASC, NULLS LAST` is used (see - /// [`PhysicalSortExpr::from`]). - pub fn from_lex_requirement(requirement: LexRequirement) -> LexOrdering { - requirement - .into_iter() - .map(PhysicalSortExpr::from) - .collect() + /// Returns the number of elements that can be stored in the `LexOrdering` + /// without reallocating. + pub fn capacity(&self) -> usize { + self.exprs.capacity() } - /// Collapse a `LexOrdering` into a new duplicate-free `LexOrdering` based on expression. - /// - /// This function filters duplicate entries that have same physical - /// expression inside, ignoring [`SortOptions`]. For example: - /// - /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. - pub fn collapse(self) -> Self { - let mut output = LexOrdering::default(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } + /// Truncates the `LexOrdering`, keeping only the first `len` elements. + /// Returns `true` if truncation made a change, `false` otherwise. Negative + /// cases happen in two scenarios: (1) When `len` is greater than or equal + /// to the number of expressions inside this `LexOrdering`, making truncation + /// a no-op, or (2) when `len` is `0`, making truncation impossible. + pub fn truncate(&mut self, len: usize) -> bool { + if len == 0 || len >= self.exprs.len() { + return false; } - output - } - - /// Transforms each `PhysicalSortExpr` in the `LexOrdering` - /// in place using the provided closure `f`. - pub fn transform(&mut self, f: F) - where - F: FnMut(&mut PhysicalSortExpr), - { - self.inner.iter_mut().for_each(f); + for PhysicalSortExpr { expr, .. } in self.exprs[len..].iter() { + self.set.remove(expr); + } + self.exprs.truncate(len); + true } -} -impl From> for LexOrdering { - fn from(value: Vec) -> Self { - Self::new(value) + /// Constructs a new `LexOrdering` from the given sort requirements w/o + /// enforcing non-degeneracy. This function is used internally and is not + /// meant (or safe) for external use. + fn construct(exprs: impl IntoIterator) -> (bool, Self) { + let mut set = HashSet::new(); + let exprs = exprs + .into_iter() + .filter_map(|s| set.insert(Arc::clone(&s.expr)).then_some(s)) + .collect(); + (!set.is_empty(), Self { exprs, set }) } } -impl From for LexOrdering { - fn from(value: LexRequirement) -> Self { - Self::from_lex_requirement(value) +impl PartialOrd for LexOrdering { + /// There is a partial ordering among `LexOrdering` objects. For example, the + /// ordering `[a ASC]` is coarser (less) than ordering `[a ASC, b ASC]`. + /// If two orderings do not share a prefix, they are incomparable. + fn partial_cmp(&self, other: &Self) -> Option { + self.iter() + .zip(other.iter()) + .all(|(lhs, rhs)| lhs == rhs) + .then(|| self.len().cmp(&other.len())) } } -/// Convert a `LexOrdering` into a `Arc[]` for fast copies -impl From for Arc<[PhysicalSortExpr]> { - fn from(value: LexOrdering) -> Self { - value.inner.into() +impl From<[PhysicalSortExpr; N]> for LexOrdering { + fn from(value: [PhysicalSortExpr; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + let (non_empty, ordering) = Self::construct(value); + debug_assert!(non_empty); + ordering } } @@ -497,14 +455,14 @@ impl Deref for LexOrdering { type Target = [PhysicalSortExpr]; fn deref(&self) -> &Self::Target { - self.inner.as_slice() + self.exprs.as_slice() } } impl Display for LexOrdering { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut first = true; - for sort_expr in &self.inner { + for sort_expr in &self.exprs { if first { first = false; } else { @@ -516,162 +474,250 @@ impl Display for LexOrdering { } } -impl FromIterator for LexOrdering { - fn from_iter>(iter: T) -> Self { - let mut lex_ordering = LexOrdering::default(); - - for i in iter { - lex_ordering.push(i); - } +impl IntoIterator for LexOrdering { + type Item = PhysicalSortExpr; + type IntoIter = IntoIter; - lex_ordering + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } -impl Index for LexOrdering { - type Output = PhysicalSortExpr; +impl<'a> IntoIterator for &'a LexOrdering { + type Item = &'a PhysicalSortExpr; + type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; - fn index(&self, index: usize) -> &Self::Output { - &self.inner[index] + fn into_iter(self) -> Self::IntoIter { + self.exprs.iter() } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range: Range) -> &Self::Output { - &self.inner[range] +impl From for Vec { + fn from(ordering: LexOrdering) -> Self { + ordering.exprs } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; +/// This object represents a lexicographical ordering requirement and contains +/// a vector of `PhysicalSortRequirement` objects. +/// +/// For example, a `vec![a Some(ASC), b None]` represents a lexicographical +/// requirement that firsts imposes an ordering by column `a` in ascending +/// order, then by column `b` in *any* (ascending or descending) order. The +/// ordering is non-degenerate, meaning it contains at least one element, and +/// it is duplicate-free, meaning it does not contain multiple entries for the +/// same column. +/// +/// Note that a `LexRequirement` need not enforce the uniqueness of its sort +/// expressions after construction like a `LexOrdering` does, because it provides +/// no mutation methods. If such methods become necessary, we will need to +/// enforce uniqueness like the latter object. +#[derive(Debug, Clone, PartialEq)] +pub struct LexRequirement { + reqs: Vec, +} - fn index(&self, range_from: RangeFrom) -> &Self::Output { - &self.inner[range_from] +impl LexRequirement { + /// Creates a new [`LexRequirement`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(reqs: impl IntoIterator) -> Option { + let (non_empty, requirements) = Self::construct(reqs); + non_empty.then_some(requirements) + } + + /// Returns the leading `PhysicalSortRequirement` of the `LexRequirement`. + /// Note that this function does not return an `Option`, as a `LexRequirement` + /// is always non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortRequirement { + // Can safely `unwrap` because `LexRequirement` is non-degenerate: + self.reqs.first().unwrap() + } + + /// Constructs a new `LexRequirement` from the given sort requirements w/o + /// enforcing non-degeneracy. This function is used internally and is not + /// meant (or safe) for external use. + fn construct( + reqs: impl IntoIterator, + ) -> (bool, Self) { + let mut set = HashSet::new(); + let reqs = reqs + .into_iter() + .filter_map(|r| set.insert(Arc::clone(&r.expr)).then_some(r)) + .collect(); + (!set.is_empty(), Self { reqs }) } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range_to: RangeTo) -> &Self::Output { - &self.inner[range_to] +impl From<[PhysicalSortRequirement; N]> for LexRequirement { + fn from(value: [PhysicalSortRequirement; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + let (non_empty, requirement) = Self::construct(value); + debug_assert!(non_empty); + requirement } } -impl IntoIterator for LexOrdering { - type Item = PhysicalSortExpr; - type IntoIter = IntoIter; +impl Deref for LexRequirement { + type Target = [PhysicalSortRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() + fn deref(&self) -> &Self::Target { + self.reqs.as_slice() } } -///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents -/// a reference to a lexicographical ordering. -#[deprecated(since = "43.0.0", note = "use &LexOrdering instead")] -pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; +impl IntoIterator for LexRequirement { + type Item = PhysicalSortRequirement; + type IntoIter = IntoIter; -///`LexRequirement` is an struct containing a `Vec`, which -/// represents a lexicographical ordering requirement. -#[derive(Debug, Default, Clone, PartialEq)] -pub struct LexRequirement { - pub inner: Vec, + fn into_iter(self) -> Self::IntoIter { + self.reqs.into_iter() + } } -impl LexRequirement { - pub fn new(inner: Vec) -> Self { - Self { inner } - } +impl<'a> IntoIterator for &'a LexRequirement { + type Item = &'a PhysicalSortRequirement; + type IntoIter = std::slice::Iter<'a, PhysicalSortRequirement>; - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.reqs.iter() } +} - pub fn iter(&self) -> impl Iterator { - self.inner.iter() +impl From for Vec { + fn from(requirement: LexRequirement) -> Self { + requirement.reqs } +} - pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) { - self.inner.push(physical_sort_requirement) +// Cross-conversion utilities between `LexOrdering` and `LexRequirement` +impl From for LexRequirement { + fn from(value: LexOrdering) -> Self { + // Can construct directly as `value` is non-degenerate: + let (non_empty, requirements) = + Self::construct(value.into_iter().map(Into::into)); + debug_assert!(non_empty); + requirements } +} - /// Create a new [`LexRequirement`] from a [`LexOrdering`] - /// - /// Returns [`LexRequirement`] that requires the exact - /// sort of the [`PhysicalSortExpr`]s in `ordering` - pub fn from_lex_ordering(ordering: LexOrdering) -> Self { - Self::new( - ordering - .into_iter() - .map(PhysicalSortRequirement::from) - .collect(), - ) +impl From for LexOrdering { + fn from(value: LexRequirement) -> Self { + // Can construct directly as `value` is non-degenerate: + let (non_empty, ordering) = Self::construct(value.into_iter().map(Into::into)); + debug_assert!(non_empty); + ordering } +} - /// Constructs a duplicate-free `LexOrderingReq` by filtering out - /// duplicate entries that have same physical expression inside. - /// - /// For example, `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a - /// Some(ASC)]`. - pub fn collapse(self) -> Self { - let mut output = Vec::::new(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); +/// Represents a plan's input ordering requirements. Vector elements represent +/// alternative ordering requirements in the order of preference. The list of +/// alternatives can be either hard or soft, depending on whether the operator +/// can work without an input ordering. +/// +/// # Invariants +/// +/// The following always hold true for a `OrderingRequirements`: +/// +/// 1. It is non-degenerate, meaning it contains at least one ordering. The +/// absence of an input ordering requirement is represented by a `None` value +/// in `ExecutionPlan` APIs, which return an `Option`. +#[derive(Debug, Clone, PartialEq)] +pub enum OrderingRequirements { + /// The operator is not able to work without one of these requirements. + Hard(Vec), + /// The operator can benefit from these input orderings when available, + /// but can still work in the absence of any input ordering. + Soft(Vec), +} + +impl OrderingRequirements { + /// Creates a new instance from the given alternatives. If an empty list of + /// alternatives are given, returns `None`. + pub fn new_alternatives( + alternatives: impl IntoIterator, + soft: bool, + ) -> Option { + let alternatives = alternatives.into_iter().collect::>(); + (!alternatives.is_empty()).then(|| { + if soft { + Self::Soft(alternatives) + } else { + Self::Hard(alternatives) } - } - LexRequirement::new(output) + }) } -} -impl From for LexRequirement { - fn from(value: LexOrdering) -> Self { - Self::from_lex_ordering(value) + /// Creates a new instance with a single hard requirement. + pub fn new(requirement: LexRequirement) -> Self { + Self::Hard(vec![requirement]) } -} -impl Deref for LexRequirement { - type Target = [PhysicalSortRequirement]; + /// Creates a new instance with a single soft requirement. + pub fn new_soft(requirement: LexRequirement) -> Self { + Self::Soft(vec![requirement]) + } - fn deref(&self) -> &Self::Target { - self.inner.as_slice() + /// Adds an alternative requirement to the list of alternatives. + pub fn add_alternative(&mut self, requirement: LexRequirement) { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.push(requirement), + } } -} -impl FromIterator for LexRequirement { - fn from_iter>(iter: T) -> Self { - let mut lex_requirement = LexRequirement::new(vec![]); + /// Returns the first (i.e. most preferred) `LexRequirement` among + /// alternative requirements. + pub fn into_single(self) -> LexRequirement { + match self { + Self::Hard(mut alts) | Self::Soft(mut alts) => alts.swap_remove(0), + } + } - for i in iter { - lex_requirement.inner.push(i); + /// Returns a reference to the first (i.e. most preferred) `LexRequirement` + /// among alternative requirements. + pub fn first(&self) -> &LexRequirement { + match self { + Self::Hard(alts) | Self::Soft(alts) => &alts[0], } + } - lex_requirement + /// Returns all alternatives as a vector of `LexRequirement` objects and a + /// boolean value indicating softness/hardness of the requirements. + pub fn into_alternatives(self) -> (Vec, bool) { + match self { + Self::Hard(alts) => (alts, false), + Self::Soft(alts) => (alts, true), + } } } -impl IntoIterator for LexRequirement { - type Item = PhysicalSortRequirement; - type IntoIter = IntoIter; +impl From for OrderingRequirements { + fn from(requirement: LexRequirement) -> Self { + Self::new(requirement) + } +} - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() +impl From for OrderingRequirements { + fn from(ordering: LexOrdering) -> Self { + Self::new(ordering.into()) } } -impl<'a> IntoIterator for &'a LexOrdering { - type Item = &'a PhysicalSortExpr; - type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; +impl Deref for OrderingRequirements { + type Target = [LexRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.iter() + fn deref(&self) -> &Self::Target { + match &self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_slice(), + } } } -///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which -/// represents a reference to a lexicographical ordering requirement. -/// #[deprecated(since = "43.0.0", note = "use &LexRequirement instead")] -pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; +impl DerefMut for OrderingRequirements { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_mut_slice(), + } + } +} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 114007bfa6af..05b216ab75eb 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,16 +17,14 @@ use std::sync::Arc; +use crate::physical_expr::PhysicalExpr; +use crate::tree_node::ExprContext; + use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; - use datafusion_common::Result; use datafusion_expr_common::sort_properties::ExprProperties; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use crate::tree_node::ExprContext; - /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -93,16 +91,6 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { Ok(make_array(data)) } -/// Reverses the ORDER BY expression, which is useful during equivalent window -/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into -/// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &LexOrdering) -> LexOrdering { - order_bys - .iter() - .map(|e| PhysicalSortExpr::new(Arc::clone(&e.expr), !e.options)) - .collect() -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index be04b9c6b8ea..9175c01274cb 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -52,8 +52,7 @@ use datafusion_functions_aggregate_common::accumulator::{ }; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::reverse_order_bys; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// Builder for physical [`AggregateFunctionExpr`] /// @@ -70,7 +69,7 @@ pub struct AggregateExprBuilder { /// Arrow Schema for the aggregate function schema: SchemaRef, /// The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, /// Whether to ignore null values ignore_nulls: bool, /// Whether is distinct aggregate function @@ -87,7 +86,7 @@ impl AggregateExprBuilder { alias: None, human_display: String::default(), schema: Arc::new(Schema::empty()), - ordering_req: LexOrdering::default(), + order_bys: vec![], ignore_nulls: false, is_distinct: false, is_reversed: false, @@ -128,17 +127,20 @@ impl AggregateExprBuilder { /// # impl AggregateUDFImpl for FirstValueUdf { /// # fn as_any(&self) -> &dyn Any { /// # unimplemented!() - /// # } + /// # } + /// # /// # fn name(&self) -> &str { /// # unimplemented!() - /// } + /// # } + /// # /// # fn signature(&self) -> &Signature { /// # unimplemented!() - /// # } + /// # } + /// # /// # fn return_type(&self, args: &[DataType]) -> Result { /// # unimplemented!() /// # } - /// # + /// # /// # fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { /// # unimplemented!() /// # } @@ -146,7 +148,7 @@ impl AggregateExprBuilder { /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// # unimplemented!() /// # } - /// # + /// # /// # fn documentation(&self) -> Option<&Documentation> { /// # unimplemented!() /// # } @@ -169,16 +171,16 @@ impl AggregateExprBuilder { /// }]; /// /// let first_value = AggregateUDF::from(FirstValueUdf::new()); - /// + /// /// let aggregate_expr = AggregateExprBuilder::new( /// Arc::new(first_value), /// args /// ) - /// .order_by(order_by.into()) + /// .order_by(order_by) /// .alias("first_a_by_x") /// .ignore_nulls() /// .build()?; - /// + /// /// Ok(()) /// } /// ``` @@ -192,7 +194,7 @@ impl AggregateExprBuilder { alias, human_display, schema, - ordering_req, + order_bys, ignore_nulls, is_distinct, is_reversed, @@ -201,17 +203,12 @@ impl AggregateExprBuilder { return internal_err!("args should not be empty"); } - let mut ordering_fields = vec![]; - - if !ordering_req.is_empty() { - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; + let ordering_types = order_bys + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; - ordering_fields = - utils::ordering_fields(ordering_req.as_ref(), &ordering_types); - } + let ordering_fields = utils::ordering_fields(&order_bys, &ordering_types); let input_exprs_fields = args .iter() @@ -242,7 +239,7 @@ impl AggregateExprBuilder { name, human_display, schema: Arc::unwrap_or_clone(schema), - ordering_req, + order_bys, ignore_nulls, ordering_fields, is_distinct, @@ -267,8 +264,8 @@ impl AggregateExprBuilder { self } - pub fn order_by(mut self, order_by: LexOrdering) -> Self { - self.ordering_req = order_by; + pub fn order_by(mut self, order_bys: Vec) -> Self { + self.order_bys = order_bys; self } @@ -318,7 +315,7 @@ pub struct AggregateFunctionExpr { human_display: String, schema: Schema, // The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, // Whether to ignore null values ignore_nulls: bool, // fields used for order sensitive aggregation functions @@ -388,7 +385,7 @@ impl AggregateFunctionExpr { return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -411,31 +408,24 @@ impl AggregateFunctionExpr { self.fun.state_fields(args) } - /// Order by requirements for the aggregate function - /// By default it is `None` (there is no requirement) - /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - pub fn order_bys(&self) -> Option<&LexOrdering> { - if self.ordering_req.is_empty() { - return None; - } - - if !self.order_sensitivity().is_insensitive() { - return Some(self.ordering_req.as_ref()); + /// Returns the ORDER BY expressions for the aggregate function. + pub fn order_bys(&self) -> &[PhysicalSortExpr] { + if self.order_sensitivity().is_insensitive() { + &[] + } else { + &self.order_bys } - - None } /// Indicates whether aggregator can produce the correct result with any /// arbitrary input ordering. By default, we assume that aggregate expressions /// are order insensitive. pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { - if !self.ordering_req.is_empty() { - // If there is requirement, use the sensitivity of the implementation - self.fun.order_sensitivity() - } else { - // If no requirement, aggregator is order insensitive + if self.order_bys.is_empty() { AggregateOrderSensitivity::Insensitive + } else { + // If there is an ORDER BY clause, use the sensitivity of the implementation: + self.fun.order_sensitivity() } } @@ -463,7 +453,7 @@ impl AggregateFunctionExpr { }; AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) - .order_by(self.ordering_req.clone()) + .order_by(self.order_bys.clone()) .schema(Arc::new(self.schema.clone())) .alias(self.name().to_string()) .with_ignore_nulls(self.ignore_nulls) @@ -479,7 +469,7 @@ impl AggregateFunctionExpr { return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -548,7 +538,7 @@ impl AggregateFunctionExpr { return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -567,7 +557,7 @@ impl AggregateFunctionExpr { return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -585,18 +575,16 @@ impl AggregateFunctionExpr { ReversedUDAF::NotSupported => None, ReversedUDAF::Identical => Some(self.clone()), ReversedUDAF::Reversed(reverse_udf) => { - let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref()); let mut name = self.name().to_string(); // If the function is changed, we need to reverse order_by clause as well // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) - if self.fun().name() == reverse_udf.name() { - } else { + if self.fun().name() != reverse_udf.name() { replace_order_by_clause(&mut name); } replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) - .order_by(reverse_ordering_req) + .order_by(self.order_bys.iter().map(|e| e.reverse()).collect()) .schema(Arc::new(self.schema.clone())) .alias(name) .with_ignore_nulls(self.ignore_nulls) @@ -612,14 +600,11 @@ impl AggregateFunctionExpr { /// These expressions are (1)function arguments, (2) order by expressions. pub fn all_expressions(&self) -> AggregatePhysicalExpressions { let args = self.expressions(); - let order_bys = self + let order_by_exprs = self .order_bys() - .cloned() - .unwrap_or_else(LexOrdering::default); - let order_by_exprs = order_bys .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); AggregatePhysicalExpressions { args, order_by_exprs, diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 98b1299a2ec6..0e722f3dc824 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,30 +15,61 @@ // specific language governing permissions and limitations // under the License. -use super::{add_offset_to_expr, ProjectionMapping}; -use crate::{ - expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, - PhysicalSortExpr, PhysicalSortRequirement, -}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinType, ScalarValue}; -use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; use std::fmt::Display; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; +use super::projection::ProjectionTargets; +use super::ProjectionMapping; +use crate::expressions::Literal; +use crate::physical_expr::add_offset_to_expr; +use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{HashMap, JoinType, Result, ScalarValue}; +use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; + use indexmap::{IndexMap, IndexSet}; -/// A structure representing a expression known to be constant in a physical execution plan. +/// Represents whether a constant expression's value is uniform or varies across +/// partitions. Has two variants: +/// - `Heterogeneous`: The constant expression may have different values for +/// different partitions. +/// - `Uniform(Option)`: The constant expression has the same value +/// across all partitions, or is `None` if the value is unknown. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub enum AcrossPartitions { + #[default] + Heterogeneous, + Uniform(Option), +} + +impl Display for AcrossPartitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"), + AcrossPartitions::Uniform(value) => { + if let Some(val) = value { + write!(f, "(uniform: {val})") + } else { + write!(f, "(uniform: unknown)") + } + } + } + } +} + +/// A structure representing a expression known to be constant in a physical +/// execution plan. /// -/// The `ConstExpr` struct encapsulates an expression that is constant during the execution -/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would -/// be known constant +/// The `ConstExpr` struct encapsulates an expression that is constant during +/// the execution of a query. For example if a filter like `A = 5` appears +/// earlier in the plan, `A` would become a constant in subsequent operations. /// /// # Fields /// /// - `expr`: Constant expression for a node in the physical plan. -/// /// - `across_partitions`: A boolean flag indicating whether the constant /// expression is the same across partitions. If set to `true`, the constant /// expression has same value for all partitions. If set to `false`, the @@ -50,108 +81,37 @@ use indexmap::{IndexMap, IndexSet}; /// # use datafusion_physical_expr::ConstExpr; /// # use datafusion_physical_expr::expressions::lit; /// let col = lit(5); -/// // Create a constant expression from a physical expression ref -/// let const_expr = ConstExpr::from(&col); -/// // create a constant expression from a physical expression +/// // Create a constant expression from a physical expression: /// let const_expr = ConstExpr::from(col); /// ``` -// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum: -// -// ``` -// enum PartitionValues { -// Uniform(Option), // Same value across all partitions -// Heterogeneous(Vec>) // Different values per partition -// } -// ``` -// -// This would provide more flexible representation of partition values. -// Note: This is a breaking change for the equivalence API and should be -// addressed in a separate issue/PR. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ConstExpr { - /// The expression that is known to be constant (e.g. a `Column`) - expr: Arc, - /// Does the constant have the same value across all partitions? See - /// struct docs for more details - across_partitions: AcrossPartitions, -} - -#[derive(PartialEq, Clone, Debug)] -/// Represents whether a constant expression's value is uniform or varies across partitions. -/// -/// The `AcrossPartitions` enum is used to describe the nature of a constant expression -/// in a physical execution plan: -/// -/// - `Heterogeneous`: The constant expression may have different values for different partitions. -/// - `Uniform(Option)`: The constant expression has the same value across all partitions, -/// or is `None` if the value is not specified. -pub enum AcrossPartitions { - Heterogeneous, - Uniform(Option), -} - -impl Default for AcrossPartitions { - fn default() -> Self { - Self::Heterogeneous - } -} - -impl PartialEq for ConstExpr { - fn eq(&self, other: &Self) -> bool { - self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) - } + /// The expression that is known to be constant (e.g. a `Column`). + pub expr: Arc, + /// Indicates whether the constant have the same value across all partitions. + pub across_partitions: AcrossPartitions, } +// TODO: The `ConstExpr` definition above can be in an inconsistent state where +// `expr` is a literal but `across_partitions` is not `Uniform`. Consider +// a refactor to ensure that `ConstExpr` is always in a consistent state +// (either by changing type definition, or by API constraints). impl ConstExpr { - /// Create a new constant expression from a physical expression. + /// Create a new constant expression from a physical expression, specifying + /// whether the constant expression is the same across partitions. /// - /// Note you can also use `ConstExpr::from` to create a constant expression - /// from a reference as well - pub fn new(expr: Arc) -> Self { - Self { - expr, - // By default, assume constant expressions are not same across partitions. - across_partitions: Default::default(), + /// Note that you can also use `ConstExpr::from` to create a constant + /// expression from just a physical expression, with the *safe* assumption + /// of heterogenous values across partitions unless the expression is a + /// literal. + pub fn new(expr: Arc, across_partitions: AcrossPartitions) -> Self { + let mut result = ConstExpr::from(expr); + // Override the across partitions specification if the expression is not + // a literal. + if result.across_partitions == AcrossPartitions::Heterogeneous { + result.across_partitions = across_partitions; } - } - - /// Set the `across_partitions` flag - /// - /// See struct docs for more details - pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self { - self.across_partitions = across_partitions; - self - } - - /// Is the expression the same across all partitions? - /// - /// See struct docs for more details - pub fn across_partitions(&self) -> AcrossPartitions { - self.across_partitions.clone() - } - - pub fn expr(&self) -> &Arc { - &self.expr - } - - pub fn owned_expr(self) -> Arc { - self.expr - } - - pub fn map(&self, f: F) -> Option - where - F: Fn(&Arc) -> Option>, - { - let maybe_expr = f(&self.expr); - maybe_expr.map(|expr| Self { - expr, - across_partitions: self.across_partitions.clone(), - }) - } - - /// Returns true if this constant expression is equal to the given expression - pub fn eq_expr(&self, other: impl AsRef) -> bool { - self.expr.as_ref() == other.as_ref() + result } /// Returns a [`Display`]able list of `ConstExpr`. @@ -175,47 +135,36 @@ impl ConstExpr { } } +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) + } +} + impl Display for ConstExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.expr)?; - match &self.across_partitions { - AcrossPartitions::Heterogeneous => { - write!(f, "(heterogeneous)")?; - } - AcrossPartitions::Uniform(value) => { - if let Some(val) = value { - write!(f, "(uniform: {val})")?; - } else { - write!(f, "(uniform: unknown)")?; - } - } - } - Ok(()) + write!(f, "{}", self.across_partitions) } } impl From> for ConstExpr { fn from(expr: Arc) -> Self { - Self::new(expr) - } -} - -impl From<&Arc> for ConstExpr { - fn from(expr: &Arc) -> Self { - Self::new(Arc::clone(expr)) + // By default, assume constant expressions are not same across partitions. + // However, if we have a literal, it will have a single value that is the + // same across all partitions. + let across = if let Some(lit) = expr.as_any().downcast_ref::() { + AcrossPartitions::Uniform(Some(lit.value().clone())) + } else { + AcrossPartitions::Heterogeneous + }; + Self { + expr, + across_partitions: across, + } } } -/// Checks whether `expr` is among in the `const_exprs`. -pub fn const_exprs_contains( - const_exprs: &[ConstExpr], - expr: &Arc, -) -> bool { - const_exprs - .iter() - .any(|const_expr| const_expr.expr.eq(expr)) -} - /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by /// equality predicates (e.g. `a = b`), typically equi-join conditions and @@ -223,259 +172,361 @@ pub fn const_exprs_contains( /// /// Two `EquivalenceClass`es are equal if they contains the same expressions in /// without any ordering. -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct EquivalenceClass { - /// The expressions in this equivalence class. The order doesn't - /// matter for equivalence purposes - /// - exprs: IndexSet>, -} - -impl PartialEq for EquivalenceClass { - /// Returns true if other is equal in the sense - /// of bags (multi-sets), disregarding their orderings. - fn eq(&self, other: &Self) -> bool { - self.exprs.eq(&other.exprs) - } + /// The expressions in this equivalence class. The order doesn't matter for + /// equivalence purposes. + pub(crate) exprs: IndexSet>, + /// Indicates whether the expressions in this equivalence class have a + /// constant value. A `Some` value indicates constant-ness. + pub(crate) constant: Option, } impl EquivalenceClass { - /// Create a new empty equivalence class - pub fn new_empty() -> Self { - Self { - exprs: IndexSet::new(), - } - } - - // Create a new equivalence class from a pre-existing `Vec` - pub fn new(exprs: Vec>) -> Self { - Self { - exprs: exprs.into_iter().collect(), + // Create a new equivalence class from a pre-existing collection. + pub fn new(exprs: impl IntoIterator>) -> Self { + let mut class = Self::default(); + for expr in exprs { + class.push(expr); } - } - - /// Return the inner vector of expressions - pub fn into_vec(self) -> Vec> { - self.exprs.into_iter().collect() + class } /// Return the "canonical" expression for this class (the first element) - /// if any - fn canonical_expr(&self) -> Option> { - self.exprs.iter().next().cloned() + /// if non-empty. + pub fn canonical_expr(&self) -> Option<&Arc> { + self.exprs.iter().next() } /// Insert the expression into this class, meaning it is known to be equal to - /// all other expressions in this class + /// all other expressions in this class. pub fn push(&mut self, expr: Arc) { + if let Some(lit) = expr.as_any().downcast_ref::() { + let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone())); + if let Some(across) = self.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = expr_across; + } + } else { + self.constant = Some(expr_across); + } + } self.exprs.insert(expr); } - /// Inserts all the expressions from other into this class + /// Inserts all the expressions from other into this class. pub fn extend(&mut self, other: Self) { - for expr in other.exprs { - // use push so entries are deduplicated - self.push(expr); + self.exprs.extend(other.exprs); + match (&self.constant, &other.constant) { + (Some(across), Some(_)) => { + // TODO: Return an error if constant values do not agree. + if across == &AcrossPartitions::Heterogeneous { + self.constant = other.constant; + } + } + (None, Some(_)) => self.constant = other.constant, + (_, None) => {} } } - /// Returns true if this equivalence class contains t expression - pub fn contains(&self, expr: &Arc) -> bool { - self.exprs.contains(expr) - } - - /// Returns true if this equivalence class has any entries in common with `other` + /// Returns whether this equivalence class has any entries in common with + /// `other`. pub fn contains_any(&self, other: &Self) -> bool { - self.exprs.iter().any(|e| other.contains(e)) - } - - /// return the number of items in this class - pub fn len(&self) -> usize { - self.exprs.len() - } - - /// return true if this class is empty - pub fn is_empty(&self) -> bool { - self.exprs.is_empty() + self.exprs.intersection(&other.exprs).next().is_some() } - /// Iterate over all elements in this class, in some arbitrary order - pub fn iter(&self) -> impl Iterator> { - self.exprs.iter() + /// Returns whether this equivalence class is trivial, meaning that it is + /// either empty, or contains a single expression that is not a constant. + /// Such classes are not useful, and can be removed from equivalence groups. + pub fn is_trivial(&self) -> bool { + self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none()) } - /// Return a new equivalence class that have the specified offset added to - /// each expression (used when schemas are appended such as in joins) - pub fn with_offset(&self, offset: usize) -> Self { - let new_exprs = self + /// Adds the given offset to all columns in the expressions inside this + /// class. This is used when schemas are appended, e.g. in joins. + pub fn try_with_offset(&self, offset: isize) -> Result { + let mut cls = Self::default(); + for expr_result in self .exprs .iter() .cloned() .map(|e| add_offset_to_expr(e, offset)) - .collect(); - Self::new(new_exprs) + { + cls.push(expr_result?); + } + Ok(cls) + } +} + +impl Deref for EquivalenceClass { + type Target = IndexSet>; + + fn deref(&self) -> &Self::Target { + &self.exprs + } +} + +impl IntoIterator for EquivalenceClass { + type Item = Arc; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } impl Display for EquivalenceClass { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "[{}]", format_physical_expr_list(&self.exprs)) + write!(f, "{{")?; + write!(f, "members: {}", format_physical_expr_list(&self.exprs))?; + if let Some(across) = &self.constant { + write!(f, ", constant: {across}")?; + } + write!(f, "}}") + } +} + +impl From for Vec> { + fn from(cls: EquivalenceClass) -> Self { + cls.exprs.into_iter().collect() } } -/// A collection of distinct `EquivalenceClass`es -#[derive(Debug, Clone)] +type AugmentedMapping<'a> = IndexMap< + &'a Arc, + (&'a ProjectionTargets, Option<&'a EquivalenceClass>), +>; + +/// A collection of distinct `EquivalenceClass`es. This object supports fast +/// lookups of expressions and their equivalence classes. +#[derive(Clone, Debug, Default)] pub struct EquivalenceGroup { + /// A mapping from expressions to their equivalence class key. + map: HashMap, usize>, + /// The equivalence classes in this group. classes: Vec, } impl EquivalenceGroup { - /// Creates an empty equivalence group. - pub fn empty() -> Self { - Self { classes: vec![] } - } - /// Creates an equivalence group from the given equivalence classes. - pub fn new(classes: Vec) -> Self { - let mut result = Self { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - pub fn len(&self) -> usize { - self.classes.len() + pub fn new(classes: impl IntoIterator) -> Self { + classes.into_iter().collect::>().into() } - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 + /// Adds `expr` as a constant expression to this equivalence group. + pub fn add_constant(&mut self, const_expr: ConstExpr) { + // If the expression is already in an equivalence class, we should + // adjust the constant-ness of the class if necessary: + if let Some(idx) = self.map.get(&const_expr.expr) { + let cls = &mut self.classes[*idx]; + if let Some(across) = cls.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = const_expr.across_partitions; + } + } else { + cls.constant = Some(const_expr.across_partitions); + } + return; + } + // If the expression is not in any equivalence class, but has the same + // constant value with some class, add it to that class: + if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions { + for (idx, cls) in self.classes.iter_mut().enumerate() { + if cls + .constant + .as_ref() + .is_some_and(|across| const_expr.across_partitions.eq(across)) + { + self.map.insert(Arc::clone(&const_expr.expr), idx); + cls.push(const_expr.expr); + return; + } + } + } + // Otherwise, create a new class with the expression as the only member: + let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr)); + if new_class.constant.is_none() { + new_class.constant = Some(const_expr.across_partitions); + } + Self::update_lookup_table(&mut self.map, &new_class, self.classes.len()); + self.classes.push(new_class); } - /// Returns an iterator over the equivalence classes in this group. - pub fn iter(&self) -> impl Iterator { - self.classes.iter() + /// Removes constant expressions that may change across partitions. + /// This method should be used when merging data from different partitions. + /// Returns whether any change was made to the equivalence group. + pub fn clear_per_partition_constants(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + while idx < self.classes.len() { + let cls = &mut self.classes[idx]; + if let Some(AcrossPartitions::Heterogeneous) = cls.constant { + change = true; + if cls.len() == 1 { + // If this class becomes trivial, remove it entirely: + self.remove_class_at_idx(idx); + continue; + } else { + cls.constant = None; + } + } + idx += 1; + } + change } - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. + /// Adds the equality `left` = `right` to this equivalence group. New + /// equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. Returns whether the given equality defines + /// a new equivalence class. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if cls.contains(left) { - first_class = Some(idx); - } - if cls.contains(right) { - second_class = Some(idx); - } - } + left: Arc, + right: Arc, + ) -> bool { + let first_class = self.map.get(&left).copied(); + let second_class = self.map.get(&right).copied(); match (first_class, second_class) { (Some(mut first_idx), Some(mut second_idx)) => { // If the given left and right sides belong to different classes, // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention, make sure `second_idx` is larger than `first_idx`. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); + match first_idx.cmp(&second_idx) { + // The equality is already known, return and signal this: + std::cmp::Ordering::Equal => return false, + // Swap indices to ensure `first_idx` is the lesser index. + std::cmp::Ordering::Greater => { + std::mem::swap(&mut first_idx, &mut second_idx); } - // Remove the class at `second_idx` and merge its values with - // the class at `first_idx`. The convention above makes sure - // that `first_idx` is still valid after removing `second_idx`. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); + _ => {} } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.remove_class_at_idx(second_idx); + // Update the lookup table for the second class: + Self::update_lookup_table(&mut self.map, &other_class, first_idx); + self.classes[first_idx].extend(other_class); } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(Arc::clone(right)); + self.map.insert(Arc::clone(&right), group_idx); + self.classes[group_idx].push(right); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(Arc::clone(left)); + self.map.insert(Arc::clone(&left), group_idx); + self.classes[group_idx].push(left); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes.push(EquivalenceClass::new(vec![ - Arc::clone(left), - Arc::clone(right), - ])); + let class = EquivalenceClass::new([left, right]); + Self::update_lookup_table(&mut self.map, &class, self.classes.len()); + self.classes.push(class); + return true; } } + false } - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() + /// Removes the equivalence class at the given index from this group. + fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass { + // Remove the class at the given index: + let cls = self.classes.swap_remove(idx); + // Remove its entries from the lookup table: + for expr in cls.iter() { + self.map.remove(expr); + } + // Update the lookup table for the moved class: + if idx < self.classes.len() { + Self::update_lookup_table(&mut self.map, &self.classes[idx], idx); + } + cls + } + + /// Updates the entry in lookup table for the given equivalence class with + /// the given index. + fn update_lookup_table( + map: &mut HashMap, usize>, + cls: &EquivalenceClass, + idx: usize, + ) { + for expr in cls.iter() { + map.insert(Arc::clone(expr), idx); + } + } + + /// Removes redundant entries from this group. Returns whether any change + /// was made to the equivalence group. + fn remove_redundant_entries(&mut self) -> bool { + // First, remove trivial equivalence classes: + let mut change = false; + for idx in (0..self.classes.len()).rev() { + if self.classes[idx].is_trivial() { + self.remove_class_at_idx(idx); + change = true; + } + } + // Then, unify/bridge groups that have common expressions: + self.bridge_classes() || change } /// This utility function unifies/bridges classes that have common expressions. /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if self.classes[idx].contains_any(&self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); + /// Returns whether any change was made to the equivalence group. + fn bridge_classes(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + 'scan: while idx < self.classes.len() { + for other_idx in (idx + 1..self.classes.len()).rev() { + if self.classes[idx].contains_any(&self.classes[other_idx]) { + let extension = self.remove_class_at_idx(other_idx); + Self::update_lookup_table(&mut self.map, &extension, idx); self.classes[idx].extend(extension); - } else { - next_idx += 1; + change = true; + continue 'scan; } } - if self.classes[idx].len() > start_size { - continue; - } idx += 1; } + change } /// Extends this equivalence group with the `other` equivalence group. - pub fn extend(&mut self, other: Self) { + /// Returns whether any equivalence classes were unified/bridged as a + /// result of the extension process. + pub fn extend(&mut self, other: Self) -> bool { + for (idx, cls) in other.classes.iter().enumerate() { + // Update the lookup table for the new class: + Self::update_lookup_table(&mut self.map, cls, idx); + } self.classes.extend(other.classes); - self.remove_redundant_entries(); + self.bridge_classes() } - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). + /// Normalizes the given physical expression according to this group. The + /// expression is replaced with the first (canonical) expression in the + /// equivalence class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.transform(|expr| { - for cls in self.iter() { - if cls.contains(&expr) { - // The unwrap below is safe because the guard above ensures - // that the class is not empty. - return Ok(Transformed::yes(cls.canonical_expr().unwrap())); - } - } - Ok(Transformed::no(expr)) + let cls = self.get_equivalence_class(&expr); + let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else { + return Ok(Transformed::no(expr)); + }; + Ok(Transformed::yes(Arc::clone(canonical))) }) .data() .unwrap() // The unwrap above is safe because the closure always returns `Ok`. } - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. + /// Normalizes the given sort expression according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the sort expression as is. pub fn normalize_sort_expr( &self, mut sort_expr: PhysicalSortExpr, @@ -484,11 +535,29 @@ impl EquivalenceGroup { sort_expr } - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. + /// Normalizes the given sort expressions (i.e. `sort_exprs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_exprs<'a>( + &'a self, + sort_exprs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_exprs + .into_iter() + .map(|sort_expr| self.normalize_sort_expr(sort_expr)) + .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none()) + } + + /// Normalizes the given sort requirement according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the given sort requirement as is. pub fn normalize_sort_requirement( &self, mut sort_requirement: PhysicalSortRequirement, @@ -497,44 +566,76 @@ impl EquivalenceGroup { sort_requirement } - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs + /// Normalizes the given sort requirements (i.e. `sort_reqs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_requirements<'a>( + &'a self, + sort_reqs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_reqs .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() + .map(|req| self.normalize_sort_requirement(req)) + .filter(|req| self.is_expr_constant(&req.expr).is_none()) } - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// Perform an indirect projection of `expr` by consulting the equivalence + /// classes. + fn project_expr_indirect( + aug_mapping: &AugmentedMapping, + expr: &Arc, + ) -> Option> { + // The given expression is not inside the mapping, so we try to project + // indirectly using equivalence classes. + for (targets, eq_class) in aug_mapping.values() { + // If we match an equivalent expression to a source expression in + // the mapping, then we can project. For example, if we have the + // mapping `(a as a1, a + c)` and the equivalence `a == b`, + // expression `b` projects to `a1`. + if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) { + let (target, _) = targets.first(); + return Some(Arc::clone(target)); + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // A leaf expression should be inside the mapping. + return None; + } + children + .into_iter() + .map(|child| { + // First, we try to project children with an exact match. If + // we are unable to do this, we consult equivalence classes. + if let Some((targets, _)) = aug_mapping.get(child) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + Self::project_expr_indirect(aug_mapping, child) + } + }) + .collect::>>() + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: &LexRequirement, - ) -> LexRequirement { - LexRequirement::new( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - .collapse() + fn augment_projection_mapping<'a>( + &'a self, + mapping: &'a ProjectionMapping, + ) -> AugmentedMapping<'a> { + mapping + .iter() + .map(|(k, v)| { + let eq_class = self.get_equivalence_class(k); + (k, (v, eq_class)) + }) + .collect() } /// Projects `expr` according to the given projection mapping. @@ -544,81 +645,118 @@ impl EquivalenceGroup { mapping: &ProjectionMapping, expr: &Arc, ) -> Option> { - // First, we try to project expressions with an exact match. If we are - // unable to do this, we consult equivalence classes. - if let Some(target) = mapping.target_expr(expr) { + if let Some(targets) = mapping.get(expr) { // If we match the source, we can project directly: - return Some(target); + let (target, _) = targets.first(); + Some(Arc::clone(target)) } else { - // If the given expression is not inside the mapping, try to project - // expressions considering the equivalence classes. - for (source, target) in mapping.iter() { - // If we match an equivalent expression to `source`, then we can - // project. For example, if we have the mapping `(a as a1, a + c)` - // and the equivalence class `(a, b)`, expression `b` projects to `a1`. - if self - .get_equivalence_class(source) - .is_some_and(|group| group.contains(expr)) - { - return Some(Arc::clone(target)); - } - } - } - // Project a non-leaf expression by projecting its children. - let children = expr.children(); - if children.is_empty() { - // Leaf expression should be inside mapping. - return None; + let aug_mapping = self.augment_projection_mapping(mapping); + Self::project_expr_indirect(&aug_mapping, expr) } - children - .into_iter() - .map(|child| self.project_expr(mapping, child)) - .collect::>>() - .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) + } + + /// Projects `expressions` according to the given projection mapping. + /// This function is similar to [`Self::project_expr`], but projects multiple + /// expressions at once more efficiently than calling `project_expr` for each + /// expression. + pub fn project_expressions<'a>( + &'a self, + mapping: &'a ProjectionMapping, + expressions: impl IntoIterator> + 'a, + ) -> impl Iterator>> + 'a { + let mut aug_mapping = None; + expressions.into_iter().map(move |expr| { + if let Some(targets) = mapping.get(expr) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + let aug_mapping = aug_mapping + .get_or_insert_with(|| self.augment_projection_mapping(mapping)); + Self::project_expr_indirect(aug_mapping, expr) + } + }) } /// Projects this equivalence group according to the given projection mapping. pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + let projected_classes = self.iter().map(|cls| { + let new_exprs = self.project_expressions(mapping, cls.iter()); + EquivalenceClass::new(new_exprs.flatten()) }); // The key is the source expression, and the value is the equivalence // class that contains the corresponding target expression. - let mut new_classes: IndexMap<_, _> = IndexMap::new(); - for (source, target) in mapping.iter() { + let mut new_constants = vec![]; + let mut new_classes = IndexMap::<_, EquivalenceClass>::new(); + for (source, targets) in mapping.iter() { // We need to find equivalent projected expressions. For example, // consider a table with columns `[a, b, c]` with `a` == `b`, and // projection `[a + c, b + c]`. To conclude that `a + c == b + c`, // we first normalize all source expressions in the mapping, then // merge all equivalent expressions into the classes. let normalized_expr = self.normalize_expr(Arc::clone(source)); - new_classes - .entry(normalized_expr) - .or_insert_with(EquivalenceClass::new_empty) - .push(Arc::clone(target)); + let cls = new_classes.entry(normalized_expr).or_default(); + for (target, _) in targets.iter() { + cls.push(Arc::clone(target)); + } + // Save new constants arising from the projection: + if let Some(across) = self.is_expr_constant(source) { + for (target, _) in targets.iter() { + let const_expr = ConstExpr::new(Arc::clone(target), across.clone()); + new_constants.push(const_expr); + } + } } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) + // Union projected classes with new classes to make up the result: + let classes = projected_classes + .chain(new_classes.into_values()) + .filter(|cls| !cls.is_trivial()); + let mut result = Self::new(classes); + // Add new constants arising from the projection to the equivalence group: + for constant in new_constants { + result.add_constant(constant); + } + result + } + + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( + &self, + expr: &Arc, + ) -> Option { + if let Some(lit) = expr.as_any().downcast_ref::() { + return Some(AcrossPartitions::Uniform(Some(lit.value().clone()))); + } + if let Some(cls) = self.get_equivalence_class(expr) { + if cls.constant.is_some() { + return cls.constant.clone(); + } + } + // TODO: This function should be able to return values of non-literal + // complex constants as well; e.g. it should return `8` for the + // expression `3 + 5`, not an unknown `heterogenous` value. + let children = expr.children(); + if children.is_empty() { + return None; + } + for child in children { + self.is_expr_constant(child)?; + } + Some(AcrossPartitions::Heterogeneous) } /// Returns the equivalence class containing `expr`. If no equivalence class /// contains `expr`, returns `None`. - fn get_equivalence_class( + pub fn get_equivalence_class( &self, expr: &Arc, ) -> Option<&EquivalenceClass> { - self.iter().find(|cls| cls.contains(expr)) + self.map.get(expr).map(|idx| &self.classes[*idx]) } /// Combine equivalence groups of the given join children. @@ -628,18 +766,16 @@ impl EquivalenceGroup { join_type: &JoinType, left_size: usize, on: &[(PhysicalExprRef, PhysicalExprRef)], - ) -> Self { - match join_type { + ) -> Result { + let group = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { let mut result = Self::new( - self.iter() - .cloned() - .chain( - right_equivalences - .iter() - .map(|cls| cls.with_offset(left_size)), - ) - .collect(), + self.iter().cloned().chain( + right_equivalences + .iter() + .map(|cls| cls.try_with_offset(left_size as _)) + .collect::>>()?, + ), ); // In we have an inner join, expressions in the "on" condition // are equal in the resulting table. @@ -647,36 +783,23 @@ impl EquivalenceGroup { for (lhs, rhs) in on.iter() { let new_lhs = Arc::clone(lhs); // Rewrite rhs to point to the right side of the join: - let new_rhs = Arc::clone(rhs) - .transform(|expr| { - if let Some(column) = - expr.as_any().downcast_ref::() - { - let new_column = Arc::new(Column::new( - column.name(), - column.index() + left_size, - )) - as _; - return Ok(Transformed::yes(new_column)); - } - - Ok(Transformed::no(expr)) - }) - .data() - .unwrap(); - result.add_equal_conditions(&new_lhs, &new_rhs); + let new_rhs = + add_offset_to_expr(Arc::clone(rhs), left_size as _)?; + result.add_equal_conditions(new_lhs, new_rhs); } } result } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } + }; + Ok(group) } - /// Checks if two expressions are equal either directly or through equivalence classes. - /// For complex expressions (e.g. a + b), checks that the expression trees are structurally - /// identical and their leaf nodes are equivalent either directly or through equivalence classes. + /// Checks if two expressions are equal directly or through equivalence + /// classes. For complex expressions (e.g. `a + b`), checks that the + /// expression trees are structurally identical and their leaf nodes are + /// equivalent either directly or through equivalence classes. pub fn exprs_equal( &self, left: &Arc, @@ -726,16 +849,19 @@ impl EquivalenceGroup { .zip(right_children) .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) } +} + +impl Deref for EquivalenceGroup { + type Target = [EquivalenceClass]; - /// Return the inner classes of this equivalence group. - pub fn into_inner(self) -> Vec { - self.classes + fn deref(&self) -> &Self::Target { + &self.classes } } impl IntoIterator for EquivalenceGroup { type Item = EquivalenceClass; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.classes.into_iter() @@ -756,11 +882,28 @@ impl Display for EquivalenceGroup { } } +impl From> for EquivalenceGroup { + fn from(classes: Vec) -> Self { + let mut result = Self { + map: classes + .iter() + .enumerate() + .flat_map(|(idx, cls)| { + cls.iter().map(move |expr| (Arc::clone(expr), idx)) + }) + .collect(), + classes, + }; + result.remove_redundant_entries(); + result + } +} + #[cfg(test)] mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{binary, col, lit, BinaryExpr, Literal}; + use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, ScalarValue}; @@ -786,16 +929,25 @@ mod tests { for (entries, expected) in test_cases { let entries = entries .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); let expected = expected .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); + let eq_groups: EquivalenceGroup = entries.clone().into(); let eq_groups = eq_groups.classes; let err_msg = format!( "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}" @@ -810,58 +962,45 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { + let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _; let entries = [ - EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), - // This group is meaningless should be removed - EquivalenceClass::new(vec![lit(3), lit(3)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), c(1), lit(20)]), + EquivalenceClass::new([lit(30), lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. let expected = [ - EquivalenceClass::new(vec![lit(1), lit(2)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), lit(20)]), + EquivalenceClass::new([lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert_eq!(eq_groups[0], expected[0]); - assert_eq!(eq_groups[1], expected[1]); + let eq_groups = EquivalenceGroup::new(entries); + assert_eq!(eq_groups.classes, expected); Ok(()) } #[test] fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. + let (_, eq_properties) = create_test_params()?; + // Test cases for equivalence normalization. First entry in the tuple is + // the argument, second entry is expected result after normalization. let expressions = vec![ // Normalized version of the column a and c should go to a // (by convention all the expressions inside equivalence class are mapped to the first entry // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), + (Arc::clone(&col_a), Arc::clone(&col_a)), + (col_c, col_a), // Cannot normalize column b - (&col_b_expr, &col_b_expr), + (Arc::clone(&col_b), Arc::clone(&col_b)), ]; let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), - "error in test: expr: {expr:?}" - ); + assert!(expected_eq.eq(&eq_group.normalize_expr(expr))); } Ok(()) @@ -869,21 +1008,15 @@ mod tests { #[test] fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let cls1 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); - let cls2 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); - let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); + let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _; + let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _; + let col_a_expr = Arc::new(Column::new("a", 0)) as _; + let col_b_expr = Arc::new(Column::new("b", 1)) as _; + let col_c_expr = Arc::new(Column::new("c", 2)) as _; + + let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]); + let cls2 = EquivalenceClass::new([lit_true, col_b_expr]); + let cls3 = EquivalenceClass::new([col_c_expr, lit_false]); // lit_true is common assert!(cls1.contains_any(&cls2)); @@ -902,21 +1035,19 @@ mod tests { } // Create test columns - let col_a = Arc::new(Column::new("a", 0)) as Arc; - let col_b = Arc::new(Column::new("b", 1)) as Arc; - let col_x = Arc::new(Column::new("x", 2)) as Arc; - let col_y = Arc::new(Column::new("y", 3)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_x = Arc::new(Column::new("x", 2)) as _; + let col_y = Arc::new(Column::new("y", 3)) as _; // Create test literals - let lit_1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let lit_2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _; + let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _; // Create equivalence group with classes (a = x) and (b = y) - let eq_group = EquivalenceGroup::new(vec![ - EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]), - EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]), + let eq_group = EquivalenceGroup::new([ + EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]), + EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]), ]); let test_cases = vec![ @@ -966,12 +1097,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_y), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent operands should be equal", @@ -981,12 +1112,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_a), - )) as Arc, + )) as _, expected: false, description: "Binary expressions with non-equivalent operands should not be equal", @@ -996,12 +1127,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent column and same literal should be equal", }, @@ -1014,7 +1145,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::clone(&col_x), @@ -1023,7 +1154,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Nested binary expressions with equivalent operands should be equal", }, @@ -1057,36 +1188,36 @@ mod tests { Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let mut group = EquivalenceGroup::empty(); - group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?); + let mut group = EquivalenceGroup::default(); + group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?); let projected_schema = Arc::new(Schema::new(vec![ Field::new("a+c", DataType::Int32, false), Field::new("b+c", DataType::Int32, false), ])); - let mapping = ProjectionMapping { - map: vec![ - ( - binary( - col("a", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("a+c", &projected_schema)?, - ), - ( - binary( - col("b", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("b+c", &projected_schema)?, - ), - ], - }; + let mapping = [ + ( + binary( + col("a", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("a+c", &projected_schema)?, 0)].into(), + ), + ( + binary( + col("b", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("b+c", &projected_schema)?, 1)].into(), + ), + ] + .into_iter() + .collect::(); let projected = group.project(&mapping); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index ef98b4812265..ecb73be256d4 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Borrow; use std::sync::Arc; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr}; +use crate::PhysicalExpr; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use arrow::compute::SortOptions; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; mod class; mod ordering; @@ -34,50 +35,34 @@ pub use properties::{ calculate_union, join_equivalence_properties, EquivalenceProperties, }; -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. -/// -/// It will also filter out entries that are ordered if the next entry is; -/// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to -/// `vec![a Some(ASC)]`. -#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")] -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - input.collapse() +// Convert each tuple to a `PhysicalSortExpr` and construct a vector. +pub fn convert_to_sort_exprs>>( + args: &[(T, SortOptions)], +) -> Vec { + args.iter() + .map(|(expr, options)| PhysicalSortExpr::new(Arc::clone(expr.borrow()), *options)) + .collect() } -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::no(e)), - }) - .data() - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. +// Convert each vector of tuples to a `LexOrdering`. +pub fn convert_to_orderings>>( + args: &[Vec<(T, SortOptions)>], +) -> Vec { + args.iter() + .filter_map(|sort_exprs| LexOrdering::new(convert_to_sort_exprs(sort_exprs))) + .collect() } #[cfg(test)] mod tests { - use super::*; - use crate::expressions::col; - use crate::PhysicalSortExpr; + use crate::expressions::{col, Column}; + use crate::{LexRequirement, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{plan_datafusion_err, Result}; - use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, PhysicalSortRequirement, - }; + use datafusion_common::{plan_err, Result}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; /// Converts a string to a physical sort expression /// @@ -114,27 +99,21 @@ mod tests { mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -163,15 +142,15 @@ mod tests { /// Column [a=c] (e.g they are aliases). pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; + let col_a = col("a", &test_schema)?; + let col_b = col("b", &test_schema)?; + let col_c = col("c", &test_schema)?; + let col_d = col("d", &test_schema)?; + let col_e = col("e", &test_schema)?; + let col_f = col("f", &test_schema)?; + let col_g = col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_c))?; let option_asc = SortOptions { descending: false, @@ -194,68 +173,19 @@ mod tests { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } - // Convert each tuple to PhysicalSortRequirement + // Convert each tuple to a `PhysicalSortRequirement` and construct a + // a `LexRequirement` from them. pub fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], + args: &[(&Arc, Option)], ) -> LexRequirement { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new(Arc::clone(*expr), *options) - }) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs_owned( - in_data: &[(Arc, SortOptions)], - ) -> LexOrdering { - LexOrdering::new( - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options: *options, - }) - .collect(), - ) - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings_owned( - orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) - .collect() + let exprs = args.iter().map(|(expr, options)| { + PhysicalSortRequirement::new(Arc::clone(*expr), *options) + }); + LexRequirement::new(exprs).unwrap() } #[test] @@ -269,49 +199,49 @@ mod tests { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 819f8905bda5..875c2a76e5eb 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -16,115 +16,83 @@ // under the License. use std::fmt::Display; -use std::hash::Hash; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; -use crate::equivalence::add_offset_to_expr; -use crate::{LexOrdering, PhysicalExpr}; +use crate::expressions::with_new_schema; +use crate::{add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr}; use arrow::compute::SortOptions; -use datafusion_common::HashSet; +use arrow::datatypes::SchemaRef; +use datafusion_common::{HashSet, Result}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: +/// An `OrderingEquivalenceClass` keeps track of distinct alternative orderings +/// than can describe a table. For example, consider the following table: /// /// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| +/// ┌───┬───┬───┬───┐ +/// │ a │ b │ c │ d │ +/// ├───┼───┼───┼───┤ +/// │ 1 │ 4 │ 3 │ 1 │ +/// │ 2 │ 3 │ 3 │ 2 │ +/// │ 3 │ 1 │ 2 │ 2 │ +/// │ 3 │ 2 │ 1 │ 3 │ +/// └───┴───┴───┴───┘ /// ``` /// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// Here, both `[a ASC, b ASC]` and `[c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +/// +/// An `OrderingEquivalenceClass` is a set of such equivalent orderings, which +/// is represented by a vector of `LexOrdering`s. The set does not store any +/// redundant information by enforcing the invariant that no suffix of an +/// ordering in the equivalence class is a prefix of another ordering in the +/// equivalence class. The set can be empty, which means that there are no +/// orderings that describe the table. +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct OrderingEquivalenceClass { orderings: Vec, } impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - pub fn empty() -> Self { - Default::default() - } - /// Clears (empties) this ordering equivalence class. pub fn clear(&mut self) { self.orderings.clear(); } - /// Creates new ordering equivalence class from the given orderings - /// - /// Any redundant entries are removed - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; + /// Creates a new ordering equivalence class from the given orderings + /// and removes any redundant entries (if given). + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut result = Self { + orderings: orderings.into_iter().filter_map(LexOrdering::new).collect(), + }; result.remove_redundant_entries(); result } - /// Converts this OrderingEquivalenceClass to a vector of orderings. - pub fn into_inner(self) -> Vec { - self.orderings - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - #[deprecated( - since = "45.0.0", - note = "use OrderingEquivalenceClass::add_new_ordering instead" - )] - fn push(&mut self, ordering: LexOrdering) { - self.add_new_ordering(ordering) - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - /// - /// Note this class also implements [`IntoIterator`] to return an iterator - /// over owned [`LexOrdering`]s. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); + /// Extend this ordering equivalence class with the given orderings. + pub fn extend(&mut self, orderings: impl IntoIterator) { + self.orderings.extend(orderings); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds new orderings into this ordering equivalence class - pub fn add_new_orderings( + /// Adds new orderings into this ordering equivalence class. + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + sort_exprs: impl IntoIterator>, ) { - self.orderings.extend(orderings); + self.orderings + .extend(sort_exprs.into_iter().filter_map(LexOrdering::new)); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); - } - - /// Removes redundant orderings from this equivalence class. + /// Removes redundant orderings from this ordering equivalence class. /// /// For instance, if we already have the ordering `[a ASC, b ASC, c DESC]`, /// then there is no need to keep ordering `[a ASC, b ASC]` in the state. @@ -133,82 +101,72 @@ impl OrderingEquivalenceClass { while work { work = false; let mut idx = 0; - while idx < self.orderings.len() { + 'outer: while idx < self.orderings.len() { let mut ordering_idx = idx + 1; - let mut removal = self.orderings[idx].is_empty(); while ordering_idx < self.orderings.len() { - work |= self.resolve_overlap(idx, ordering_idx); - if self.orderings[idx].is_empty() { - removal = true; - break; + if let Some(remove) = self.resolve_overlap(idx, ordering_idx) { + work = true; + if remove { + self.orderings.swap_remove(idx); + continue 'outer; + } } - work |= self.resolve_overlap(ordering_idx, idx); - if self.orderings[ordering_idx].is_empty() { - self.orderings.swap_remove(ordering_idx); - } else { - ordering_idx += 1; + if let Some(remove) = self.resolve_overlap(ordering_idx, idx) { + work = true; + if remove { + self.orderings.swap_remove(ordering_idx); + continue; + } } + ordering_idx += 1; } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } + idx += 1; } } } /// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of - /// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. + /// `orderings[pre_idx]`. If there is any overlap, returns a `Some(true)` + /// if any trimming took place, and `Some(false)` otherwise. If there is + /// no overlap, returns `None`. /// /// For example, if `orderings[idx]` is `[a ASC, b ASC, c DESC]` and /// `orderings[pre_idx]` is `[b ASC, c DESC]`, then the function will trim /// `orderings[idx]` to `[a ASC]`. - fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> bool { + fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> Option { let length = self.orderings[idx].len(); let other_length = self.orderings[pre_idx].len(); for overlap in 1..=length.min(other_length) { if self.orderings[idx][length - overlap..] == self.orderings[pre_idx][..overlap] { - self.orderings[idx].truncate(length - overlap); - return true; + return Some(!self.orderings[idx].truncate(length - overlap)); } } - false + None } /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = self - .orderings - .iter() - .flatten() - .cloned() - .collect::() - .collapse(); - (!output_ordering.is_empty()).then_some(output_ordering) + self.orderings.iter().cloned().reduce(|mut cat, o| { + cat.extend(o); + cat + }) } - // Append orderings in `other` to all existing orderings in this equivalence - // class. + // Append orderings in `other` to all existing orderings in this ordering + // equivalence class. pub fn join_suffix(mut self, other: &Self) -> Self { let n_ordering = self.orderings.len(); - // Replicate entries before cross product + // Replicate entries before cross product: let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); - self.orderings = self - .orderings - .iter() - .cloned() - .cycle() - .take(n_cross) - .collect(); - // Suffix orderings of other to the current orderings. + self.orderings = self.orderings.into_iter().cycle().take(n_cross).collect(); + // Append sort expressions of `other` to the current orderings: for (outer_idx, ordering) in other.iter().enumerate() { - for idx in 0..n_ordering { - // Calculate cross product index - let idx = outer_idx * n_ordering + idx; + let base = outer_idx * n_ordering; + // Use the cross product index: + for idx in base..(base + n_ordering) { self.orderings[idx].extend(ordering.iter().cloned()); } } @@ -217,12 +175,40 @@ impl OrderingEquivalenceClass { /// Adds `offset` value to the index of each expression inside this /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - ordering.transform(|sort_expr| { - sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); - }) + pub fn add_offset(&mut self, offset: isize) -> Result<()> { + let orderings = std::mem::take(&mut self.orderings); + for ordering_result in orderings + .into_iter() + .map(|o| add_offset_to_physical_sort_exprs(o, offset)) + { + self.orderings.extend(LexOrdering::new(ordering_result?)); } + Ok(()) + } + + /// Transforms this `OrderingEquivalenceClass` by mapping columns in the + /// original schema to columns in the new schema by index. The new schema + /// and the original schema needs to be aligned; i.e. they should have the + /// same number of columns, and fields at the same index have the same type + /// in both schemas. + pub fn with_new_schema(mut self, schema: &SchemaRef) -> Result { + self.orderings = self + .orderings + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, schema)?; + Ok(sort_expr) + }) + .collect::>>() + // The following `unwrap` is safe because the vector will always + // be non-empty. + .map(|v| LexOrdering::new(v).unwrap()) + }) + .collect::>()?; + Ok(self) } /// Gets sort options associated with this expression if it is a leading @@ -257,31 +243,6 @@ impl OrderingEquivalenceClass { /// added as a constant during `ordering_satisfy_requirement()` iterations /// after the corresponding prefix requirement is satisfied. /// - /// ### Example Scenarios - /// - /// In these scenarios, we assume that all expressions share the same sort - /// properties. - /// - /// #### Case 1: Sort Requirement `[a, c]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `false` for `c`, as neither - /// `[b, c]` nor `[d]` satisfies `c`. - /// - /// #### Case 2: Sort Requirement `[a, d]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `true` for `d`, as `[d]` satisfies - /// `d`. - /// /// ### Future Improvements /// /// This function may become unnecessary if any of the following improvements @@ -296,15 +257,14 @@ impl OrderingEquivalenceClass { ]; for ordering in self.iter() { - if let Some(leading_ordering) = ordering.first() { - if leading_ordering.expr.eq(expr) { - let opt = ( - leading_ordering.options.descending, - leading_ordering.options.nulls_first, - ); - constantness_defining_pairs[0].remove(&opt); - constantness_defining_pairs[1].remove(&opt); - } + let leading_ordering = ordering.first(); + if leading_ordering.expr.eq(expr) { + let opt = ( + leading_ordering.options.descending, + leading_ordering.options.nulls_first, + ); + constantness_defining_pairs[0].remove(&opt); + constantness_defining_pairs[1].remove(&opt); } } @@ -314,10 +274,26 @@ impl OrderingEquivalenceClass { } } -/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings +impl Deref for OrderingEquivalenceClass { + type Target = [LexOrdering]; + + fn deref(&self) -> &Self::Target { + self.orderings.as_slice() + } +} + +impl From> for OrderingEquivalenceClass { + fn from(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } +} + +/// Convert the `OrderingEquivalenceClass` into an iterator of `LexOrdering`s. impl IntoIterator for OrderingEquivalenceClass { type Item = LexOrdering; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.orderings.into_iter() @@ -334,8 +310,13 @@ impl Display for OrderingEquivalenceClass { for ordering in iter { write!(f, ", [{ordering}]")?; } - write!(f, "]")?; - Ok(()) + write!(f, "]") + } +} + +impl From for Vec { + fn from(oeq_class: OrderingEquivalenceClass) -> Self { + oeq_class.orderings } } @@ -343,12 +324,10 @@ impl Display for OrderingEquivalenceClass { mod tests { use std::sync::Arc; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_test_schema, - }; + use crate::equivalence::tests::create_test_schema; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, EquivalenceProperties, - OrderingEquivalenceClass, + convert_to_orderings, convert_to_sort_exprs, EquivalenceClass, EquivalenceGroup, + EquivalenceProperties, OrderingEquivalenceClass, }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; @@ -361,7 +340,6 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -369,11 +347,11 @@ mod tests { Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), ])); - let crude = LexOrdering::new(vec![PhysicalSortExpr { + let crude = vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), - }]); - let finer = LexOrdering::new(vec![ + }]; + let finer = vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -382,20 +360,18 @@ mod tests { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), }, - ]); + ]; // finer ordering satisfies, crude ordering should return true let eq_properties_finer = EquivalenceProperties::new_with_orderings( Arc::clone(&input_schema), - &[finer.clone()], + [finer.clone()], ); - assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); + assert!(eq_properties_finer.ordering_satisfy(crude.clone())?); // Crude ordering doesn't satisfy finer ordering. should return false - let eq_properties_crude = EquivalenceProperties::new_with_orderings( - Arc::clone(&input_schema), - &[crude.clone()], - ); - assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); + let eq_properties_crude = + EquivalenceProperties::new_with_orderings(Arc::clone(&input_schema), [crude]); + assert!(!eq_properties_crude.ordering_satisfy(finer)?); Ok(()) } @@ -663,29 +639,20 @@ mod tests { format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - let eq_group = eq_group + eq_properties.add_orderings(orderings); + let classes = eq_group .into_iter() - .map(|eq_class| { - let eq_classes = eq_class.into_iter().cloned().collect::>(); - EquivalenceClass::new(eq_classes) - }) - .collect::>(); - let eq_group = EquivalenceGroup::new(eq_group); - eq_properties.add_equivalence_group(eq_group); + .map(|eq_class| EquivalenceClass::new(eq_class.into_iter().cloned())); + let eq_group = EquivalenceGroup::new(classes); + eq_properties.add_equivalence_group(eq_group)?; let constants = constants.into_iter().map(|expr| { - ConstExpr::from(expr) - .with_across_partitions(AcrossPartitions::Uniform(None)) + ConstExpr::new(Arc::clone(expr), AcrossPartitions::Uniform(None)) }); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{err_msg}" - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -706,7 +673,7 @@ mod tests { }; // a=c (e.g they are aliases). let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let orderings = vec![ vec![(col_a, options)], @@ -716,7 +683,7 @@ mod tests { let orderings = convert_to_orderings(&orderings); // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); // First entry in the tuple is required ordering, second entry is the expected flag // that indicates whether this required ordering is satisfied. @@ -740,11 +707,7 @@ mod tests { let err_msg = format!("error in test reqs: {reqs:?}, expected: {expected:?}",); let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{err_msg}" - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -854,7 +817,7 @@ mod tests { // ------- TEST CASE 5 --------- // Empty ordering ( - vec![vec![]], + vec![], // No ordering in the state (empty ordering is ignored). vec![], ), @@ -973,8 +936,7 @@ mod tests { for (orderings, expected) in test_cases { let orderings = convert_to_orderings(&orderings); let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; + let actual = OrderingEquivalenceClass::from(orderings.clone()); let err_msg = format!( "orderings: {orderings:?}, expected: {expected:?}, actual :{actual:?}" ); diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index efbb50bc40e1..38bb1fef8074 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; use std::sync::Arc; use crate::expressions::Column; @@ -24,13 +25,52 @@ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, Result}; +use indexmap::IndexMap; + +/// Stores target expressions, along with their indices, that associate with a +/// source expression in a projection mapping. +#[derive(Clone, Debug, Default)] +pub struct ProjectionTargets { + /// A non-empty vector of pairs of target expressions and their indices. + /// Consider using a special non-empty collection type in the future (e.g. + /// if Rust provides one in the standard library). + exprs_indices: Vec<(Arc, usize)>, +} + +impl ProjectionTargets { + /// Returns the first target expression and its index. + pub fn first(&self) -> &(Arc, usize) { + // Since the vector is non-empty, we can safely unwrap: + self.exprs_indices.first().unwrap() + } + + /// Adds a target expression and its index to the list of targets. + pub fn push(&mut self, target: (Arc, usize)) { + self.exprs_indices.push(target); + } +} + +impl Deref for ProjectionTargets { + type Target = [(Arc, usize)]; + + fn deref(&self) -> &Self::Target { + &self.exprs_indices + } +} + +impl From, usize)>> for ProjectionTargets { + fn from(exprs_indices: Vec<(Arc, usize)>) -> Self { + Self { exprs_indices } + } +} + /// Stores the mapping between source expressions and target expressions for a /// projection. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ProjectionMapping { /// Mapping between source expressions and target expressions. /// Vector indices correspond to the indices after projection. - pub map: Vec<(Arc, Arc)>, + map: IndexMap, ProjectionTargets>, } impl ProjectionMapping { @@ -42,44 +82,46 @@ impl ProjectionMapping { /// projection mapping would be: /// /// ```text - /// [0]: (c + d, col("c + d")) - /// [1]: (a + b, col("a + b")) + /// [0]: (c + d, [(col("c + d"), 0)]) + /// [1]: (a + b, [(col("a + b"), 1)]) /// ``` /// /// where `col("c + d")` means the column named `"c + d"`. pub fn try_new( - expr: &[(Arc, String)], + expr: impl IntoIterator, String)>, input_schema: &SchemaRef, ) -> Result { // Construct a map from the input expressions to the output expression of the projection: - expr.iter() - .enumerate() - .map(|(expr_idx, (expression, name))| { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - Arc::clone(expression) - .transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - if col.name() != matching_input_field.name() { - return internal_err!("Input field name {} does not match with the projection expression {}", - matching_input_field.name(),col.name()) - } - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::no(e)), - }) - .data() - .map(|source_expr| (source_expr, target_expr)) + let mut map = IndexMap::<_, ProjectionTargets>::new(); + for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { + let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; + let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_field = input_schema.field(idx); + let matching_name = matching_field.name(); + if col.name() != matching_name { + return internal_err!( + "Input field name {} does not match with the projection expression {}", + matching_name, + col.name() + ); + } + let matching_column = Column::new(matching_name, idx); + Ok(Transformed::yes(Arc::new(matching_column))) + } + None => Ok(Transformed::no(e)), }) - .collect::>>() - .map(|map| Self { map }) + .data()?; + map.entry(source_expr) + .or_default() + .push((target_expr, expr_idx)); + } + Ok(Self { map }) } /// Constructs a subset mapping using the provided indices. @@ -87,61 +129,38 @@ impl ProjectionMapping { /// This is used when the output is a subset of the input without any /// other transformations. The indices are for columns in the schema. pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result { - let projection_exprs = project_index_to_exprs(indices, schema); - ProjectionMapping::try_new(&projection_exprs, schema) + let projection_exprs = indices.iter().map(|index| { + let field = schema.field(*index); + let column = Arc::new(Column::new(field.name(), *index)); + (column as _, field.name().clone()) + }); + ProjectionMapping::try_new(projection_exprs, schema) } +} - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.map.iter() - } +impl Deref for ProjectionMapping { + type Target = IndexMap, ProjectionTargets>; - /// This function returns the target expression for a given source expression. - /// - /// # Arguments - /// - /// * `expr` - Source physical expression. - /// - /// # Returns - /// - /// An `Option` containing the target for the given source expression, - /// where a `None` value means that `expr` is not inside the mapping. - pub fn target_expr( - &self, - expr: &Arc, - ) -> Option> { - self.map - .iter() - .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| Arc::clone(target)) + fn deref(&self) -> &Self::Target { + &self.map } } -fn project_index_to_exprs( - projection_index: &[usize], - schema: &SchemaRef, -) -> Vec<(Arc, String)> { - projection_index - .iter() - .map(|index| { - let field = schema.field(*index); - ( - Arc::new(Column::new(field.name(), *index)) as Arc, - field.name().to_owned(), - ) - }) - .collect::>() +impl FromIterator<(Arc, ProjectionTargets)> for ProjectionMapping { + fn from_iter, ProjectionTargets)>>( + iter: T, + ) -> Self { + Self { + map: IndexMap::from_iter(iter), + } + } } #[cfg(test)] mod tests { use super::*; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_orderings_owned, output_schema, - }; - use crate::equivalence::EquivalenceProperties; + use crate::equivalence::tests::output_schema; + use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; use crate::expressions::{col, BinaryExpr}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; @@ -608,13 +627,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let expected = expected @@ -628,7 +646,7 @@ mod tests { .collect::>() }) .collect::>(); - let expected = convert_to_orderings_owned(&expected); + let expected = convert_to_orderings(&expected); let projected_eq = eq_properties.project(&projection_mapping, output_schema); let orderings = projected_eq.oeq_class(); @@ -686,9 +704,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_new = &col("a_new", &output_schema)?; @@ -812,7 +829,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(expected); @@ -866,9 +883,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_plus_b_new = &col("a+b", &output_schema)?; @@ -953,11 +969,11 @@ mod tests { for (orderings, equal_columns, expected) in test_cases { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs)?; + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?; } let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(&expected); diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index fa52ae8686f7..4554e36f766d 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -16,28 +16,28 @@ // under the License. use std::fmt::{self, Display}; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use super::expr_refers; use crate::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use indexmap::IndexSet; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; -use super::{expr_refers, ExprWrapper}; - // A list of sort expressions that can be calculated from a known set of /// dependencies. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Dependencies { - inner: IndexSet, + sort_exprs: IndexSet, } impl Display for Dependencies { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "[")?; - let mut iter = self.inner.iter(); + let mut iter = self.sort_exprs.iter(); if let Some(dep) = iter.next() { write!(f, "{dep}")?; } @@ -49,38 +49,34 @@ impl Display for Dependencies { } impl Dependencies { - /// Create a new empty `Dependencies` instance. - fn new() -> Self { + // Creates a new `Dependencies` instance from the given sort expressions. + pub fn new(sort_exprs: impl IntoIterator) -> Self { Self { - inner: IndexSet::new(), + sort_exprs: sort_exprs.into_iter().collect(), } } +} - /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. - pub fn new_from_iter(iter: impl IntoIterator) -> Self { - Self { - inner: iter.into_iter().collect(), - } - } +impl Deref for Dependencies { + type Target = IndexSet; - /// Insert a new dependency into the set. - pub fn insert(&mut self, sort_expr: PhysicalSortExpr) { - self.inner.insert(sort_expr); + fn deref(&self) -> &Self::Target { + &self.sort_exprs } +} - /// Iterator over dependencies in the set - pub fn iter(&self) -> impl Iterator + Clone { - self.inner.iter() +impl DerefMut for Dependencies { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sort_exprs } +} - /// Return the inner set of dependencies - pub fn into_inner(self) -> IndexSet { - self.inner - } +impl IntoIterator for Dependencies { + type Item = PhysicalSortExpr; + type IntoIter = as IntoIterator>::IntoIter; - /// Returns true if there are no dependencies - fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.sort_exprs.into_iter() } } @@ -133,26 +129,25 @@ impl<'a> DependencyEnumerator<'a> { let node = dependency_map .get(referred_sort_expr) .expect("`referred_sort_expr` should be inside `dependency_map`"); - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // Since we work on intermediate nodes, we are sure `node.target` exists. + let target = node.target.as_ref().unwrap(); // An empty dependency means the referred_sort_expr represents a global ordering. // Return its projected version, which is the target_expression. if node.dependencies.is_empty() { - return vec![LexOrdering::new(vec![target_sort_expr.clone()])]; + return vec![[target.clone()].into()]; }; node.dependencies .iter() .flat_map(|dep| { - let mut orderings = if self.insert(target_sort_expr, dep) { + let mut orderings = if self.insert(target, dep) { self.construct_orderings(dep, dependency_map) } else { vec![] }; for ordering in orderings.iter_mut() { - ordering.push(target_sort_expr.clone()) + ordering.push(target.clone()); } orderings }) @@ -178,70 +173,55 @@ impl<'a> DependencyEnumerator<'a> { /// # Note on IndexMap Rationale /// /// Using `IndexMap` (which preserves insert order) to ensure consistent results -/// across different executions for the same query. We could have used -/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// across different executions for the same query. We could have used `HashSet` +/// and `HashMap` instead without any loss of functionality. /// /// As an example, if existing orderings are /// 1. `[a ASC, b ASC]` -/// 2. `[c ASC]` for +/// 2. `[c ASC]` /// /// Then both the following output orderings are valid /// 1. `[a ASC, b ASC, c ASC]` /// 2. `[c ASC, a ASC, b ASC]` /// -/// (this are both valid as they are concatenated versions of the alternative -/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate -/// consistent result, among the possible 2 results in the example above. -#[derive(Debug)] +/// These are both valid as they are concatenated versions of the alternative +/// orderings. Had we used `HashSet`/`HashMap`, we couldn't guarantee to generate +/// the same result among the possible two results in the example above. +#[derive(Debug, Default)] pub struct DependencyMap { - inner: IndexMap, + map: IndexMap, } impl DependencyMap { - pub fn new() -> Self { - Self { - inner: IndexMap::new(), - } - } - - /// Insert a new dependency `sort_expr` --> `dependency` into the map. - /// - /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + /// Insert a new dependency of `sort_expr` (i.e. `dependency`) into the map + /// along with its target sort expression. pub fn insert( &mut self, - sort_expr: &PhysicalSortExpr, - target_sort_expr: Option<&PhysicalSortExpr>, - dependency: Option<&PhysicalSortExpr>, + sort_expr: PhysicalSortExpr, + target_sort_expr: Option, + dependency: Option, ) { - self.inner - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.cloned(), - dependencies: Dependencies::new(), - }) - .insert_dependency(dependency) - } - - /// Iterator over (sort_expr, DependencyNode) pairs - pub fn iter(&self) -> impl Iterator { - self.inner.iter() + let entry = self.map.entry(sort_expr); + let node = entry.or_insert_with(|| DependencyNode { + target: target_sort_expr, + dependencies: Dependencies::default(), + }); + node.dependencies.extend(dependency); } +} - /// iterator over all sort exprs - pub fn sort_exprs(&self) -> impl Iterator { - self.inner.keys() - } +impl Deref for DependencyMap { + type Target = IndexMap; - /// Return the dependency node for the given sort expression, if any - pub fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { - self.inner.get(sort_expr) + fn deref(&self) -> &Self::Target { + &self.map } } impl Display for DependencyMap { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "DependencyMap: {{")?; - for (sort_expr, node) in self.inner.iter() { + for (sort_expr, node) in self.map.iter() { writeln!(f, " {sort_expr} --> {node}")?; } writeln!(f, "}}") @@ -256,29 +236,20 @@ impl Display for DependencyMap { /// /// # Fields /// -/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target -/// sort expression associated with the node. It is `None` if the sort expression +/// - `target`: An optional `PhysicalSortExpr` representing the target sort +/// expression associated with the node. It is `None` if the sort expression /// cannot be projected. /// - `dependencies`: A [`Dependencies`] containing dependencies on other sort /// expressions that are referred to by the target sort expression. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DependencyNode { - pub target_sort_expr: Option, - pub dependencies: Dependencies, -} - -impl DependencyNode { - /// Insert dependency to the state (if exists). - fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { - if let Some(dep) = dependency { - self.dependencies.insert(dep.clone()); - } - } + pub(crate) target: Option, + pub(crate) dependencies: Dependencies, } impl Display for DependencyNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(target) = &self.target_sort_expr { + if let Some(target) = &self.target { write!(f, "(target: {target}, ")?; } else { write!(f, "(")?; @@ -307,12 +278,12 @@ pub fn referred_dependencies( source: &Arc, ) -> Vec { // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: - let mut expr_to_sort_exprs = IndexMap::::new(); + let mut expr_to_sort_exprs = IndexMap::<_, Dependencies>::new(); for sort_expr in dependency_map - .sort_exprs() + .keys() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(Arc::clone(&sort_expr.expr)); + let key = Arc::clone(&sort_expr.expr); expr_to_sort_exprs .entry(key) .or_default() @@ -322,16 +293,10 @@ pub fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - let dependencies = expr_to_sort_exprs + expr_to_sort_exprs .into_values() - .map(Dependencies::into_inner) - .collect::>(); - dependencies - .iter() .multi_cartesian_product() - .map(|referred_deps| { - Dependencies::new_from_iter(referred_deps.into_iter().cloned()) - }) + .map(Dependencies::new) .collect() } @@ -378,46 +343,39 @@ pub fn construct_prefix_orderings( /// # Parameters /// /// * `dependencies` - Set of relevant expressions. -/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` +/// * `dependency_map` - Map of dependencies for expressions that may appear in +/// `dependencies`. /// /// # Returns /// -/// A vector of lexical orderings (`Vec`) representing all valid orderings -/// based on the given dependencies. +/// A vector of lexical orderings (`Vec`) representing all valid +/// orderings based on the given dependencies. pub fn generate_dependency_orderings( dependencies: &Dependencies, dependency_map: &DependencyMap, ) -> Vec { // Construct all the valid prefix orderings for each expression appearing - // in the projection: - let relevant_prefixes = dependencies + // in the projection. Note that if relevant prefixes are empty, there is no + // dependency, meaning that dependent is a leading ordering. + dependencies .iter() - .flat_map(|dep| { + .filter_map(|dep| { let prefixes = construct_prefix_orderings(dep, dependency_map); (!prefixes.is_empty()).then_some(prefixes) }) - .collect::>(); - - // No dependency, dependent is a leading ordering. - if relevant_prefixes.is_empty() { - // Return an empty ordering: - return vec![LexOrdering::default()]; - } - - relevant_prefixes - .into_iter() + // Generate all possible valid orderings: .multi_cartesian_product() .flat_map(|prefix_orderings| { + let length = prefix_orderings.len(); prefix_orderings - .iter() - .permutations(prefix_orderings.len()) - .map(|prefixes| { - prefixes - .into_iter() - .flat_map(|ordering| ordering.clone()) - .collect() + .into_iter() + .permutations(length) + .filter_map(|prefixes| { + prefixes.into_iter().reduce(|mut acc, ordering| { + acc.extend(ordering); + acc + }) }) - .collect::>() }) .collect() } @@ -429,10 +387,10 @@ mod tests { use super::*; use crate::equivalence::tests::{ - convert_to_sort_exprs, convert_to_sort_reqs, create_test_params, - create_test_schema, output_schema, parse_sort_expr, + convert_to_sort_reqs, create_test_params, create_test_schema, output_schema, + parse_sort_expr, }; - use crate::equivalence::ProjectionMapping; + use crate::equivalence::{convert_to_sort_exprs, ProjectionMapping}; use crate::expressions::{col, BinaryExpr, CastExpr, Column}; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; @@ -441,9 +399,11 @@ mod tests { use datafusion_common::{Constraint, Constraints, Result}; use datafusion_expr::sort_properties::SortProperties; use datafusion_expr::Operator; - use datafusion_functions::string::concat; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{ + LexRequirement, PhysicalSortRequirement, + }; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -463,7 +423,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 @@ -473,7 +433,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let col_a1 = &col("a1", &out_schema)?; @@ -506,20 +466,20 @@ mod tests { let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); // add equivalent ordering [a, b, c, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("b", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("d", &input_schema), - ])); + ]); // add equivalent ordering [a, c, b, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("b", &input_schema), // NB b and c are swapped parse_sort_expr("d", &input_schema), - ])); + ]); // simply project all the columns in order let proj_exprs = vec![ @@ -528,7 +488,7 @@ mod tests { (col("c", &input_schema)?, "c".to_string()), (col("d", &input_schema)?, "d".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_properties = input_properties.project(&projection_mapping, input_schema); assert_eq!( @@ -541,8 +501,6 @@ mod tests { #[test] fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - let schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -553,35 +511,19 @@ mod tests { let col_c_expr = col("c", &schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; - let others = vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), - ]; - eq_properties.add_new_orderings(others); + eq_properties.add_equal_conditions(col_a_expr, Arc::clone(&col_c_expr))?; + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new_default(Arc::clone(&col_b_expr))], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_c_expr))], + ]); let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), + expected_eqs.add_orderings([ + vec![PhysicalSortExpr::new_default(col_b_expr)], + vec![PhysicalSortExpr::new_default(col_c_expr)], ]); - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - + assert!(eq_properties.oeq_class().eq(expected_eqs.oeq_class())); Ok(()) } @@ -594,34 +536,22 @@ mod tests { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let schema = Schema::new(vec![ @@ -629,40 +559,28 @@ mod tests { Field::new("b", DataType::Int32, true), Field::new("c", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }]), - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new( + Arc::new(Column::new("c", 2)), + sort_options, + )], + vec![ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ], ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let required_columns = [ @@ -677,21 +595,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); // not satisfied orders - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("c", 2)), sort_options), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0]); Ok(()) @@ -707,49 +616,35 @@ mod tests { ]); let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let col_d = col("d", &schema)?; let option_asc = SortOptions { descending: false, nulls_first: false, }; // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_b), - options: option_asc, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new(Arc::clone(&col_b), option_asc)], + vec![PhysicalSortExpr::new(Arc::clone(&col_d), option_asc)], ]); let test_cases = vec![ // d + b ( - Arc::new(BinaryExpr::new( - Arc::clone(col_d), - Operator::Plus, - Arc::clone(col_b), - )) as Arc, + Arc::new(BinaryExpr::new(col_d, Operator::Plus, Arc::clone(&col_b))) as _, SortProperties::Ordered(option_asc), ), // b - (Arc::clone(col_b), SortProperties::Ordered(option_asc)), + (col_b, SortProperties::Ordered(option_asc)), // a - (Arc::clone(col_a), SortProperties::Ordered(option_asc)), + (Arc::clone(&col_a), SortProperties::Ordered(option_asc)), // a + c ( - Arc::new(BinaryExpr::new( - Arc::clone(col_a), - Operator::Plus, - Arc::clone(col_c), - )), + Arc::new(BinaryExpr::new(col_a, Operator::Plus, col_c)), SortProperties::Unordered, ), ]; @@ -757,7 +652,7 @@ mod tests { let leading_orderings = eq_properties .oeq_class() .iter() - .flat_map(|ordering| ordering.first().cloned()) + .map(|ordering| ordering.first().clone()) .collect::>(); let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( @@ -790,7 +685,7 @@ mod tests { Arc::clone(col_a), Operator::Plus, Arc::clone(col_d), - )) as Arc; + )) as _; let option_asc = SortOptions { descending: false, @@ -801,16 +696,10 @@ mod tests { nulls_first: true, }; // [d ASC, h DESC] also satisfies schema. - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }, - PhysicalSortExpr { - expr: Arc::clone(col_h), - options: option_desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(col_d), option_asc), + PhysicalSortExpr::new(Arc::clone(col_h), option_desc), + ]); let test_cases = vec![ // TEST CASE 1 (vec![col_a], vec![(col_a, option_asc)]), @@ -878,7 +767,7 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } @@ -896,7 +785,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.with_constants(vec![ConstExpr::from(col_h)]); + eq_properties.add_constants(vec![ConstExpr::from(Arc::clone(col_h))])?; let test_cases = vec![ // TEST CASE 1 @@ -907,72 +796,13 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } Ok(()) } - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_normalize_sort_reqs() -> Result<()> { // Schema satisfies following properties @@ -1040,7 +870,7 @@ mod tests { let expected_normalized = convert_to_sort_reqs(&expected_normalized); assert_eq!( - eq_properties.normalize_sort_requirements(&req), + eq_properties.normalize_sort_requirements(req).unwrap(), expected_normalized ); } @@ -1073,8 +903,9 @@ mod tests { for (reqs, expected) in test_cases.into_iter() { let reqs = convert_to_sort_reqs(&reqs); let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); + let normalized = eq_properties + .normalize_sort_requirements(reqs.clone()) + .unwrap(); assert!( expected.eq(&normalized), "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" @@ -1091,21 +922,12 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) - .with_reorder(LexOrdering::new( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - }) - .collect::>>()?, - )); + let mut base_properties = EquivalenceProperties::new(Arc::clone(&schema)); + base_properties.reorder( + ["a", "b", "c"] + .into_iter() + .map(|c| PhysicalSortExpr::new_default(col(c, schema.as_ref()).unwrap())), + )?; struct TestCase { name: &'static str, @@ -1118,17 +940,14 @@ mod tests { let col_a = col("a", schema.as_ref())?; let col_b = col("b", schema.as_ref())?; let col_c = col("c", schema.as_ref())?; - let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)); + let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)) as _; let cases = vec![ TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![Arc::clone(&col_b)], - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1138,10 +957,7 @@ mod tests { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[ - Arc::clone(&col_a), - Arc::clone(&cast_c) as Arc, - ]], + equal_conditions: vec![[Arc::clone(&col_a), Arc::clone(&cast_c)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1150,10 +966,7 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: false, }, @@ -1167,19 +980,21 @@ mod tests { // Equal conditions before constants { let mut properties = base_properties.clone(); - for [left, right] in &case.equal_conditions { + for [left, right] in case.equal_conditions.clone() { properties.add_equal_conditions(left, right)? } - properties.with_constants( + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ) + )?; + properties }, // Constants before equal conditions { - let mut properties = base_properties.clone().with_constants( + let mut properties = base_properties.clone(); + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ); - for [left, right] in &case.equal_conditions { + )?; + for [left, right] in case.equal_conditions { properties.add_equal_conditions(left, right)? } properties @@ -1188,16 +1003,11 @@ mod tests { let sort = case .sort_columns .iter() - .map(|&name| { - col(name, &schema).map(|col| PhysicalSortExpr { - expr: col, - options: SortOptions::default(), - }) - }) - .collect::>()?; + .map(|&name| col(name, &schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; assert_eq!( - properties.ordering_satisfy(sort.as_ref()), + properties.ordering_satisfy(sort)?, case.should_satisfy_ordering, "failed test '{}'", case.name @@ -1230,25 +1040,23 @@ mod tests { // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(Arc::clone(&col_c), a_concat_b)?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = - LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc() - ]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(col_c).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); @@ -1270,25 +1078,26 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_times_b: Arc = Arc::new(BinaryExpr::new( + let a_times_b = Arc::new(BinaryExpr::new( Arc::clone(&col_a), Operator::Multiply, Arc::clone(&col_b), - )); + )) as _; // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - let initial_ordering = LexOrdering::from(vec![ + let initial_ordering: LexOrdering = [ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - eq_properties.add_new_ordering(initial_ordering.clone()); + eq_properties.add_ordering(initial_ordering.clone()); // Add equality condition c = a * b - eq_properties.add_equal_conditions(&col_c, &a_times_b)?; + eq_properties.add_equal_conditions(col_c, a_times_b)?; let orderings = eq_properties.oeq_class(); @@ -1311,37 +1120,35 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_concat_b: Arc = Arc::new(ScalarFunctionExpr::new( + let a_concat_b = Arc::new(ScalarFunctionExpr::new( "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], Field::new("f", DataType::Utf8, true).into(), - )); + )) as _; // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&a_concat_b)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(col_c, Arc::clone(&a_concat_b))?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = LexOrdering::from(vec![PhysicalSortExpr::new_default( - Arc::clone(&a_concat_b), - ) - .asc()]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(a_concat_b).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - // The ordering should be [concat(a, b) ASC] and [a ASC, b ASC] + // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); assert!(orderings.contains(&expected_ordering1)); assert!(orderings.contains(&expected_ordering2)); @@ -1349,6 +1156,35 @@ mod tests { Ok(()) } + #[test] + fn test_requirements_compatible() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let eq_properties = EquivalenceProperties::new(schema); + let lex_a: LexRequirement = + [PhysicalSortRequirement::new(Arc::clone(&col_a), None)].into(); + let lex_a_b: LexRequirement = [ + PhysicalSortRequirement::new(col_a, None), + PhysicalSortRequirement::new(col_b, None), + ] + .into(); + let lex_c = [PhysicalSortRequirement::new(col_c, None)].into(); + + assert!(eq_properties.requirements_compatible(lex_a.clone(), lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_a.clone(), lex_a_b.clone())); + assert!(eq_properties.requirements_compatible(lex_a_b, lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_c, lex_a)); + + Ok(()) + } + #[test] fn test_with_reorder_constant_filtering() -> Result<()> { let schema = create_test_schema()?; @@ -1357,26 +1193,21 @@ mod tests { // Setup constant columns let col_a = col("a", &schema)?; let col_b = col("b", &schema)?; - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_a)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_a))])?; - let sort_exprs = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: SortOptions::default(), - }, - ]); + let sort_exprs = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + ]; - let result = eq_properties.with_reorder(sort_exprs); + let change = eq_properties.reorder(sort_exprs)?; + assert!(change); - // Should only contain b since a is constant - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 1); - assert!(ordering[0].expr.eq(&col_b)); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + assert!(ordering[0].expr.eq(&col_a)); + assert!(ordering[1].expr.eq(&col_b)); Ok(()) } @@ -1397,32 +1228,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + PhysicalSortExpr::new(Arc::clone(&col_c), asc), + ]); // New ordering: [a ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), asc)]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; + assert!(!change); // Should only contain [a ASC, b DESC, c ASC] - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 3); assert!(ordering[0].expr.eq(&col_a)); assert!(ordering[0].options.eq(&asc)); @@ -1444,37 +1264,28 @@ mod tests { let col_c = col("c", &schema)?; // Make a and b equivalent - eq_properties.add_equal_conditions(&col_a, &col_b)?; - - let asc = SortOptions::default(); + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; // Initial ordering: [a ASC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]); // New ordering: [b ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new_default(Arc::clone(&col_b))]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; - // Should only contain [b ASC, c ASC] - assert_eq!(result.oeq_class().len(), 1); + assert!(!change); + // Should only contain [a/b ASC, c ASC] + assert_eq!(eq_properties.oeq_class().len(), 1); // Verify orderings - let ordering = result.oeq_class().iter().next().unwrap(); + let asc = SortOptions::default(); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 2); - assert!(ordering[0].expr.eq(&col_b)); + assert!(ordering[0].expr.eq(&col_a) || ordering[0].expr.eq(&col_b)); assert!(ordering[0].options.eq(&asc)); assert!(ordering[1].expr.eq(&col_c)); assert!(ordering[1].options.eq(&asc)); @@ -1497,29 +1308,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + ]); // New ordering: [a DESC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: desc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), desc)]; - let result = eq_properties.with_reorder(new_order.clone()); + let change = eq_properties.reorder(new_order.clone())?; + assert!(change); // Should only contain the new ordering since options don't match - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering, &new_order); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.to_vec(), new_order); Ok(()) } @@ -1535,62 +1338,32 @@ mod tests { let col_d = col("d", &schema)?; let col_e = col("e", &schema)?; - let asc = SortOptions::default(); - // Constants: c is constant - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_c)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_c))])?; // Equality: b = d - eq_properties.add_equal_conditions(&col_b, &col_d)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_d))?; // Orderings: [d ASC, a ASC], [e ASC] - eq_properties.add_new_orderings([ - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_d), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_e), - options: asc, - }]), + eq_properties.add_orderings([ + vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_d)), + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + ], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_e))], ]); - // Initial ordering: [b ASC, c ASC] - let new_order = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ]); - - let result = eq_properties.with_reorder(new_order); - - // Should preserve the original [d ASC, a ASC] ordering - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 2); - - // First expression should be either b or d (they're equivalent) - assert!( - ordering[0].expr.eq(&col_b) || ordering[0].expr.eq(&col_d), - "Expected b or d as first expression, got {:?}", - ordering[0].expr - ); - assert!(ordering[0].options.eq(&asc)); + // New ordering: [b ASC, c ASC] + let new_order = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]; - // Second expression should be a - assert!(ordering[1].expr.eq(&col_a)); - assert!(ordering[1].options.eq(&asc)); + let old_orderings = eq_properties.oeq_class().clone(); + let change = eq_properties.reorder(new_order)?; + // Original orderings should be preserved: + assert!(!change); + assert_eq!(eq_properties.oeq_class, old_orderings); Ok(()) } @@ -1691,75 +1464,62 @@ mod tests { { let mut eq_properties = EquivalenceProperties::new(Arc::clone(schema)); - // Convert base ordering - let base_ordering = LexOrdering::new( - base_order - .iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ); - // Convert string column names to orderings - let satisfied_orderings: Vec = satisfied_orders + let satisfied_orderings: Vec<_> = satisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); - let unsatisfied_orderings: Vec = unsatisfied_orders + let unsatisfied_orderings: Vec<_> = unsatisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); // Test that orderings are not satisfied before adding constraints - for ordering in &satisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{name}: ordering {ordering:?} should not be satisfied before adding constraints" + for ordering in satisfied_orderings.clone() { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied before adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Add base ordering - eq_properties.add_new_ordering(base_ordering); + let base_ordering = base_order.iter().map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }); + eq_properties.add_ordering(base_ordering); // Add constraints eq_properties = eq_properties.with_constraints(Constraints::new_unverified(constraints)); // Test that expected orderings are now satisfied - for ordering in &satisfied_orderings { - assert!( - eq_properties.ordering_satisfy(ordering), - "{name}: ordering {ordering:?} should be satisfied after adding constraints" + for ordering in satisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should be satisfied after adding constraints", ); + assert!(eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Test that unsatisfied orderings remain unsatisfied - for ordering in &unsatisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{name}: ordering {ordering:?} should not be satisfied after adding constraints" + for ordering in unsatisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied after adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } } diff --git a/datafusion/physical-expr/src/equivalence/properties/joins.rs b/datafusion/physical-expr/src/equivalence/properties/joins.rs index 344cf54a57a8..9329ce56b7d6 100644 --- a/datafusion/physical-expr/src/equivalence/properties/joins.rs +++ b/datafusion/physical-expr/src/equivalence/properties/joins.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use super::EquivalenceProperties; use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef}; -use arrow::datatypes::SchemaRef; -use datafusion_common::{JoinSide, JoinType}; -use super::EquivalenceProperties; +use arrow::datatypes::SchemaRef; +use datafusion_common::{JoinSide, JoinType, Result}; /// Calculate ordering equivalence properties for the given join operation. pub fn join_equivalence_properties( @@ -30,7 +30,7 @@ pub fn join_equivalence_properties( maintains_input_order: &[bool], probe_side: Option, on: &[(PhysicalExprRef, PhysicalExprRef)], -) -> EquivalenceProperties { +) -> Result { let left_size = left.schema.fields.len(); let mut result = EquivalenceProperties::new(join_schema); result.add_equivalence_group(left.eq_group().join( @@ -38,15 +38,13 @@ pub fn join_equivalence_properties( join_type, left_size, on, - )); + )?)?; let EquivalenceProperties { - constants: left_constants, oeq_class: left_oeq_class, .. } = left; let EquivalenceProperties { - constants: right_constants, oeq_class: mut right_oeq_class, .. } = right; @@ -59,7 +57,7 @@ pub fn join_equivalence_properties( &mut right_oeq_class, join_type, left_size, - ); + )?; // Right side ordering equivalence properties should be prepended // with those of the left side while constructing output ordering @@ -70,9 +68,9 @@ pub fn join_equivalence_properties( // then we should add `a ASC, b ASC` to the ordering equivalences // of the join output. let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(left_oeq_class); + result.add_orderings(left_oeq_class); } } [false, true] => { @@ -80,7 +78,7 @@ pub fn join_equivalence_properties( &mut right_oeq_class, join_type, left_size, - ); + )?; // In this special case, left side ordering can be prefixed with // the right side ordering. if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { @@ -93,25 +91,16 @@ pub fn join_equivalence_properties( // then we should add `b ASC, a ASC` to the ordering equivalences // of the join output. let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(right_oeq_class); + result.add_orderings(right_oeq_class); } } [false, false] => {} [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), } - match join_type { - JoinType::LeftAnti | JoinType::LeftSemi => { - result = result.with_constants(left_constants); - } - JoinType::RightAnti | JoinType::RightSemi => { - result = result.with_constants(right_constants); - } - _ => {} - } - result + Ok(result) } /// In the context of a join, update the right side `OrderingEquivalenceClass` @@ -125,28 +114,29 @@ pub fn updated_right_ordering_equivalence_class( right_oeq_class: &mut OrderingEquivalenceClass, join_type: &JoinType, left_size: usize, -) { +) -> Result<()> { if matches!( join_type, JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right ) { - right_oeq_class.add_offset(left_size); + right_oeq_class.add_offset(left_size as _)?; } + Ok(()) } #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; - use crate::equivalence::add_offset_to_expr; - use crate::equivalence::tests::{convert_to_orderings, create_test_schema}; + use crate::equivalence::convert_to_orderings; + use crate::equivalence::tests::create_test_schema; use crate::expressions::col; - use datafusion_common::Result; + use crate::physical_expr::add_offset_to_expr; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Fields, Schema}; + use datafusion_common::Result; #[test] fn test_join_equivalence_properties() -> Result<()> { @@ -154,9 +144,9 @@ mod tests { let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; - let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); - let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); + let offset = schema.fields.len() as _; + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset)?; + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset)?; let option_asc = SortOptions { descending: false, nulls_first: false, @@ -205,8 +195,8 @@ mod tests { let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); - left_eq_properties.add_new_orderings(left_orderings); - right_eq_properties.add_new_orderings(right_orderings); + left_eq_properties.add_orderings(left_orderings); + right_eq_properties.add_orderings(right_orderings); let join_eq = join_equivalence_properties( left_eq_properties, right_eq_properties, @@ -215,7 +205,7 @@ mod tests { &[true, false], Some(JoinSide::Left), &[], - ); + )?; let err_msg = format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}"); @@ -253,7 +243,7 @@ mod tests { ]; let orderings = convert_to_orderings(&orderings); // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + let mut right_oeq_class = OrderingEquivalenceClass::from(orderings); let left_columns_len = 4; @@ -264,24 +254,24 @@ mod tests { // Join Schema let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; + let col_a = col("a", &schema)?; + let col_d = col("d", &schema)?; + let col_x = col("x", &schema)?; + let col_y = col("y", &schema)?; + let col_z = col("z", &schema)?; + let col_w = col("w", &schema)?; let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x)?; - join_eq_properties.add_equal_conditions(col_d, col_w)?; + join_eq_properties.add_equal_conditions(col_a, Arc::clone(&col_x))?; + join_eq_properties.add_equal_conditions(col_d, Arc::clone(&col_w))?; updated_right_ordering_equivalence_class( &mut right_oeq_class, &join_type, left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + )?; + join_eq_properties.add_orderings(right_oeq_class); let result = join_eq_properties.oeq_class().clone(); // [x ASC, y ASC], [z ASC, w ASC] @@ -290,7 +280,7 @@ mod tests { vec![(col_z, option_asc), (col_w, option_asc)], ]; let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); + let expected = OrderingEquivalenceClass::from(orderings); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 8f6391bc0b5e..6d18d34ca4de 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -19,47 +19,43 @@ mod dependency; // Submodule containing DependencyMap and Dependencies mod joins; // Submodule containing join_equivalence_properties mod union; // Submodule containing calculate_union -use dependency::{ - construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, - Dependencies, DependencyMap, -}; pub use joins::*; pub use union::*; -use std::fmt::Display; -use std::hash::{Hash, Hasher}; +use std::fmt::{self, Display}; +use std::mem; use std::sync::Arc; -use std::{fmt, mem}; -use crate::equivalence::class::{const_exprs_contains, AcrossPartitions}; +use self::dependency::{ + construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, + Dependencies, DependencyMap, +}; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ - physical_exprs_contains, ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortExpr, PhysicalSortRequirement, + ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Constraint, Constraints, HashMap, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::sort_expr::options_compatible; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::IndexSet; use itertools::Itertools; -/// `EquivalenceProperties` stores information about the output -/// of a plan node, that can be used to optimize the plan. -/// -/// Currently, it keeps track of: -/// - Sort expressions (orderings) -/// - Equivalent expressions: expressions that are known to have same value. -/// - Constants expressions: expressions that are known to contain a single -/// constant value. +/// `EquivalenceProperties` stores information about the output of a plan node +/// that can be used to optimize the plan. Currently, it keeps track of: +/// - Sort expressions (orderings), +/// - Equivalent expressions; i.e. expressions known to have the same value. +/// - Constants expressions; i.e. expressions known to contain a single constant +/// value. /// /// Please see the [Using Ordering for Better Plans] blog for more details. /// @@ -81,8 +77,8 @@ use itertools::Itertools; /// ``` /// /// In this case, both `a ASC` and `b DESC` can describe the table ordering. -/// `EquivalenceProperties`, tracks these different valid sort expressions and -/// treat `a ASC` and `b DESC` on an equal footing. For example if the query +/// `EquivalenceProperties` tracks these different valid sort expressions and +/// treat `a ASC` and `b DESC` on an equal footing. For example, if the query /// specifies the output sorted by EITHER `a ASC` or `b DESC`, the sort can be /// avoided. /// @@ -101,12 +97,11 @@ use itertools::Itertools; /// └---┴---┘ /// ``` /// -/// In this case, columns `a` and `b` always have the same value, which can of -/// such equivalences inside this object. With this information, Datafusion can -/// optimize operations such as. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then DataFusion avoids -/// repartitioning the data as the existing partitioning satisfies the -/// requirement. +/// In this case, columns `a` and `b` always have the same value. With this +/// information, Datafusion can optimize various operations. For example, if +/// the partition requirement is `Hash(a)` and output partitioning is +/// `Hash(b)`, then DataFusion avoids repartitioning the data as the existing +/// partitioning satisfies the requirement. /// /// # Code Example /// ``` @@ -125,40 +120,85 @@ use itertools::Itertools; /// # let col_c = col("c", &schema).unwrap(); /// // This object represents data that is sorted by a ASC, c DESC /// // with a single constant value of b -/// let mut eq_properties = EquivalenceProperties::new(schema) -/// .with_constants(vec![ConstExpr::from(col_b)]); -/// eq_properties.add_new_ordering(LexOrdering::new(vec![ +/// let mut eq_properties = EquivalenceProperties::new(schema); +/// eq_properties.add_constants(vec![ConstExpr::from(col_b)]); +/// eq_properties.add_ordering([ /// PhysicalSortExpr::new_default(col_a).asc(), /// PhysicalSortExpr::new_default(col_c).desc(), -/// ])); +/// ]); /// -/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], const: [b@1(heterogeneous)]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], eq: [{members: [b@1], constant: (heterogeneous)}]"); /// ``` -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct EquivalenceProperties { - /// Distinct equivalence classes (exprs known to have the same expressions) + /// Distinct equivalence classes (i.e. expressions with the same value). eq_group: EquivalenceGroup, - /// Equivalent sort expressions + /// Equivalent sort expressions (i.e. those define the same ordering). oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant - /// - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_group` as `Literal` expressions. - constants: Vec, - /// Table constraints + /// Cache storing equivalent sort expressions in normal form (i.e. without + /// constants/duplicates and in standard form) and a map associating leading + /// terms with full sort expressions. + oeq_cache: OrderingEquivalenceCache, + /// Table constraints that factor in equivalence calculations. constraints: Constraints, /// Schema associated with this object. schema: SchemaRef, } +/// This object serves as a cache for storing equivalent sort expressions +/// in normal form, and a map associating leading sort expressions with +/// full lexicographical orderings. With this information, DataFusion can +/// efficiently determine whether a given ordering is satisfied by the +/// existing orderings, and discover new orderings based on the existing +/// equivalence properties. +#[derive(Clone, Debug, Default)] +struct OrderingEquivalenceCache { + /// Equivalent sort expressions in normal form. + normal_cls: OrderingEquivalenceClass, + /// Map associating leading sort expressions with full lexicographical + /// orderings. Values are indices into `normal_cls`. + leading_map: HashMap, Vec>, +} + +impl OrderingEquivalenceCache { + /// Creates a new `OrderingEquivalenceCache` object with the given + /// equivalent orderings, which should be in normal form. + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut cache = Self { + normal_cls: OrderingEquivalenceClass::new(orderings), + leading_map: HashMap::new(), + }; + cache.update_map(); + cache + } + + /// Updates/reconstructs the leading expression map according to the normal + /// ordering equivalence class within. + pub fn update_map(&mut self) { + self.leading_map.clear(); + for (idx, ordering) in self.normal_cls.iter().enumerate() { + let expr = Arc::clone(&ordering.first().expr); + self.leading_map.entry(expr).or_default().push(idx); + } + } + + /// Clears the cache, removing all orderings and leading expressions. + pub fn clear(&mut self) { + self.normal_cls.clear(); + self.leading_map.clear(); + } +} + impl EquivalenceProperties { /// Creates an empty `EquivalenceProperties` object. pub fn new(schema: SchemaRef) -> Self { Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - constraints: Constraints::empty(), + eq_group: EquivalenceGroup::default(), + oeq_class: OrderingEquivalenceClass::default(), + oeq_cache: OrderingEquivalenceCache::default(), + constraints: Constraints::default(), schema, } } @@ -170,12 +210,23 @@ impl EquivalenceProperties { } /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + pub fn new_with_orderings( + schema: SchemaRef, + orderings: impl IntoIterator>, + ) -> Self { + let eq_group = EquivalenceGroup::default(); + let oeq_class = OrderingEquivalenceClass::new(orderings); + // Here, we can avoid performing a full normalization, and get by with + // only removing constants because the equivalence group is empty. + let normal_orderings = oeq_class.iter().cloned().map(|o| { + o.into_iter() + .filter(|sort_expr| eq_group.is_expr_constant(&sort_expr.expr).is_none()) + }); Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - constraints: Constraints::empty(), + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class, + eq_group, + constraints: Constraints::default(), schema, } } @@ -190,91 +241,126 @@ impl EquivalenceProperties { &self.oeq_class } - /// Return the inner OrderingEquivalenceClass, consuming self - pub fn into_oeq_class(self) -> OrderingEquivalenceClass { - self.oeq_class - } - /// Returns a reference to the equivalence group within. pub fn eq_group(&self) -> &EquivalenceGroup { &self.eq_group } - /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[ConstExpr] { - &self.constants - } - + /// Returns a reference to the constraints within. pub fn constraints(&self) -> &Constraints { &self.constraints } - /// Returns the output ordering of the properties. - pub fn output_ordering(&self) -> Option { - let constants = self.constants(); - let mut output_ordering = self.oeq_class().output_ordering().unwrap_or_default(); - // Prune out constant expressions - output_ordering - .retain(|sort_expr| !const_exprs_contains(constants, &sort_expr.expr)); - (!output_ordering.is_empty()).then_some(output_ordering) + /// Returns all the known constants expressions. + pub fn constants(&self) -> Vec { + self.eq_group + .iter() + .filter_map(|c| { + c.constant.as_ref().and_then(|across| { + c.canonical_expr() + .map(|expr| ConstExpr::new(Arc::clone(expr), across.clone())) + }) + }) + .collect() } - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) + /// Returns the output ordering of the properties. + pub fn output_ordering(&self) -> Option { + let concat = self.oeq_class.iter().flat_map(|o| o.iter().cloned()); + self.normalize_sort_exprs(concat) } /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.with_constants(other.constants) + pub fn extend(mut self, other: Self) -> Result { + self.constraints.extend(other.constraints); + self.add_equivalence_group(other.eq_group)?; + self.add_orderings(other.oeq_class); + Ok(self) } /// Clears (empties) the ordering equivalence class within this object. /// Call this method when existing orderings are invalidated. pub fn clear_orderings(&mut self) { self.oeq_class.clear(); + self.oeq_cache.clear(); } /// Removes constant expressions that may change across partitions. - /// This method should be used when data from different partitions are merged. + /// This method should be used when merging data from different partitions. pub fn clear_per_partition_constants(&mut self) { - self.constants.retain(|item| { - matches!(item.across_partitions(), AcrossPartitions::Uniform(_)) - }) - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); + if self.eq_group.clear_per_partition_constants() { + // Renormalize orderings if the equivalence group changes: + let normal_orderings = self + .oeq_class + .iter() + .cloned() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache = OrderingEquivalenceCache::new(normal_orderings); + } } /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + orderings: impl IntoIterator>, ) { - self.oeq_class.add_new_orderings(orderings); + let orderings: Vec<_> = + orderings.into_iter().filter_map(LexOrdering::new).collect(); + let normal_orderings: Vec<_> = orderings + .iter() + .cloned() + .filter_map(|o| self.normalize_sort_exprs(o)) + .collect(); + if !normal_orderings.is_empty() { + self.oeq_class.extend(orderings); + // Normalize given orderings to update the cache: + self.oeq_cache.normal_cls.extend(normal_orderings); + // TODO: If no ordering is found to be redunant during extension, we + // can use a shortcut algorithm to update the leading map. + self.oeq_cache.update_map(); + } } /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); + pub fn add_ordering(&mut self, ordering: impl IntoIterator) { + self.add_orderings(std::iter::once(ordering)); } /// Incorporates the given equivalence group to into the existing /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); + pub fn add_equivalence_group( + &mut self, + other_eq_group: EquivalenceGroup, + ) -> Result<()> { + if !other_eq_group.is_empty() { + self.eq_group.extend(other_eq_group); + // Renormalize orderings if the equivalence group changes: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls + .into_iter() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the new equivalence classes: + let leading_exprs: Vec<_> = + self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; + } + } + Ok(()) + } + + /// Returns the ordering equivalence class within in normal form. + /// Normalization standardizes expressions according to the equivalence + /// group within, and removes constants/duplicates. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + self.oeq_class + .iter() + .cloned() + .filter_map(|ordering| self.normalize_sort_exprs(ordering)) + .collect::>() + .into() } /// Adds a new equality condition into the existing equivalence group. @@ -282,290 +368,236 @@ impl EquivalenceProperties { /// equivalence class to the equivalence group. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, + left: Arc, + right: Arc, ) -> Result<()> { - // Discover new constants in light of new the equality: - if self.is_expr_constant(left) { - // Left expression is constant, add right as constant - if !const_exprs_contains(&self.constants, right) { - let const_expr = ConstExpr::from(right) - .with_across_partitions(self.get_expr_constant_value(left)); - self.constants.push(const_expr); - } - } else if self.is_expr_constant(right) { - // Right expression is constant, add left as constant - if !const_exprs_contains(&self.constants, left) { - let const_expr = ConstExpr::from(left) - .with_across_partitions(self.get_expr_constant_value(right)); - self.constants.push(const_expr); - } + // Add equal expressions to the state: + if self.eq_group.add_equal_conditions(Arc::clone(&left), right) { + // Renormalize orderings if the equivalence group changes: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls + .into_iter() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings: + self.discover_new_orderings(left)?; } - - // Add equal expressions to the state - self.eq_group.add_equal_conditions(left, right); - - // Discover any new orderings - self.discover_new_orderings(left)?; Ok(()) } /// Track/register physical expressions with constant values. - #[deprecated(since = "43.0.0", note = "Use [`with_constants`] instead")] - pub fn add_constants(self, constants: impl IntoIterator) -> Self { - self.with_constants(constants) - } - - /// Remove the specified constant - pub fn remove_constant(mut self, c: &ConstExpr) -> Self { - self.constants.retain(|existing| existing != c); - self - } - - /// Track/register physical expressions with constant values. - pub fn with_constants( - mut self, + pub fn add_constants( + &mut self, constants: impl IntoIterator, - ) -> Self { - let normalized_constants = constants - .into_iter() - .filter_map(|c| { - let across_partitions = c.across_partitions(); - let expr = c.owned_expr(); - let normalized_expr = self.eq_group.normalize_expr(expr); - - if const_exprs_contains(&self.constants, &normalized_expr) { - return None; - } - - let const_expr = ConstExpr::from(normalized_expr) - .with_across_partitions(across_partitions); - - Some(const_expr) + ) -> Result<()> { + // Add the new constant to the equivalence group: + for constant in constants { + self.eq_group.add_constant(constant); + } + // Renormalize the orderings after adding new constants by removing + // the constants from existing orderings: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls.into_iter().map(|ordering| { + ordering.into_iter().filter(|sort_expr| { + self.eq_group.is_expr_constant(&sort_expr.expr).is_none() }) - .collect::>(); - - // Add all new normalized constants - self.constants.extend(normalized_constants); - - // Discover any new orderings based on the constants - for ordering in self.normalized_oeq_class().iter() { - if let Err(e) = self.discover_new_orderings(&ordering[0].expr) { - log::debug!("error discovering new orderings: {e}"); - } + }); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the constants: + let leading_exprs: Vec<_> = self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; } - - self + Ok(()) } - // Discover new valid orderings in light of a new equality. - // Accepts a single argument (`expr`) which is used to determine - // which orderings should be updated. - // When constants or equivalence classes are changed, there may be new orderings - // that can be discovered with the new equivalence properties. - // For a discussion, see: https://github.com/apache/datafusion/issues/9812 - fn discover_new_orderings(&mut self, expr: &Arc) -> Result<()> { - let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); + /// Discover new valid orderings in light of a new equality. Accepts a single + /// argument (`expr`) which is used to determine the orderings to update. + /// When constants or equivalence classes change, there may be new orderings + /// that can be discovered with the new equivalence properties. + /// For a discussion, see: + fn discover_new_orderings( + &mut self, + normal_expr: Arc, + ) -> Result<()> { + let Some(ordering_idxs) = self.oeq_cache.leading_map.get(&normal_expr) else { + return Ok(()); + }; let eq_class = self .eq_group - .iter() - .find_map(|class| { - class - .contains(&normalized_expr) - .then(|| class.clone().into_vec()) - }) - .unwrap_or_else(|| vec![Arc::clone(&normalized_expr)]); - - let mut new_orderings: Vec = vec![]; - for ordering in self.normalized_oeq_class().iter() { - if !ordering[0].expr.eq(&normalized_expr) { - continue; - } + .get_equivalence_class(&normal_expr) + .map_or_else(|| vec![normal_expr], |class| class.clone().into()); + let mut new_orderings = vec![]; + for idx in ordering_idxs { + let ordering = &self.oeq_cache.normal_cls[*idx]; let leading_ordering_options = ordering[0].options; - for equivalent_expr in &eq_class { + 'exprs: for equivalent_expr in &eq_class { let children = equivalent_expr.children(); if children.is_empty() { continue; } - - // Check if all children match the next expressions in the ordering - let mut all_children_match = true; + // Check if all children match the next expressions in the ordering: let mut child_properties = vec![]; - - // Build properties for each child based on the next expressions - for (i, child) in children.iter().enumerate() { - if let Some(next) = ordering.get(i + 1) { - if !child.as_ref().eq(next.expr.as_ref()) { - all_children_match = false; - break; - } - child_properties.push(ExprProperties { - sort_properties: SortProperties::Ordered(next.options), - range: Interval::make_unbounded( - &child.data_type(&self.schema)?, - )?, - preserves_lex_ordering: true, - }); - } else { - all_children_match = false; - break; + // Build properties for each child based on the next expression: + for (i, child) in children.into_iter().enumerate() { + let Some(next) = ordering.get(i + 1) else { + break 'exprs; + }; + if !next.expr.eq(child) { + break 'exprs; } + let data_type = child.data_type(&self.schema)?; + child_properties.push(ExprProperties { + sort_properties: SortProperties::Ordered(next.options), + range: Interval::make_unbounded(&data_type)?, + preserves_lex_ordering: true, + }); } - - if all_children_match { - // Check if the expression is monotonic in all arguments - if let Ok(expr_properties) = - equivalent_expr.get_properties(&child_properties) - { - if expr_properties.preserves_lex_ordering - && SortProperties::Ordered(leading_ordering_options) - == expr_properties.sort_properties - { - // Assume existing ordering is [c ASC, a ASC, b ASC] - // When equality c = f(a,b) is given, if we know that given ordering `[a ASC, b ASC]`, - // ordering `[f(a,b) ASC]` is valid, then we can deduce that ordering `[a ASC, b ASC]` is also valid. - // Hence, ordering `[a ASC, b ASC]` can be added to the state as a valid ordering. - // (e.g. existing ordering where leading ordering is removed) - new_orderings.push(LexOrdering::new(ordering[1..].to_vec())); - break; - } - } + // Check if the expression is monotonic in all arguments: + let expr_properties = + equivalent_expr.get_properties(&child_properties)?; + if expr_properties.preserves_lex_ordering + && expr_properties.sort_properties + == SortProperties::Ordered(leading_ordering_options) + { + // Assume that `[c ASC, a ASC, b ASC]` is among existing + // orderings. If equality `c = f(a, b)` is given, ordering + // `[a ASC, b ASC]` implies the ordering `[c ASC]`. Thus, + // ordering `[a ASC, b ASC]` is also a valid ordering. + new_orderings.push(ordering[1..].to_vec()); + break; } } } - self.oeq_class.add_new_orderings(new_orderings); - Ok(()) - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - /// If the given ordering is already satisfied, the function does nothing. - pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { - // Filter out constant expressions as they don't affect ordering - let filtered_exprs = LexOrdering::new( - sort_exprs - .into_iter() - .filter(|expr| !self.is_expr_constant(&expr.expr)) - .collect(), - ); - - if filtered_exprs.is_empty() { - return self; - } - - let mut new_orderings = vec![filtered_exprs.clone()]; - - // Preserve valid suffixes from existing orderings - let oeq_class = mem::take(&mut self.oeq_class); - for existing in oeq_class { - if self.is_prefix_of(&filtered_exprs, &existing) { - let mut extended = filtered_exprs.clone(); - extended.extend(existing.into_iter().skip(filtered_exprs.len())); - new_orderings.push(extended); - } + if !new_orderings.is_empty() { + self.add_orderings(new_orderings); } - - self.oeq_class = OrderingEquivalenceClass::new(new_orderings); - self + Ok(()) } - /// Checks if the new ordering matches a prefix of the existing ordering - /// (considering expression equivalences) - fn is_prefix_of(&self, new_order: &LexOrdering, existing: &LexOrdering) -> bool { - // Check if new order is longer than existing - can't be a prefix - if new_order.len() > existing.len() { - return false; + /// Updates the ordering equivalence class within assuming that the table + /// is re-sorted according to the argument `ordering`, and returns whether + /// this operation resulted in any change. Note that equivalence classes + /// (and constants) do not change as they are unaffected by a re-sort. If + /// the given ordering is already satisfied, the function does nothing. + pub fn reorder( + &mut self, + ordering: impl IntoIterator, + ) -> Result { + let (ordering, ordering_tee) = ordering.into_iter().tee(); + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(false); + }; + if normal_ordering.len() != self.common_sort_prefix_length(&normal_ordering)? { + // If the ordering is unsatisfied, replace existing orderings: + self.clear_orderings(); + self.add_ordering(ordering_tee); + return Ok(true); } - - // Check if new order matches existing prefix (considering equivalences) - new_order.iter().zip(existing).all(|(new, existing)| { - self.eq_group.exprs_equal(&new.expr, &existing.expr) - && new.options == existing.options - }) + Ok(false) } /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// equivalence group within. Returns a `LexOrdering` instance if the + /// expressions define a proper lexicographical ordering. For more details, + /// see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_exprs( + &self, + sort_exprs: impl IntoIterator, + ) -> Option { + LexOrdering::new(self.eq_group.normalize_sort_exprs(sort_exprs)) } /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements(&self, sort_reqs: &LexRequirement) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let mut constant_exprs = vec![]; - constant_exprs.extend( - self.constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())), - ); - let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); - // Prune redundant sections in the requirement: - normalized_sort_reqs - .iter() - .filter(|&order| !physical_exprs_contains(&constants_normalized, &order.expr)) - .cloned() - .collect::() - .collapse() + /// equivalence group within. Returns a `LexRequirement` instance if the + /// expressions define a proper lexicographical requirement. For more + /// details, see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_requirements( + &self, + sort_reqs: impl IntoIterator, + ) -> Option { + LexRequirement::new(self.eq_group.normalize_sort_requirements(sort_reqs)) } - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: &LexOrdering) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = LexRequirement::from(given.clone()); - self.ordering_satisfy_requirement(&sort_requirements) + /// Iteratively checks whether the given ordering is satisfied by any of + /// the existing orderings. See [`Self::ordering_satisfy_requirement`] for + /// more details and examples. + pub fn ordering_satisfy( + &self, + given: impl IntoIterator, + ) -> Result { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(given) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(true); + }; + Ok(normal_ordering.len() == self.common_sort_prefix_length(&normal_ordering)?) } - /// Returns the number of consecutive requirements (starting from the left) - /// that are satisfied by the plan ordering. - fn compute_common_sort_prefix_length( + /// Iteratively checks whether the given sort requirement is satisfied by + /// any of the existing orderings. + /// + /// ### Example Scenarios + /// + /// In these scenarios, assume that all expressions share the same sort + /// properties. + /// + /// #### Case 1: Sort Requirement `[a, c]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function fails for `c` in the second iteration, as neither + /// `[b, c]` nor `[d]` satisfies `c`. + /// + /// #### Case 2: Sort Requirement `[a, d]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function returns `true` as `[d]` satisfies `d`. + pub fn ordering_satisfy_requirement( &self, - normalized_reqs: &LexRequirement, - ) -> usize { - // Check whether given ordering is satisfied by constraints first - if self.satisfied_by_constraints(normalized_reqs) { - // If the constraints satisfy all requirements, return the full normalized requirements length - return normalized_reqs.len(); + given: impl IntoIterator, + ) -> Result { + // First, standardize the given requirement: + let Some(normal_reqs) = self.normalize_sort_requirements(given) else { + // If the requirement vanishes after normalization, it is satisfied: + return Ok(true); + }; + // Then, check whether given requirement is satisfied by constraints: + if self.satisfied_by_constraints(&normal_reqs) { + return Ok(true); } - + let schema = self.schema(); let mut eq_properties = self.clone(); - - for (i, normalized_req) in normalized_reqs.iter().enumerate() { - // Check whether given ordering is satisfied - if !eq_properties.ordering_satisfy_single(normalized_req) { - // As soon as one requirement is not satisfied, return - // how many we've satisfied so far - return i; + for element in normal_reqs { + // Check whether given requirement is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => element.options.is_none_or(|opts| { + let nullable = element.expr.nullable(schema).unwrap_or(true); + options_compatible(&options, &opts, nullable) + }), + // Singleton expressions satisfy any requirement. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + return Ok(false); } // Treat satisfied keys as constants in subsequent iterations. We // can do this because the "next" key only matters in a lexicographical @@ -579,288 +611,263 @@ impl EquivalenceProperties { // From the analysis above, we know that `[a ASC]` is satisfied. Then, // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = eq_properties.with_constants(std::iter::once( - ConstExpr::from(Arc::clone(&normalized_req.expr)), - )); + let const_expr = ConstExpr::from(element.expr); + eq_properties.add_constants(std::iter::once(const_expr))?; } + Ok(true) + } - // All requirements are satisfied. - normalized_reqs.len() + /// Returns the number of consecutive sort expressions (starting from the + /// left) that are satisfied by the existing ordering. + fn common_sort_prefix_length(&self, normal_ordering: &LexOrdering) -> Result { + let full_length = normal_ordering.len(); + // Check whether the given ordering is satisfied by constraints: + if self.satisfied_by_constraints_ordering(normal_ordering) { + // If constraints satisfy all sort expressions, return the full + // length: + return Ok(full_length); + } + let schema = self.schema(); + let mut eq_properties = self.clone(); + for (idx, element) in normal_ordering.into_iter().enumerate() { + // Check whether given ordering is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => options_compatible( + &options, + &element.options, + element.expr.nullable(schema).unwrap_or(true), + ), + // Singleton expressions satisfy any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + // As soon as one sort expression is unsatisfied, return how + // many we've satisfied so far: + return Ok(idx); + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + let const_expr = ConstExpr::from(Arc::clone(&element.expr)); + eq_properties.add_constants(std::iter::once(const_expr))? + } + // All sort expressions are satisfied, return full length: + Ok(full_length) } - /// Determines the longest prefix of `reqs` that is satisfied by the existing ordering. - /// Returns that prefix as a new `LexRequirement`, and a boolean indicating if all the requirements are satisfied. + /// Determines the longest normal prefix of `ordering` satisfied by the + /// existing ordering. Returns that prefix as a new `LexOrdering`, and a + /// boolean indicating whether all the sort expressions are satisfied. pub fn extract_common_sort_prefix( &self, - reqs: &LexRequirement, - ) -> (LexRequirement, bool) { - // First, standardize the given requirement: - let normalized_reqs = self.normalize_sort_requirements(reqs); - - let prefix_len = self.compute_common_sort_prefix_length(&normalized_reqs); - ( - LexRequirement::new(normalized_reqs[..prefix_len].to_vec()), - prefix_len == normalized_reqs.len(), - ) - } - - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool { - self.extract_common_sort_prefix(reqs).1 + ordering: LexOrdering, + ) -> Result<(Vec, bool)> { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok((vec![], true)); + }; + let prefix_len = self.common_sort_prefix_length(&normal_ordering)?; + let flag = prefix_len == normal_ordering.len(); + let mut sort_exprs: Vec<_> = normal_ordering.into(); + if !flag { + sort_exprs.truncate(prefix_len); + } + Ok((sort_exprs, flag)) } - /// Checks if the sort requirements are satisfied by any of the table constraints (primary key or unique). - /// Returns true if any constraint fully satisfies the requirements. - fn satisfied_by_constraints( + /// Checks if the sort expressions are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the expressions (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the expressions). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints_ordering( &self, - normalized_reqs: &[PhysicalSortRequirement], + normal_exprs: &[PhysicalSortExpr], ) -> bool { self.constraints.iter().any(|constraint| match constraint { - Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => self - .satisfied_by_constraint( - normalized_reqs, - indices, - matches!(constraint, Constraint::Unique(_)), - ), - }) - } - - /// Checks if sort requirements are satisfied by a constraint (primary key or unique). - /// Returns true if the constraint indices form a valid prefix of an existing ordering - /// that matches the requirements. For unique constraints, also verifies nullable columns. - fn satisfied_by_constraint( - &self, - normalized_reqs: &[PhysicalSortRequirement], - indices: &[usize], - check_null: bool, - ) -> bool { - // Requirements must contain indices - if indices.len() > normalized_reqs.len() { - return false; - } - - // Iterate over all orderings - self.oeq_class.iter().any(|ordering| { - if indices.len() > ordering.len() { - return false; - } - - // Build a map of column positions in the ordering - let mut col_positions = HashMap::with_capacity(ordering.len()); - for (pos, req) in ordering.iter().enumerate() { - if let Some(col) = req.expr.as_any().downcast_ref::() { - col_positions.insert( - col.index(), - (pos, col.nullable(&self.schema).unwrap_or(true)), - ); - } - } - - // Check if all constraint indices appear in valid positions - if !indices.iter().all(|&idx| { - col_positions - .get(&idx) - .map(|&(pos, nullable)| { - // For unique constraints, verify column is not nullable if it's first/last - !check_null - || (pos != 0 && pos != ordering.len() - 1) - || !nullable + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_exprs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_exprs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy_expr(given, &self.schema) + }) }) - .unwrap_or(false) - }) { - return false; } - - // Check if this ordering matches requirements prefix - let ordering_len = ordering.len(); - normalized_reqs.len() >= ordering_len - && normalized_reqs[..ordering_len].iter().zip(ordering).all( - |(req, existing)| { - req.expr.eq(&existing.expr) - && req - .options - .is_none_or(|req_opts| req_opts == existing.options) - }, - ) }) } - /// Determines whether the ordering specified by the given sort requirement - /// is satisfied based on the orderings within, equivalence classes, and - /// constant expressions. - /// - /// # Parameters - /// - /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering - /// satisfaction check will be done. - /// - /// # Returns - /// - /// Returns `true` if the specified ordering is satisfied, `false` otherwise. - fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let ExprProperties { - sort_properties, .. - } = self.get_expr_properties(Arc::clone(&req.expr)); - match sort_properties { - SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(&req.expr), - options, - }; - sort_expr.satisfy(req, self.schema()) + /// Checks if the sort requirements are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the requirements (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the requirements). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints(&self, normal_reqs: &[PhysicalSortRequirement]) -> bool { + self.constraints.iter().any(|constraint| match constraint { + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_reqs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_reqs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy(given, &self.schema) + }) + }) } - // Singleton expressions satisfies any ordering. - SortProperties::Singleton => true, - SortProperties::Unordered => false, - } + }) } /// Checks whether the `given` sort requirements are equal or more specific /// than the `reference` sort requirements. pub fn requirements_compatible( &self, - given: &LexRequirement, - reference: &LexRequirement, + given: LexRequirement, + reference: LexRequirement, ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); - - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference + let Some(normal_given) = self.normalize_sort_requirements(given) else { + return true; + }; + let Some(normal_reference) = self.normalize_sort_requirements(reference) else { + return true; + }; + + (normal_reference.len() <= normal_given.len()) + && normal_reference .into_iter() - .zip(normalized_given) + .zip(normal_given) .all(|(reference, given)| given.compatible(&reference)) } - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: &LexOrdering, - rhs: &LexOrdering, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = LexRequirement::from(lhs.clone()); - let rhs = LexRequirement::from(rhs.clone()); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(LexOrdering::from) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. + /// Modify existing orderings by substituting sort expressions with appropriate + /// targets from the projection mapping. We substitute a sort expression when + /// its physical expression has a one-to-one functional relationship with a + /// target expression in the mapping. /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. + /// After substitution, we may generate more than one `LexOrdering` for each + /// existing equivalent ordering. For example, `[a ASC, b ASC]` will turn + /// into `[CAST(a) ASC, b ASC]` and `[a ASC, b ASC]` when applying projection + /// expressions `a, b, CAST(a)`. /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: &LexRequirement, - req2: &LexRequirement, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.inner - .iter_mut() - .zip(rhs.inner.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// we substitute the ordering according to input expression type, this is a simplified version - /// In this case, we just substitute when the expression satisfy the following condition: - /// I. just have one column and is a CAST expression - /// TODO: Add one-to-ones analysis for monotonic ScalarFunctions. - /// TODO: we could precompute all the scenario that is computable, for example: atan(x + 1000) should also be substituted if - /// x is DESC or ASC - /// After substitution, we may generate more than 1 `LexOrdering`. As an example, - /// `[a ASC, b ASC]` will turn into `[a ASC, b ASC], [CAST(a) ASC, b ASC]` when projection expressions `a, b, CAST(a)` is applied. - pub fn substitute_ordering_component( - &self, + /// TODO: Handle all scenarios that allow substitution; e.g. when `x` is + /// sorted, `atan(x + 1000)` should also be substituted. For now, we + /// only consider single-column `CAST` expressions. + fn substitute_oeq_class( + schema: &SchemaRef, mapping: &ProjectionMapping, - sort_expr: &LexOrdering, - ) -> Result> { - let new_orderings = sort_expr - .iter() - .map(|sort_expr| { - let referring_exprs: Vec<_> = mapping - .iter() - .map(|(source, _target)| source) - .filter(|source| expr_refers(source, &sort_expr.expr)) - .cloned() - .collect(); - let mut res = LexOrdering::new(vec![sort_expr.clone()]); - // TODO: Add one-to-ones analysis for ScalarFunctions. - for r_expr in referring_exprs { - // we check whether this expression is substitutable or not - if let Some(cast_expr) = r_expr.as_any().downcast_ref::() { - // we need to know whether the Cast Expr matches or not - let expr_type = sort_expr.expr.data_type(&self.schema)?; - if cast_expr.expr.eq(&sort_expr.expr) - && cast_expr.is_bigger_cast(expr_type) + oeq_class: OrderingEquivalenceClass, + ) -> OrderingEquivalenceClass { + let new_orderings = oeq_class.into_iter().flat_map(|order| { + // Modify/expand existing orderings by substituting sort + // expressions with appropriate targets from the mapping: + order + .into_iter() + .map(|sort_expr| { + let referring_exprs = mapping + .iter() + .map(|(source, _target)| source) + .filter(|source| expr_refers(source, &sort_expr.expr)) + .cloned(); + let mut result = vec![]; + // The sort expression comes from this schema, so the + // following call to `unwrap` is safe. + let expr_type = sort_expr.expr.data_type(schema).unwrap(); + // TODO: Add one-to-one analysis for ScalarFunctions. + for r_expr in referring_exprs { + // We check whether this expression is substitutable. + if let Some(cast_expr) = + r_expr.as_any().downcast_ref::() { - res.push(PhysicalSortExpr { - expr: Arc::clone(&r_expr), - options: sort_expr.options, - }); + // For casts, we need to know whether the cast + // expression matches: + if cast_expr.expr.eq(&sort_expr.expr) + && cast_expr.is_bigger_cast(&expr_type) + { + result.push(PhysicalSortExpr::new( + r_expr, + sort_expr.options, + )); + } } } - } - Ok(res) - }) - .collect::>>()?; - // Generate all valid orderings, given substituted expressions. - let res = new_orderings - .into_iter() - .multi_cartesian_product() - .map(LexOrdering::new) - .collect::>(); - Ok(res) + result.push(sort_expr); + result + }) + // Generate all valid orderings given substituted expressions: + .multi_cartesian_product() + }); + OrderingEquivalenceClass::new(new_orderings) } - /// In projection, supposed we have a input function 'A DESC B DESC' and the output shares the same expression - /// with A and B, we could surely use the ordering of the original ordering, However, if the A has been changed, - /// for example, A-> Cast(A, Int64) or any other form, it is invalid if we continue using the original ordering - /// Since it would cause bug in dependency constructions, we should substitute the input order in order to get correct - /// dependency map, happen in issue 8838: - pub fn substitute_oeq_class(&mut self, mapping: &ProjectionMapping) -> Result<()> { - let new_order = self - .oeq_class - .iter() - .map(|order| self.substitute_ordering_component(mapping, order)) - .collect::>>()?; - let new_order = new_order.into_iter().flatten().collect(); - self.oeq_class = OrderingEquivalenceClass::new(new_order); - Ok(()) - } - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. + /// Projects argument `expr` according to the projection described by + /// `mapping`, taking equivalences into account. /// /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: + /// the projection described by `mapping` encodes the following: /// /// ```text /// a -> a1 @@ -868,13 +875,25 @@ impl EquivalenceProperties { /// ``` /// /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + /// `Some(a1 + b1)` and `d` to `None`, meaning that it is not projectable. pub fn project_expr( &self, expr: &Arc, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) + self.eq_group.project_expr(mapping, expr) + } + + /// Projects the given `expressions` according to the projection described + /// by `mapping`, taking equivalences into account. This function is similar + /// to [`Self::project_expr`], but projects multiple expressions at once + /// more efficiently than calling `project_expr` for each expression. + pub fn project_expressions<'a>( + &'a self, + expressions: impl IntoIterator> + 'a, + mapping: &'a ProjectionMapping, + ) -> impl Iterator>> + 'a { + self.eq_group.project_expressions(mapping, expressions) } /// Constructs a dependency map based on existing orderings referred to in @@ -906,71 +925,85 @@ impl EquivalenceProperties { /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} /// c ASC: Node {None, HashSet{a ASC}} /// ``` - fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = DependencyMap::new(); - for ordering in self.normalized_oeq_class().iter() { - for (idx, sort_expr) in ordering.iter().enumerate() { - let target_sort_expr = - self.project_expr(&sort_expr.expr, mapping).map(|expr| { - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }); - let is_projected = target_sort_expr.is_some(); - if is_projected - || mapping - .iter() - .any(|(source, _)| expr_refers(source, &sort_expr.expr)) - { - // Previous ordering is a dependency. Note that there is no, - // dependency for a leading ordering (i.e. the first sort - // expression). - let dependency = idx.checked_sub(1).map(|a| &ordering[a]); - // Add sort expressions that can be projected or referred to - // by any of the projection expressions to the dependency map: - dependency_map.insert( - sort_expr, - target_sort_expr.as_ref(), - dependency, - ); - } - if !is_projected { - // If we can not project, stop constructing the dependency - // map as remaining dependencies will be invalid after projection. + fn construct_dependency_map( + &self, + oeq_class: OrderingEquivalenceClass, + mapping: &ProjectionMapping, + ) -> DependencyMap { + let mut map = DependencyMap::default(); + for ordering in oeq_class.into_iter() { + // Previous expression is a dependency. Note that there is no + // dependency for the leading expression. + if !self.insert_to_dependency_map( + mapping, + ordering[0].clone(), + None, + &mut map, + ) { + continue; + } + for (dependency, sort_expr) in ordering.into_iter().tuple_windows() { + if !self.insert_to_dependency_map( + mapping, + sort_expr, + Some(dependency), + &mut map, + ) { + // If we can't project, stop constructing the dependency map + // as remaining dependencies will be invalid post projection. break; } } } - dependency_map + map } - /// Returns a new `ProjectionMapping` where source expressions are normalized. - /// - /// This normalization ensures that source expressions are transformed into a - /// consistent representation. This is beneficial for algorithms that rely on - /// exact equalities, as it allows for more precise and reliable comparisons. + /// Projects the sort expression according to the projection mapping and + /// inserts it into the dependency map with the given dependency. Returns + /// a boolean flag indicating whether the given expression is projectable. + fn insert_to_dependency_map( + &self, + mapping: &ProjectionMapping, + sort_expr: PhysicalSortExpr, + dependency: Option, + map: &mut DependencyMap, + ) -> bool { + let target_sort_expr = self + .project_expr(&sort_expr.expr, mapping) + .map(|expr| PhysicalSortExpr::new(expr, sort_expr.options)); + let projectable = target_sort_expr.is_some(); + if projectable + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + map.insert(sort_expr, target_sort_expr, dependency); + } + projectable + } + + /// Returns a new `ProjectionMapping` where source expressions are in normal + /// form. Normalization ensures that source expressions are transformed into + /// a consistent representation, which is beneficial for algorithms that rely + /// on exact equalities, as it allows for more precise and reliable comparisons. /// /// # Parameters /// - /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// - `mapping`: A reference to the original `ProjectionMapping` to normalize. /// /// # Returns /// - /// A new `ProjectionMapping` with normalized source expressions. - fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { - // Construct the mapping where source expressions are normalized. In this way - // In the algorithms below we can work on exact equalities - ProjectionMapping { - map: mapping - .iter() - .map(|(source, target)| { - let normalized_source = - self.eq_group.normalize_expr(Arc::clone(source)); - (normalized_source, Arc::clone(target)) - }) - .collect(), - } + /// A new `ProjectionMapping` with source expressions in normal form. + fn normalize_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + mapping + .iter() + .map(|(source, target)| { + let normal_source = self.eq_group.normalize_expr(Arc::clone(source)); + (normal_source, target.clone()) + }) + .collect() } /// Computes projected orderings based on a given projection mapping. @@ -984,42 +1017,55 @@ impl EquivalenceProperties { /// /// - `mapping`: A reference to the `ProjectionMapping` that defines the /// relationship between source and target expressions. + /// - `oeq_class`: The `OrderingEquivalenceClass` containing the orderings + /// to project. /// /// # Returns /// - /// A vector of `LexOrdering` containing all valid orderings after projection. - fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { - let mapping = self.normalized_mapping(mapping); - + /// A vector of all valid (but not in normal form) orderings after projection. + fn projected_orderings( + &self, + mapping: &ProjectionMapping, + mut oeq_class: OrderingEquivalenceClass, + ) -> Vec { + // Normalize source expressions in the mapping: + let mapping = self.normalize_mapping(mapping); // Get dependency map for existing orderings: - let dependency_map = self.construct_dependency_map(&mapping); - let orderings = mapping.iter().flat_map(|(source, target)| { + oeq_class = Self::substitute_oeq_class(&self.schema, &mapping, oeq_class); + let dependency_map = self.construct_dependency_map(oeq_class, &mapping); + let orderings = mapping.iter().flat_map(|(source, targets)| { referred_dependencies(&dependency_map, source) .into_iter() - .filter_map(|relevant_deps| { - if let Ok(SortProperties::Ordered(options)) = - get_expr_properties(source, &relevant_deps, &self.schema) - .map(|prop| prop.sort_properties) - { - Some((options, relevant_deps)) + .filter_map(|deps| { + let ep = get_expr_properties(source, &deps, &self.schema); + let sort_properties = ep.map(|prop| prop.sort_properties); + if let Ok(SortProperties::Ordered(options)) = sort_properties { + Some((options, deps)) } else { - // Do not consider unordered cases + // Do not consider unordered cases. None } }) .flat_map(|(options, relevant_deps)| { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(target), - options, - }; - // Generate dependent orderings (i.e. prefixes for `sort_expr`): - let mut dependency_orderings = + // Generate dependent orderings (i.e. prefixes for targets): + let dependency_orderings = generate_dependency_orderings(&relevant_deps, &dependency_map); - // Append `sort_expr` to the dependent orderings: - for ordering in dependency_orderings.iter_mut() { - ordering.push(sort_expr.clone()); + let sort_exprs = targets.iter().map(|(target, _)| { + PhysicalSortExpr::new(Arc::clone(target), options) + }); + if dependency_orderings.is_empty() { + sort_exprs.map(|sort_expr| [sort_expr].into()).collect() + } else { + sort_exprs + .flat_map(|sort_expr| { + let mut result = dependency_orderings.clone(); + for ordering in result.iter_mut() { + ordering.push(sort_expr.clone()); + } + result + }) + .collect::>() } - dependency_orderings }) }); @@ -1033,116 +1079,67 @@ impl EquivalenceProperties { if prefixes.is_empty() { // If prefix is empty, there is no dependency. Insert // empty ordering: - prefixes = vec![LexOrdering::default()]; - } - // Append current ordering on top its dependencies: - for ordering in prefixes.iter_mut() { - if let Some(target) = &node.target_sort_expr { - ordering.push(target.clone()) + if let Some(target) = &node.target { + prefixes.push([target.clone()].into()); + } + } else { + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target { + ordering.push(target.clone()); + } } } prefixes }); // Simplify each ordering by removing redundant sections: - orderings - .chain(projected_orderings) - .map(|lex_ordering| lex_ordering.collapse()) - .collect() - } - - /// Projects constants based on the provided `ProjectionMapping`. - /// - /// This function takes a `ProjectionMapping` and identifies/projects - /// constants based on the existing constants and the mapping. It ensures - /// that constants are appropriately propagated through the projection. - /// - /// # Parameters - /// - /// - `mapping`: A reference to a `ProjectionMapping` representing the - /// mapping of source expressions to target expressions in the projection. - /// - /// # Returns - /// - /// Returns a `Vec>` containing the projected constants. - fn projected_constants(&self, mapping: &ProjectionMapping) -> Vec { - // First, project existing constants. For example, assume that `a + b` - // is known to be constant. If the projection were `a as a_new`, `b as b_new`, - // then we would project constant `a + b` as `a_new + b_new`. - let mut projected_constants = self - .constants - .iter() - .flat_map(|const_expr| { - const_expr - .map(|expr| self.eq_group.project_expr(mapping, expr)) - .map(|projected_expr| { - projected_expr - .with_across_partitions(const_expr.across_partitions()) - }) - }) - .collect::>(); - - // Add projection expressions that are known to be constant: - for (source, target) in mapping.iter() { - if self.is_expr_constant(source) - && !const_exprs_contains(&projected_constants, target) - { - if self.is_expr_constant_across_partitions(source) { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(self.get_expr_constant_value(source)), - ) - } else { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(AcrossPartitions::Heterogeneous), - ) - } - } - } - projected_constants + orderings.chain(projected_orderings).collect() } /// Projects constraints according to the given projection mapping. /// - /// This function takes a projection mapping and extracts the column indices of the target columns. - /// It then projects the constraints to only include relationships between - /// columns that exist in the projected output. + /// This function takes a projection mapping and extracts column indices of + /// target columns. It then projects the constraints to only include + /// relationships between columns that exist in the projected output. /// - /// # Arguments + /// # Parameters /// - /// * `mapping` - A reference to `ProjectionMapping` that defines how expressions are mapped - /// in the projection operation + /// * `mapping` - A reference to the `ProjectionMapping` that defines the + /// projection operation. /// /// # Returns /// - /// Returns a new `Constraints` object containing only the constraints - /// that are valid for the projected columns. + /// Returns an optional `Constraints` object containing only the constraints + /// that are valid for the projected columns (if any exists). fn projected_constraints(&self, mapping: &ProjectionMapping) -> Option { let indices = mapping .iter() - .filter_map(|(_, target)| target.as_any().downcast_ref::()) - .map(|col| col.index()) + .flat_map(|(_, targets)| { + targets.iter().flat_map(|(target, _)| { + target.as_any().downcast_ref::().map(|c| c.index()) + }) + }) .collect::>(); - debug_assert_eq!(mapping.map.len(), indices.len()); self.constraints.project(&indices) } - /// Projects the equivalences within according to `mapping` - /// and `output_schema`. + /// Projects the equivalences within according to `mapping` and + /// `output_schema`. pub fn project(&self, mapping: &ProjectionMapping, output_schema: SchemaRef) -> Self { let eq_group = self.eq_group.project(mapping); - let oeq_class = OrderingEquivalenceClass::new(self.projected_orderings(mapping)); - let constants = self.projected_constants(mapping); - let constraints = self - .projected_constraints(mapping) - .unwrap_or_else(Constraints::empty); + let orderings = + self.projected_orderings(mapping, self.oeq_cache.normal_cls.clone()); + let normal_orderings = orderings + .iter() + .cloned() + .map(|o| eq_group.normalize_sort_exprs(o)); Self { + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class: OrderingEquivalenceClass::new(orderings), + constraints: self.projected_constraints(mapping).unwrap_or_default(), schema: output_schema, eq_group, - oeq_class, - constants, - constraints, } } @@ -1159,7 +1156,7 @@ impl EquivalenceProperties { pub fn find_longest_permutation( &self, exprs: &[Arc], - ) -> (LexOrdering, Vec) { + ) -> Result<(Vec, Vec)> { let mut eq_properties = self.clone(); let mut result = vec![]; // The algorithm is as follows: @@ -1172,32 +1169,23 @@ impl EquivalenceProperties { // This algorithm should reach a fixed point in at most `exprs.len()` // iterations. let mut search_indices = (0..exprs.len()).collect::>(); - for _idx in 0..exprs.len() { + for _ in 0..exprs.len() { // Get ordered expressions with their indices. let ordered_exprs = search_indices .iter() - .flat_map(|&idx| { + .filter_map(|&idx| { let ExprProperties { sort_properties, .. } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); match sort_properties { - SortProperties::Ordered(options) => Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )), + SortProperties::Ordered(options) => { + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new(expr, options), idx)) + } SortProperties::Singleton => { - // Assign default ordering to constant expressions - let options = SortOptions::default(); - Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )) + // Assign default ordering to constant expressions: + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new_default(expr), idx)) } SortProperties::Unordered => None, } @@ -1215,44 +1203,20 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.with_constants(std::iter::once(ConstExpr::from(expr))); + let const_expr = ConstExpr::from(Arc::clone(expr)); + eq_properties.add_constants(std::iter::once(const_expr))?; search_indices.shift_remove(idx); } // Add new ordered section to the state. result.extend(ordered_exprs); } - let (left, right) = result.into_iter().unzip(); - (LexOrdering::new(left), right) - } - - /// This function determines whether the provided expression is constant - /// based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant according to equivalence - /// group, `false` otherwise. - pub fn is_expr_constant(&self, expr: &Arc) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) + Ok(result.into_iter().unzip()) } /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. + /// based on the known constants. For example, if columns `a` and `b` are + /// constant, then expressions `a`, `b` and `a + b` will all return `true` + /// whereas expression `c` will return `false`. /// /// # Parameters /// @@ -1261,87 +1225,15 @@ impl EquivalenceProperties { /// /// # Returns /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise - #[deprecated( - since = "45.0.0", - note = "Use [`is_expr_constant_across_partitions`] instead" - )] - pub fn is_expr_constant_accross_partitions( + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( &self, expr: &Arc, - ) -> bool { - self.is_expr_constant_across_partitions(expr) - } - - /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise. - pub fn is_expr_constant_across_partitions( - &self, - expr: &Arc, - ) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .filter_map(|const_expr| { - if matches!( - const_expr.across_partitions(), - AcrossPartitions::Uniform { .. } - ) { - Some(Arc::clone(const_expr.expr())) - } else { - None - } - }) - .collect::>(); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) - } - - /// Retrieves the constant value of a given physical expression, if it exists. - /// - /// Normalizes the input expression and checks if it matches any known constants - /// in the current context. Returns whether the expression has a uniform value, - /// varies across partitions, or is not constant. - /// - /// # Parameters - /// - `expr`: A reference to the physical expression to evaluate. - /// - /// # Returns - /// - `AcrossPartitions::Uniform(value)`: If the expression has the same value across partitions. - /// - `AcrossPartitions::Heterogeneous`: If the expression varies across partitions. - /// - `None`: If the expression is not recognized as constant. - pub fn get_expr_constant_value( - &self, - expr: &Arc, - ) -> AcrossPartitions { - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - - if let Some(lit) = normalized_expr.as_any().downcast_ref::() { - return AcrossPartitions::Uniform(Some(lit.value().clone())); - } - - for const_expr in self.constants.iter() { - if normalized_expr.eq(const_expr.expr()) { - return const_expr.across_partitions(); - } - } - - AcrossPartitions::Heterogeneous + ) -> Option { + self.eq_group.is_expr_constant(expr) } /// Retrieves the properties for a given physical expression. @@ -1367,10 +1259,9 @@ impl EquivalenceProperties { .unwrap_or_else(|_| ExprProperties::new_unknown()) } - /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` - /// by mapping columns in the original schema to columns in the new schema - /// by index. - pub fn with_new_schema(self, schema: SchemaRef) -> Result { + /// Transforms this `EquivalenceProperties` by mapping columns in the + /// original schema to columns in the new schema by index. + pub fn with_new_schema(mut self, schema: SchemaRef) -> Result { // The new schema and the original schema is aligned when they have the // same number of columns, and fields at the same index have the same // type in both schemas. @@ -1385,54 +1276,49 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Cannot rewrite old_schema:{:?} with new schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", self.schema, schema ); } - // Rewrite constants according to new schema: - let new_constants = self - .constants - .into_iter() - .map(|const_expr| { - let across_partitions = const_expr.across_partitions(); - let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; - Ok(ConstExpr::new(new_const_expr) - .with_across_partitions(across_partitions)) - }) - .collect::>>()?; - - // Rewrite orderings according to new schema: - let mut new_orderings = vec![]; - for ordering in self.oeq_class { - let new_ordering = ordering - .into_iter() - .map(|mut sort_expr| { - sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; - Ok(sort_expr) - }) - .collect::>()?; - new_orderings.push(new_ordering); - } // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group { - let new_eq_exprs = eq_class - .into_vec() + for mut eq_class in self.eq_group { + // Rewrite the expressions in the equivalence class: + eq_class.exprs = eq_class + .exprs .into_iter() .map(|expr| with_new_schema(expr, &schema)) .collect::>()?; - eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + // Rewrite the constant value (if available and known): + let data_type = eq_class + .canonical_expr() + .map(|e| e.data_type(&schema)) + .transpose()?; + if let (Some(data_type), Some(AcrossPartitions::Uniform(Some(value)))) = + (data_type, &mut eq_class.constant) + { + *value = value.cast_to(&data_type)?; + } + eq_classes.push(eq_class); } + self.eq_group = eq_classes.into(); + + // Rewrite orderings according to new schema: + self.oeq_class = self.oeq_class.with_new_schema(&schema)?; + self.oeq_cache.normal_cls = self.oeq_cache.normal_cls.with_new_schema(&schema)?; + + // Update the schema: + self.schema = schema; - // Construct the resulting equivalence properties: - let mut result = EquivalenceProperties::new(schema); - result.constants = new_constants; - result.add_new_orderings(new_orderings); - result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + Ok(self) + } +} - Ok(result) +impl From for OrderingEquivalenceClass { + fn from(eq_properties: EquivalenceProperties) -> Self { + eq_properties.oeq_class } } @@ -1440,24 +1326,21 @@ impl EquivalenceProperties { /// /// Format: /// ```text -/// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] +/// order: [[b@1 ASC NULLS LAST]], eq: [{members: [a@0], constant: (heterogeneous)}] /// ``` impl Display for EquivalenceProperties { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.eq_group.is_empty() - && self.oeq_class.is_empty() - && self.constants.is_empty() - { - return write!(f, "No properties"); - } - if !self.oeq_class.is_empty() { + let empty_eq_group = self.eq_group.is_empty(); + let empty_oeq_class = self.oeq_class.is_empty(); + if empty_oeq_class && empty_eq_group { + write!(f, "No properties")?; + } else if !empty_oeq_class { write!(f, "order: {}", self.oeq_class)?; - } - if !self.eq_group.is_empty() { - write!(f, ", eq: {}", self.eq_group)?; - } - if !self.constants.is_empty() { - write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; + if !empty_eq_group { + write!(f, ", eq: {}", self.eq_group)?; + } + } else { + write!(f, "eq: {}", self.eq_group)?; } Ok(()) } @@ -1501,45 +1384,20 @@ fn update_properties( Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? } // Now, check what we know about orderings: - let normalized_expr = eq_properties + let normal_expr = eq_properties .eq_group .normalize_expr(Arc::clone(&node.expr)); - let oeq_class = eq_properties.normalized_oeq_class(); - if eq_properties.is_expr_constant(&normalized_expr) - || oeq_class.is_expr_partial_const(&normalized_expr) + let oeq_class = &eq_properties.oeq_cache.normal_cls; + if eq_properties.is_expr_constant(&normal_expr).is_some() + || oeq_class.is_expr_partial_const(&normal_expr) { node.data.sort_properties = SortProperties::Singleton; - } else if let Some(options) = oeq_class.get_options(&normalized_expr) { + } else if let Some(options) = oeq_class.get_options(&normal_expr) { node.data.sort_properties = SortProperties::Ordered(options); } Ok(Transformed::yes(node)) } -/// This function determines whether the provided expression is constant -/// based on the known constants. -/// -/// # Parameters -/// -/// - `constants`: A `&[Arc]` containing expressions known to -/// be a constant. -/// - `expr`: A reference to a `Arc` representing the expression -/// to check. -/// -/// # Returns -/// -/// Returns `true` if the expression is constant according to equivalence -/// group, `false` otherwise. -fn is_constant_recurse( - constants: &[Arc], - expr: &Arc, -) -> bool { - if physical_exprs_contains(constants, expr) || expr.as_any().is::() { - return true; - } - let children = expr.children(); - !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) -} - /// This function examines whether a referring expression directly refers to a /// given referred expression or if any of its children in the expression tree /// refer to the specified expression. @@ -1614,59 +1472,3 @@ fn get_expr_properties( expr.get_properties(&child_states) } } - -/// Wrapper struct for `Arc` to use them as keys in a hash map. -#[derive(Debug, Clone)] -struct ExprWrapper(Arc); - -impl PartialEq for ExprWrapper { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl Eq for ExprWrapper {} - -impl Hash for ExprWrapper { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::{col, BinaryExpr}; - - use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; - use datafusion_expr::Operator; - - #[test] - fn test_expr_consists_of_constants() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; - let b_plus_d = Arc::new(BinaryExpr::new( - Arc::clone(&col_b), - Operator::Plus, - Arc::clone(&col_d), - )) as Arc; - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; - let expr = Arc::clone(&b_plus_d); - assert!(!is_constant_recurse(&constants, &expr)); - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; - let expr = Arc::clone(&b_plus_d); - assert!(is_constant_recurse(&constants, &expr)); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 64ef9278e248..4f44b9b0c9d4 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -15,28 +15,26 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, Result}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::iter::Peekable; use std::sync::Arc; +use super::EquivalenceProperties; use crate::equivalence::class::AcrossPartitions; -use crate::ConstExpr; +use crate::{ConstExpr, PhysicalSortExpr}; -use super::EquivalenceProperties; -use crate::PhysicalSortExpr; use arrow::datatypes::SchemaRef; -use std::slice::Iter; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; -/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` -/// of `lhs` and `rhs` according to the schema of `lhs`. +/// Computes the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. /// -/// Rules: The UnionExec does not interleave its inputs: instead it passes each -/// input partition from the children as its own output. +/// Rules: The `UnionExec` does not interleave its inputs, instead it passes +/// each input partition from the children as its own output. /// /// Since the output equivalence properties are properties that are true for /// *all* output partitions, that is the same as being true for all *input* -/// partitions +/// partitions. fn calculate_union_binary( lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, @@ -48,28 +46,21 @@ fn calculate_union_binary( // First, calculate valid constants for the union. An expression is constant // at the output of the union if it is constant in both sides with matching values. + let rhs_constants = rhs.constants(); let constants = lhs .constants() - .iter() + .into_iter() .filter_map(|lhs_const| { // Find matching constant expression in RHS - rhs.constants() + rhs_constants .iter() - .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr())) + .find(|rhs_const| rhs_const.expr.eq(&lhs_const.expr)) .map(|rhs_const| { - let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr())); - - // If both sides have matching constant values, preserve the value and set across_partitions=true - if let ( - AcrossPartitions::Uniform(Some(lhs_val)), - AcrossPartitions::Uniform(Some(rhs_val)), - ) = (lhs_const.across_partitions(), rhs_const.across_partitions()) - { - if lhs_val == rhs_val { - const_expr = const_expr.with_across_partitions( - AcrossPartitions::Uniform(Some(lhs_val)), - ) - } + let mut const_expr = lhs_const.clone(); + // If both sides have matching constant values, preserve it. + // Otherwise, set fall back to heterogeneous values. + if lhs_const.across_partitions != rhs_const.across_partitions { + const_expr.across_partitions = AcrossPartitions::Heterogeneous; } const_expr }) @@ -79,14 +70,13 @@ fn calculate_union_binary( // Next, calculate valid orderings for the union by searching for prefixes // in both sides. let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs); - orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs); + orderings.add_satisfied_orderings(&lhs, &rhs)?; + orderings.add_satisfied_orderings(&rhs, &lhs)?; let orderings = orderings.build(); - let mut eq_properties = - EquivalenceProperties::new(lhs.schema).with_constants(constants); - - eq_properties.add_new_orderings(orderings); + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.add_constants(constants)?; + eq_properties.add_orderings(orderings); Ok(eq_properties) } @@ -137,135 +127,139 @@ impl UnionEquivalentOrderingBuilder { Self { orderings: vec![] } } - /// Add all orderings from `orderings` that satisfy `properties`, - /// potentially augmented with`constants`. + /// Add all orderings from `source` that satisfy `properties`, + /// potentially augmented with the constants in `source`. /// - /// Note: any column that is known to be constant can be inserted into the - /// ordering without changing its meaning + /// Note: Any column that is known to be constant can be inserted into the + /// ordering without changing its meaning. /// /// For example: - /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` - /// * `properties` has required ordering `[a ASC, b ASC]` + /// * Orderings in `source` contains `[a ASC, c ASC]` and constants contains + /// `b`, + /// * `properties` has the ordering `[a ASC, b ASC]`. /// /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was /// in the sort order and `b` was a constant). fn add_satisfied_orderings( &mut self, - orderings: impl IntoIterator, - constants: &[ConstExpr], + source: &EquivalenceProperties, properties: &EquivalenceProperties, - ) { - for mut ordering in orderings.into_iter() { + ) -> Result<()> { + let constants = source.constants(); + let properties_constants = properties.constants(); + for mut ordering in source.oeq_cache.normal_cls.clone() { // Progressively shorten the ordering to search for a satisfied prefix: loop { - match self.try_add_ordering(ordering, constants, properties) { + ordering = match self.try_add_ordering( + ordering, + &constants, + properties, + &properties_constants, + )? { AddedOrdering::Yes => break, - AddedOrdering::No(o) => { - ordering = o; - ordering.pop(); + AddedOrdering::No(ordering) => { + let mut sort_exprs: Vec<_> = ordering.into(); + sort_exprs.pop(); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + ordering + } else { + break; + } } } } } + Ok(()) } - /// Adds `ordering`, potentially augmented with constants, if it satisfies - /// the target `properties` properties. + /// Adds `ordering`, potentially augmented with `constants`, if it satisfies + /// the given `properties`. /// - /// Returns + /// # Returns /// - /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or - /// augmented), or was empty. - /// - /// * [`AddedOrdering::No`] if the ordering was not added + /// An [`AddedOrdering::Yes`] instance if the ordering was added (either + /// directly or augmented), or was empty. An [`AddedOrdering::No`] instance + /// otherwise. fn try_add_ordering( &mut self, ordering: LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, - ) -> AddedOrdering { - if ordering.is_empty() { - AddedOrdering::Yes - } else if properties.ordering_satisfy(ordering.as_ref()) { + properties_constants: &[ConstExpr], + ) -> Result { + if properties.ordering_satisfy(ordering.clone())? { // If the ordering satisfies the target properties, no need to // augment it with constants. self.orderings.push(ordering); - AddedOrdering::Yes + Ok(AddedOrdering::Yes) + } else if self.try_find_augmented_ordering( + &ordering, + constants, + properties, + properties_constants, + ) { + // Augmented with constants to match the properties. + Ok(AddedOrdering::Yes) } else { - // Did not satisfy target properties, try and augment with constants - // to match the properties - if self.try_find_augmented_ordering(&ordering, constants, properties) { - AddedOrdering::Yes - } else { - AddedOrdering::No(ordering) - } + Ok(AddedOrdering::No(ordering)) } } /// Attempts to add `constants` to `ordering` to satisfy the properties. - /// - /// returns true if any orderings were added, false otherwise + /// Returns `true` if augmentation took place, `false` otherwise. fn try_find_augmented_ordering( &mut self, ordering: &LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, + properties_constants: &[ConstExpr], ) -> bool { - // can't augment if there is nothing to augment with - if constants.is_empty() { - return false; - } - let start_num_orderings = self.orderings.len(); - - // for each equivalent ordering in properties, try and augment - // `ordering` it with the constants to match - for existing_ordering in properties.oeq_class.iter() { - if let Some(augmented_ordering) = self.augment_ordering( - ordering, - constants, - existing_ordering, - &properties.constants, - ) { - if !augmented_ordering.is_empty() { - assert!(properties.ordering_satisfy(augmented_ordering.as_ref())); + let mut result = false; + // Can only augment if there are constants. + if !constants.is_empty() { + // For each equivalent ordering in properties, try and augment + // `ordering` with the constants to match `existing_ordering`: + for existing_ordering in properties.oeq_class.iter() { + if let Some(augmented_ordering) = Self::augment_ordering( + ordering, + constants, + existing_ordering, + properties_constants, + ) { self.orderings.push(augmented_ordering); + result = true; } } } - - self.orderings.len() > start_num_orderings + result } - /// Attempts to augment the ordering with constants to match the - /// `existing_ordering` - /// - /// Returns Some(ordering) if an augmented ordering was found, None otherwise + /// Attempts to augment the ordering with constants to match `existing_ordering`. + /// Returns `Some(ordering)` if an augmented ordering was found, `None` otherwise. fn augment_ordering( - &mut self, ordering: &LexOrdering, constants: &[ConstExpr], existing_ordering: &LexOrdering, existing_constants: &[ConstExpr], ) -> Option { - let mut augmented_ordering = LexOrdering::default(); - let mut sort_expr_iter = ordering.iter().peekable(); - let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); - - // walk in parallel down the two orderings, trying to match them up - while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() - { - // If the next expressions are equal, add the next match - // otherwise try and match with a constant + let mut augmented_ordering = vec![]; + let mut sort_exprs = ordering.iter().peekable(); + let mut existing_sort_exprs = existing_ordering.iter().peekable(); + + // Walk in parallel down the two orderings, trying to match them up: + while sort_exprs.peek().is_some() || existing_sort_exprs.peek().is_some() { + // If the next expressions are equal, add the next match. Otherwise, + // try and match with a constant. if let Some(expr) = - advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + advance_if_match(&mut sort_exprs, &mut existing_sort_exprs) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + advance_if_matches_constant(&mut sort_exprs, existing_constants) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + advance_if_matches_constant(&mut existing_sort_exprs, constants) { augmented_ordering.push(expr); } else { @@ -274,7 +268,7 @@ impl UnionEquivalentOrderingBuilder { } } - Some(augmented_ordering) + LexOrdering::new(augmented_ordering) } fn build(self) -> Vec { @@ -282,134 +276,135 @@ impl UnionEquivalentOrderingBuilder { } } -/// Advances two iterators in parallel -/// -/// If the next expressions are equal, the iterators are advanced and returns -/// the matched expression . -/// -/// Otherwise, the iterators are left unchanged and return `None` -fn advance_if_match( - iter1: &mut Peekable>, - iter2: &mut Peekable>, +/// Advances two iterators in parallel if the next expressions are equal. +/// Otherwise, the iterators are left unchanged and returns `None`. +fn advance_if_match<'a>( + iter1: &mut Peekable>, + iter2: &mut Peekable>, ) -> Option { - if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) - { - iter1.next().unwrap(); + let (expr1, expr2) = (iter1.peek()?, iter2.peek()?); + if expr1.eq(expr2) { + iter1.next(); iter2.next().cloned() } else { None } } -/// Advances the iterator with a constant -/// -/// If the next expression matches one of the constants, advances the iterator -/// returning the matched expression -/// -/// Otherwise, the iterator is left unchanged and returns `None` -fn advance_if_matches_constant( - iter: &mut Peekable>, +/// Advances the iterator with a constant if the next expression matches one of +/// the constants. Otherwise, the iterator is left unchanged and returns `None`. +fn advance_if_matches_constant<'a>( + iter: &mut Peekable>, constants: &[ConstExpr], ) -> Option { let expr = iter.peek()?; - let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; - let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + let const_expr = constants.iter().find(|c| expr.expr.eq(&c.expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(&const_expr.expr), expr.options); iter.next(); Some(found_expr) } #[cfg(test)] mod tests { - use super::*; - use crate::equivalence::class::const_exprs_contains; use crate::equivalence::tests::{create_test_schema, parse_sort_expr}; use crate::expressions::col; + use crate::PhysicalExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use itertools::Itertools; + /// Checks whether `expr` is among in the `const_exprs`. + fn const_exprs_contains( + const_exprs: &[ConstExpr], + expr: &Arc, + ) -> bool { + const_exprs + .iter() + .any(|const_expr| const_expr.expr.eq(expr)) + } + #[test] - fn test_union_equivalence_properties_multi_children_1() { + fn test_union_equivalence_properties_multi_children_1() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_2() { + fn test_union_equivalence_properties_multi_children_2() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b", "c"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b", "c"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_3() { + fn test_union_equivalence_properties_multi_children_3() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_4() { + fn test_union_equivalence_properties_multi_children_4() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["b2", "c2"]], &schema3) - .with_expected_sort(vec![]) + .with_child_sort(vec![vec!["b2", "c2"]], &schema3)? + .with_expected_sort(vec![])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_5() { + fn test_union_equivalence_properties_multi_children_5() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) - .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)? + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]])? .run() } #[test] - fn test_union_equivalence_properties_constants_common_constants() { + fn test_union_equivalence_properties_constants_common_constants() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -417,23 +412,23 @@ mod tests { vec![vec!["a"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [b ASC], const [a, c] vec![vec!["b"]], vec!["a", "c"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union expected orderings: [[a ASC], [b ASC]], const [c] vec![vec!["a"], vec!["b"]], vec!["c"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_prefix() { + fn test_union_equivalence_properties_constants_prefix() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -441,23 +436,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, b ASC], const [] vec![vec!["a", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC], const [] vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_asc_desc_mismatch() { + fn test_union_equivalence_properties_constants_asc_desc_mismatch() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -465,23 +460,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC], const [] vec![vec!["a DESC"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union doesn't have any ordering or constant vec![], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_different_schemas() { + fn test_union_equivalence_properties_constants_different_schemas() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) @@ -490,13 +485,13 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a1 ASC, b1 ASC], const [] vec![vec!["a1", "b1"]], vec![], &schema2, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC] // @@ -504,12 +499,12 @@ mod tests { // corresponding schemas. vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps() { + fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -517,13 +512,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -531,12 +526,12 @@ mod tests { // ], const [] vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_no_fill_gaps() { + fn test_union_equivalence_properties_constants_no_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -544,23 +539,23 @@ mod tests { vec![vec!["a", "c"]], vec!["d"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a]] (only a is constant) vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_some_gaps() { + fn test_union_equivalence_properties_constants_fill_some_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -568,23 +563,24 @@ mod tests { vec![vec!["c"]], vec!["a", "b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC, b], const [] vec![vec!["a DESC", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a, b]] (can fill in the a/b with constants) vec![vec!["a DESC", "b"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> + { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -592,13 +588,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b DESC", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -606,12 +602,12 @@ mod tests { // ], const [] vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_symmetric() { + fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -619,25 +615,25 @@ mod tests { vec![vec!["a", "b", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, c, d] // [a, c, b, d] vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_and_common() { + fn test_union_equivalence_properties_constants_gap_fill_and_common() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -645,24 +641,24 @@ mod tests { vec![vec!["a DESC", "d"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a DESC, c ASC, d ASC], const [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a DESC, c, d] [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_middle_desc() { + fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -672,20 +668,20 @@ mod tests { vec![vec!["a", "b DESC", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, d] (c constant) // [a, c, d] (b constant) vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], vec![], - ) + )? .run() } @@ -718,10 +714,10 @@ mod tests { mut self, orderings: Vec>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, vec![], schema); + ) -> Result { + let properties = self.make_props(orderings, vec![], schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Add a union input with the specified orderings and constant @@ -734,19 +730,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, constants, schema); + ) -> Result { + let properties = self.make_props(orderings, constants, schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Set the expected output sort order for the union of the children /// /// See [`Self::make_props`] for the format of the strings in `orderings` - fn with_expected_sort(mut self, orderings: Vec>) -> Self { - let properties = self.make_props(orderings, vec![], &self.output_schema); + fn with_expected_sort(mut self, orderings: Vec>) -> Result { + let properties = self.make_props(orderings, vec![], &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// Set the expected output sort order and constant expressions for the @@ -758,15 +754,16 @@ mod tests { mut self, orderings: Vec>, constants: Vec<&str>, - ) -> Self { - let properties = self.make_props(orderings, constants, &self.output_schema); + ) -> Result { + let properties = + self.make_props(orderings, constants, &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// compute the union's output equivalence properties from the child /// properties, and compare them to the expected properties - fn run(self) { + fn run(self) -> Result<()> { let Self { output_schema, child_properties, @@ -798,6 +795,7 @@ mod tests { ), ); } + Ok(()) } fn assert_eq_properties_same( @@ -808,9 +806,9 @@ mod tests { // Check whether constants are same let lhs_constants = lhs.constants(); let rhs_constants = rhs.constants(); - for rhs_constant in rhs_constants { + for rhs_constant in &rhs_constants { assert!( - const_exprs_contains(lhs_constants, rhs_constant.expr()), + const_exprs_contains(&lhs_constants, &rhs_constant.expr), "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } @@ -845,24 +843,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> EquivalenceProperties { - let orderings = orderings - .iter() - .map(|ordering| { - ordering - .iter() - .map(|name| parse_sort_expr(name, schema)) - .collect::() - }) - .collect::>(); + ) -> Result { + let orderings = orderings.iter().map(|ordering| { + ordering.iter().map(|name| parse_sort_expr(name, schema)) + }); let constants = constants .iter() - .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) - .collect::>(); + .map(|col_name| ConstExpr::from(col(col_name, schema).unwrap())); - EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) - .with_constants(constants) + let mut props = + EquivalenceProperties::new_with_orderings(Arc::clone(schema), orderings); + props.add_constants(constants)?; + Ok(props) } } @@ -877,25 +870,29 @@ mod tests { let literal_10 = ScalarValue::Int32(Some(10)); // Create first input with a=10 - let const_expr1 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input1 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr1]); + let const_expr1 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + input1.add_constants(vec![const_expr1])?; // Create second input with a=10 - let const_expr2 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input2 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr2]); + let const_expr2 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + input2.add_constants(vec![const_expr2])?; // Calculate union properties let union_props = calculate_union(vec![input1, input2], schema)?; // Verify column 'a' remains constant with value 10 let const_a = &union_props.constants()[0]; - assert!(const_a.expr().eq(&col_a)); + assert!(const_a.expr.eq(&col_a)); assert_eq!( - const_a.across_partitions(), + const_a.across_partitions, AcrossPartitions::Uniform(Some(literal_10)) ); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 7e345e60271f..c91678317b75 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -97,8 +97,10 @@ impl CastExpr { pub fn cast_options(&self) -> &CastOptions<'static> { &self.cast_options } - pub fn is_bigger_cast(&self, src: DataType) -> bool { - if src == self.cast_type { + + /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`). + pub fn is_bigger_cast(&self, src: &DataType) -> bool { + if self.cast_type.eq(src) { return true; } matches!( diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 9f795c81fa48..6741f94c9545 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -54,20 +54,19 @@ pub use equivalence::{ }; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ - create_ordering, create_physical_sort_expr, create_physical_sort_exprs, - physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, - PhysicalExprRef, + add_offset_to_expr, add_offset_to_physical_sort_exprs, create_ordering, + create_physical_sort_expr, create_physical_sort_exprs, physical_exprs_bag_equal, + physical_exprs_contains, physical_exprs_equal, }; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, PhysicalExprRef}; pub use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, }; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; - -pub use datafusion_physical_expr_common::utils::reverse_order_bys; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index eb7e1ea6282b..d6b2b1b046f7 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -199,18 +199,17 @@ impl Partitioning { /// Calculate the output partitioning after applying the given projection. pub fn project( &self, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, input_eq_properties: &EquivalenceProperties, ) -> Self { if let Partitioning::Hash(exprs, part) = self { - let normalized_exprs = exprs - .iter() - .map(|expr| { - input_eq_properties - .project_expr(expr, projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) + let normalized_exprs = input_eq_properties + .project_expressions(exprs, mapping) + .zip(exprs) + .map(|(proj_expr, expr)| { + proj_expr.unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) }) .collect(); Partitioning::Hash(normalized_exprs, *part) diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 63c4ccbb4b38..80dd8ce069b7 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,13 +17,40 @@ use std::sync::Arc; -use crate::create_physical_expr; +use crate::expressions::{self, Column}; +use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; + +use arrow::compute::SortOptions; +use arrow::datatypes::Schema; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; -pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; +use datafusion_expr::{Expr, SortExpr}; + use itertools::izip; +// Exports: +pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: isize, +) -> Result> { + expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + let Some(idx) = col.index().checked_add_signed(offset) else { + return plan_err!("Column index overflow"); + }; + Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) + } + None => Ok(Transformed::no(e)), + }) + .data() +} + /// This function is similar to the `contains` method of `Vec`. It finds /// whether `expr` is among `physical_exprs`. pub fn physical_exprs_contains( @@ -60,26 +87,21 @@ pub fn physical_exprs_bag_equal( multi_set_lhs == multi_set_rhs } -use crate::{expressions, LexOrdering, PhysicalSortExpr}; -use arrow::compute::SortOptions; -use arrow::datatypes::Schema; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::{Expr, SortExpr}; - -/// Converts logical sort expressions to physical sort expressions +/// Converts logical sort expressions to physical sort expressions. /// -/// This function transforms a collection of logical sort expressions into their physical -/// representation that can be used during query execution. +/// This function transforms a collection of logical sort expressions into their +/// physical representation that can be used during query execution. /// /// # Arguments /// -/// * `schema` - The schema containing column definitions -/// * `sort_order` - A collection of logical sort expressions grouped into lexicographic orderings +/// * `schema` - The schema containing column definitions. +/// * `sort_order` - A collection of logical sort expressions grouped into +/// lexicographic orderings. /// /// # Returns /// -/// A vector of lexicographic orderings for physical execution, or an error if the transformation fails +/// A vector of lexicographic orderings for physical execution, or an error if +/// the transformation fails. /// /// # Examples /// @@ -114,18 +136,13 @@ pub fn create_ordering( for (group_idx, exprs) in sort_order.iter().enumerate() { // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = LexOrdering::default(); + let mut sort_exprs = vec![]; for (expr_idx, sort) in exprs.iter().enumerate() { match &sort.expr { Expr::Column(col) => match expressions::col(&col.name, schema) { Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); + let opts = SortOptions::new(!sort.asc, sort.nulls_first); + sort_exprs.push(PhysicalSortExpr::new(expr, opts)); } // Cannot find expression in the projected_schema, stop iterating // since rest of the orderings are violated @@ -141,9 +158,7 @@ pub fn create_ordering( } } } - if !sort_exprs.is_empty() { - all_sort_orders.push(sort_exprs); - } + all_sort_orders.extend(LexOrdering::new(sort_exprs)); } Ok(all_sort_orders) } @@ -154,17 +169,9 @@ pub fn create_physical_sort_expr( input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, + create_physical_expr(&e.expr, input_dfschema, execution_props).map(|expr| { + let options = SortOptions::new(!e.asc, e.nulls_first); + PhysicalSortExpr::new(expr, options) }) } @@ -173,11 +180,24 @@ pub fn create_physical_sort_exprs( exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, -) -> Result { +) -> Result> { exprs .iter() - .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) - .collect::>() + .map(|e| create_physical_sort_expr(e, input_dfschema, execution_props)) + .collect() +} + +pub fn add_offset_to_physical_sort_exprs( + sort_exprs: impl IntoIterator, + offset: isize, +) -> Result> { + sort_exprs + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = add_offset_to_expr(sort_expr.expr, offset)?; + Ok(sort_expr) + }) + .collect() } #[cfg(test)] diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8660bff796d5..7cfaaba553f9 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -380,7 +380,7 @@ where exprs .into_iter() .map(|expr| create_physical_expr(expr, input_dfschema, execution_props)) - .collect::>>() + .collect() } /// Convert a logical expression to a physical expression (without any simplification, etc) diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index b4d0758fd2e8..4d01389f00f6 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -33,8 +33,6 @@ use datafusion_common::tree_node::{ use datafusion_common::{HashMap, HashSet, Result}; use datafusion_expr::Operator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -266,15 +264,6 @@ pub fn reassign_predicate_columns( .data() } -/// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering { - left.iter() - .cloned() - .chain(right.iter().cloned()) - .unique() - .collect() -} - #[cfg(test)] pub(crate) mod tests { use std::any::Any; diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index dae0667afb25..6f0e7c963d14 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -27,7 +27,7 @@ use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; use arrow::array::Array; use arrow::array::ArrayRef; @@ -35,7 +35,7 @@ use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of an aggregate function. /// @@ -44,7 +44,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct PlainAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, is_constant_in_partition: bool, } @@ -54,7 +54,7 @@ impl PlainAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { let is_constant_in_partition = @@ -62,7 +62,7 @@ impl PlainAggregateWindowExpr { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, is_constant_in_partition, } @@ -77,7 +77,7 @@ impl PlainAggregateWindowExpr { &self, eq_properties: &mut EquivalenceProperties, window_expr_index: usize, - ) { + ) -> Result<()> { if let Some(expr) = self .get_aggregate_expr() .get_result_ordering(window_expr_index) @@ -86,8 +86,9 @@ impl PlainAggregateWindowExpr { eq_properties, expr, &self.partition_by, - ); + )?; } + Ok(()) } // Returns true if every row in the partition has the same window frame. This allows @@ -100,7 +101,7 @@ impl PlainAggregateWindowExpr { // This results in an invalid range specification. Following PostgreSQL’s convention, // we interpret this as the entire partition being used for the current window frame. fn is_window_constant_in_partition( - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: &WindowFrame, ) -> bool { let is_constant_bound = |bound: &WindowFrameBound| match bound { @@ -170,8 +171,8 @@ impl WindowExpr for PlainAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -185,14 +186,22 @@ impl WindowExpr for PlainAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 09d6af748755..33921a57a6ce 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -26,14 +26,13 @@ use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; -use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A window expr that takes the form of an aggregate function that /// can be incrementally computed over sliding windows. @@ -43,7 +42,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct SlidingAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, } @@ -52,13 +51,13 @@ impl SlidingAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, } } @@ -108,8 +107,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -123,14 +122,22 @@ impl WindowExpr for SlidingAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -157,7 +164,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { expr: new_expr, options: req.options, }) - .collect::(); + .collect(); Some(Arc::new(SlidingAggregateWindowExpr { aggregate: self .aggregate diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 73f47b0b6863..c3761aa78f72 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -24,23 +24,23 @@ use std::sync::Arc; use super::{StandardWindowFunctionExpr, WindowExpr}; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; + use arrow::array::{new_empty_array, ArrayRef}; -use arrow::compute::SortOptions; use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of a [`StandardWindowFunctionExpr`]. #[derive(Debug)] pub struct StandardWindowExpr { expr: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, } @@ -49,13 +49,13 @@ impl StandardWindowExpr { pub fn new( expr: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, } } @@ -70,15 +70,19 @@ impl StandardWindowExpr { /// If `self.expr` doesn't have an ordering, ordering equivalence properties /// are not updated. Otherwise, ordering equivalence properties are updated /// by the ordering of `self.expr`. - pub fn add_equal_orderings(&self, eq_properties: &mut EquivalenceProperties) { + pub fn add_equal_orderings( + &self, + eq_properties: &mut EquivalenceProperties, + ) -> Result<()> { let schema = eq_properties.schema(); if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { add_new_ordering_expr_with_partition_by( eq_properties, fn_res_ordering, &self.partition_by, - ); + )?; } + Ok(()) } } @@ -104,16 +108,15 @@ impl WindowExpr for StandardWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn evaluate(&self, batch: &RecordBatch) -> Result { let mut evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); if evaluator.uses_window_frame() { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); + let sort_options = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; let mut values = self.evaluate_args(batch)?; @@ -253,7 +256,11 @@ impl WindowExpr for StandardWindowExpr { Arc::new(StandardWindowExpr::new( reverse_expr, &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ }) @@ -276,10 +283,10 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( eqp: &mut EquivalenceProperties, expr: PhysicalSortExpr, partition_by: &[Arc], -) { +) -> Result<()> { if partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - eqp.add_new_orderings([LexOrdering::new(vec![expr])]); + eqp.add_ordering([expr]); } else { // If we have a PARTITION BY, standard functions can not introduce // a global ordering unless the existing ordering is compatible @@ -287,10 +294,11 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( // expressions and existing ordering expressions are equal (w.r.t. // set equality), we can prefix the ordering of `self.expr` with // the existing ordering. - let (mut ordering, _) = eqp.find_longest_permutation(partition_by); + let (mut ordering, _) = eqp.find_longest_permutation(partition_by)?; if ordering.len() == partition_by.len() { ordering.push(expr); - eqp.add_new_orderings([ordering]); + eqp.add_ordering(ordering); } } + Ok(()) } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 70a73c44ae9d..59f9e1780731 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{LexOrdering, PhysicalExpr}; +use crate::PhysicalExpr; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; @@ -33,6 +33,7 @@ use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, }; use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use indexmap::IndexMap; @@ -109,14 +110,14 @@ pub trait WindowExpr: Send + Sync + Debug { fn partition_by(&self) -> &[Arc]; /// Expressions that's from the window function's order by clause, empty if absent - fn order_by(&self) -> &LexOrdering; + fn order_by(&self) -> &[PhysicalSortExpr]; /// Get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() .iter() .map(|e| e.evaluate_to_sort_column(batch)) - .collect::>>() + .collect() } /// Get the window frame of this [WindowExpr]. @@ -138,7 +139,7 @@ pub trait WindowExpr: Send + Sync + Debug { .order_by() .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); WindowPhysicalExpressions { args, partition_by_exprs, @@ -194,8 +195,7 @@ pub trait AggregateWindowExpr: WindowExpr { fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( @@ -243,8 +243,7 @@ pub trait AggregateWindowExpr: WindowExpr { // If there is no window state context, initialize it. let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( @@ -353,13 +352,13 @@ pub(crate) fn is_end_bound_safe( window_frame_ctx: &WindowFrameContext, order_bys: &[ArrayRef], most_recent_order_bys: Option<&[ArrayRef]>, - sort_exprs: &LexOrdering, + sort_exprs: &[PhysicalSortExpr], idx: usize, ) -> Result { if sort_exprs.is_empty() { // Early return if no sort expressions are present: return Ok(false); - } + }; match window_frame_ctx { WindowFrameContext::Rows(window_frame) => { diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 700b00c19dd5..478ce39eecb9 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -944,12 +944,10 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.plan.output_ordering().is_some(); - - let new_plan = if should_preserve_ordering { + // (determined by flag `config.optimizer.prefer_existing_sort`) + let new_plan = if let Some(ordering) = input.plan.output_ordering() { Arc::new(SortPreservingMergeExec::new( - input.plan.output_ordering().cloned().unwrap_or_default(), + ordering.clone(), Arc::clone(&input.plan), )) as _ } else { @@ -1289,10 +1287,12 @@ pub fn ensure_distribution( // Either: // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. + let sort_req = required_input_ordering.into_single(); let ordering_satisfied = child .plan .equivalence_properties() - .ordering_satisfy_requirement(&required_input_ordering); + .ordering_satisfy_requirement(sort_req.clone())?; + if (!ordering_satisfied || !order_preserving_variants_desirable) && child.data { @@ -1303,9 +1303,12 @@ pub fn ensure_distribution( // Make sure to satisfy ordering requirement: child = add_sort_above_with_check( child, - required_input_ordering.clone(), - None, - ); + sort_req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + )?; } } // Stop tracking distribution changing operators diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 37fec2eab3f9..8a71b28486a2 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -46,6 +46,7 @@ use crate::enforce_sorting::replace_with_order_preserving_variants::{ use crate::enforce_sorting::sort_pushdown::{ assign_initial_requirements, pushdown_sorts, SortPushDown, }; +use crate::output_requirements::OutputRequirementExec; use crate::utils::{ add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, @@ -191,14 +192,20 @@ fn update_coalesce_ctx_children( } /// Performs optimizations based upon a series of subrules. -/// /// Refer to each subrule for detailed descriptions of the optimizations performed: -/// [`ensure_sorting`], [`parallelize_sorts`], [`replace_with_order_preserving_variants()`], -/// and [`pushdown_sorts`]. -/// /// Subrule application is ordering dependent. /// -/// The subrule `parallelize_sorts` is only applied if `repartition_sorts` is enabled. +/// Optimizer consists of 5 main parts which work sequentially +/// 1. [`ensure_sorting`] Works down-to-top to be able to remove unnecessary [`SortExec`]s, [`SortPreservingMergeExec`]s +/// add [`SortExec`]s if necessary by a requirement and adjusts window operators. +/// 2. [`parallelize_sorts`] (Optional, depends on the `repartition_sorts` configuration) +/// Responsible to identify and remove unnecessary partition unifier operators +/// such as [`SortPreservingMergeExec`], [`CoalescePartitionsExec`] follows [`SortExec`]s does possible simplifications. +/// 3. [`replace_with_order_preserving_variants()`] Replaces with alternative operators, for example can merge +/// a [`SortExec`] and a [`CoalescePartitionsExec`] into one [`SortPreservingMergeExec`] +/// or a [`SortExec`] + [`RepartitionExec`] combination into an order preserving [`RepartitionExec`] +/// 4. [`sort_pushdown`] Works top-down. Responsible to push down sort operators as deep as possible in the plan. +/// 5. `replace_with_partial_sort` Checks if it's possible to replace [`SortExec`]s with [`PartialSortExec`] operators impl PhysicalOptimizerRule for EnforceSorting { fn optimize( &self, @@ -251,87 +258,93 @@ impl PhysicalOptimizerRule for EnforceSorting { } } +/// Only interested with [`SortExec`]s and their unbounded children. +/// If the plan is not a [`SortExec`] or its child is not unbounded, returns the original plan. +/// Otherwise, by checking the requirement satisfaction searches for a replacement chance. +/// If there's one replaces the [`SortExec`] plan with a [`PartialSortExec`] fn replace_with_partial_sort( plan: Arc, ) -> Result> { let plan_any = plan.as_any(); - if let Some(sort_plan) = plan_any.downcast_ref::() { - let child = Arc::clone(sort_plan.children()[0]); - if !child.boundedness().is_unbounded() { - return Ok(plan); - } + let Some(sort_plan) = plan_any.downcast_ref::() else { + return Ok(plan); + }; - // here we're trying to find the common prefix for sorted columns that is required for the - // sort and already satisfied by the given ordering - let child_eq_properties = child.equivalence_properties(); - let sort_req = LexRequirement::from(sort_plan.expr().clone()); + // It's safe to get first child of the SortExec + let child = Arc::clone(sort_plan.children()[0]); + if !child.boundedness().is_unbounded() { + return Ok(plan); + } - let mut common_prefix_length = 0; - while child_eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: sort_req[0..common_prefix_length + 1].to_vec(), - }) { - common_prefix_length += 1; - } - if common_prefix_length > 0 { - return Ok(Arc::new( - PartialSortExec::new( - LexOrdering::new(sort_plan.expr().to_vec()), - Arc::clone(sort_plan.input()), - common_prefix_length, - ) - .with_preserve_partitioning(sort_plan.preserve_partitioning()) - .with_fetch(sort_plan.fetch()), - )); - } + // Here we're trying to find the common prefix for sorted columns that is required for the + // sort and already satisfied by the given ordering + let child_eq_properties = child.equivalence_properties(); + let sort_exprs = sort_plan.expr().clone(); + + let mut common_prefix_length = 0; + while child_eq_properties + .ordering_satisfy(sort_exprs[0..common_prefix_length + 1].to_vec())? + { + common_prefix_length += 1; + } + if common_prefix_length > 0 { + return Ok(Arc::new( + PartialSortExec::new( + sort_exprs, + Arc::clone(sort_plan.input()), + common_prefix_length, + ) + .with_preserve_partitioning(sort_plan.preserve_partitioning()) + .with_fetch(sort_plan.fetch()), + )); } Ok(plan) } -/// Transform [`CoalescePartitionsExec`] + [`SortExec`] into -/// [`SortExec`] + [`SortPreservingMergeExec`] as illustrated below: +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] cascades into [`SortExec`] +/// + [`SortPreservingMergeExec`] cascades, as illustrated below. /// -/// The [`CoalescePartitionsExec`] + [`SortExec`] cascades -/// combine the partitions first, and then sort: +/// A [`CoalescePartitionsExec`] + [`SortExec`] cascade combines partitions +/// first, and then sorts: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ -/// ││B│A│D│... ├──┐ -/// └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ Partition Partition -/// ││E│C│ ... ├──┘ -/// └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 -/// ``` +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` /// /// -/// The [`SortExec`] + [`SortPreservingMergeExec`] cascades -/// sorts each partition first, then merge partitions while retaining the sort: +/// A [`SortExec`] + [`SortPreservingMergeExec`] cascade sorts each partition +/// first, then merges partitions while preserving the sort: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ -/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ -/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition -/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ -/// └─┴─┘ │ │ └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 Partition 2 +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 /// ``` /// -/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs the -/// sort first on a per-partition basis, thereby parallelizing the sort. -/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs +/// sorting first on a per-partition basis, thereby parallelizing the sort. /// /// The outcome is that plans of the form /// ```text @@ -348,16 +361,32 @@ fn replace_with_partial_sort( /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` /// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. -/// By performing sorting in parallel, we can increase performance in some scenarios. +/// By performing sorting in parallel, we can increase performance in some +/// scenarios. /// -/// This requires that there are no nodes between the [`SortExec`] and [`CoalescePartitionsExec`] -/// which require single partitioning. Do not parallelize when the following scenario occurs: +/// This optimization requires that there are no nodes between the [`SortExec`] +/// and the [`CoalescePartitionsExec`], which requires single partitioning. Do +/// not parallelize when the following scenario occurs: /// ```text /// "SortExec: expr=\[a@0 ASC\]", /// " ...nodes requiring single partitioning..." /// " CoalescePartitionsExec", /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` +/// +/// **Steps** +/// 1. Checks if the plan is either a [`SortExec`], a [`SortPreservingMergeExec`], +/// or a [`CoalescePartitionsExec`]. Otherwise, does nothing. +/// 2. If the plan is a [`SortExec`] or a final [`SortPreservingMergeExec`] +/// (i.e. output partitioning is 1): +/// - Check for [`CoalescePartitionsExec`] in children. If found, check if +/// it can be removed (with possible [`RepartitionExec`]s). If so, remove +/// (see `remove_bottleneck_in_subplan`). +/// - If the plan is satisfying the ordering requirements, add a `SortExec`. +/// - Add an SPM above the plan and return. +/// 3. If the plan is a [`CoalescePartitionsExec`]: +/// - Check if it can be removed (with possible [`RepartitionExec`]s). +/// If so, remove (see `remove_bottleneck_in_subplan`). pub fn parallelize_sorts( mut requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { @@ -388,7 +417,7 @@ pub fn parallelize_sorts( // deals with the children and their children and so on. requirements = requirements.children.swap_remove(0); - requirements = add_sort_above_with_check(requirements, sort_reqs, fetch); + requirements = add_sort_above_with_check(requirements, sort_reqs, fetch)?; let spm = SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); @@ -424,6 +453,25 @@ pub fn parallelize_sorts( /// This function enforces sorting requirements and makes optimizations without /// violating these requirements whenever possible. Requires a bottom-up traversal. +/// +/// **Steps** +/// 1. Analyze if there are any immediate removals of [`SortExec`]s. If so, +/// removes them (see `analyze_immediate_sort_removal`). +/// 2. For each child of the plan, if the plan requires an input ordering: +/// - Checks if ordering is satisfied with the child. If not: +/// - If the child has an output ordering, removes the unnecessary +/// `SortExec`. +/// - Adds sort above the child plan. +/// - (Plan not requires input ordering) +/// - Checks if the `SortExec` is neutralized in the plan. If so, +/// removes it. +/// 3. Check and modify window operator: +/// - Checks if the plan is a window operator, and connected with a sort. +/// If so, either tries to update the window definition or removes +/// unnecessary [`SortExec`]s (see `adjust_window_sort_removal`). +/// 4. Check and remove possibly unnecessary SPM: +/// - Checks if the plan is SPM and child 1 output partitions, if so +/// decides this SPM is unnecessary and removes it from the plan. pub fn ensure_sorting( mut requirements: PlanWithCorrespondingSort, ) -> Result> { @@ -433,7 +481,7 @@ pub fn ensure_sorting( if requirements.children.is_empty() { return Ok(Transformed::no(requirements)); } - let maybe_requirements = analyze_immediate_sort_removal(requirements); + let maybe_requirements = analyze_immediate_sort_removal(requirements)?; requirements = if !maybe_requirements.transformed { maybe_requirements.data } else { @@ -452,12 +500,20 @@ pub fn ensure_sorting( if let Some(required) = required_ordering { let eq_properties = child.plan.equivalence_properties(); - if !eq_properties.ordering_satisfy_requirement(&required) { + let req = required.into_single(); + if !eq_properties.ordering_satisfy_requirement(req.clone())? { // Make sure we preserve the ordering requirements: if physical_ordering.is_some() { child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } - child = add_sort_above(child, required, None); + child = add_sort_above( + child, + req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + ); child = update_sort_ctx_children_data(child, true)?; } } else if physical_ordering.is_none() @@ -493,60 +549,56 @@ pub fn ensure_sorting( update_sort_ctx_children_data(requirements, false).map(Transformed::yes) } -/// Analyzes a given [`SortExec`] (`plan`) to determine whether its input -/// already has a finer ordering than it enforces. +/// Analyzes if there are any immediate sort removals by checking the `SortExec`s +/// and their ordering requirement satisfactions with children +/// If the sort is unnecessary, either replaces it with [`SortPreservingMergeExec`]/`LimitExec` +/// or removes the [`SortExec`]. +/// Otherwise, returns the original plan fn analyze_immediate_sort_removal( mut node: PlanWithCorrespondingSort, -) -> Transformed { - if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { - let sort_input = sort_exec.input(); - // If this sort is unnecessary, we should remove it: - if sort_input.equivalence_properties().ordering_satisfy( - sort_exec - .properties() - .output_ordering() - .unwrap_or_else(|| LexOrdering::empty()), - ) { - node.plan = if !sort_exec.preserve_partitioning() - && sort_input.output_partitioning().partition_count() > 1 - { - // Replace the sort with a sort-preserving merge: - let expr = LexOrdering::new(sort_exec.expr().to_vec()); - Arc::new( - SortPreservingMergeExec::new(expr, Arc::clone(sort_input)) - .with_fetch(sort_exec.fetch()), - ) as _ +) -> Result> { + let Some(sort_exec) = node.plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + let sort_input = sort_exec.input(); + // Check if the sort is unnecessary: + let properties = sort_exec.properties(); + if let Some(ordering) = properties.output_ordering().cloned() { + let eqp = sort_input.equivalence_properties(); + if !eqp.ordering_satisfy(ordering)? { + return Ok(Transformed::no(node)); + } + } + node.plan = if !sort_exec.preserve_partitioning() + && sort_input.output_partitioning().partition_count() > 1 + { + // Replace the sort with a sort-preserving merge: + Arc::new( + SortPreservingMergeExec::new( + sort_exec.expr().clone(), + Arc::clone(sort_input), + ) + .with_fetch(sort_exec.fetch()), + ) as _ + } else { + // Remove the sort: + node.children = node.children.swap_remove(0).children; + if let Some(fetch) = sort_exec.fetch() { + // If the sort has a fetch, we need to add a limit: + if properties.output_partitioning().partition_count() == 1 { + Arc::new(GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch))) } else { - // Remove the sort: - node.children = node.children.swap_remove(0).children; - if let Some(fetch) = sort_exec.fetch() { - // If the sort has a fetch, we need to add a limit: - if sort_exec - .properties() - .output_partitioning() - .partition_count() - == 1 - { - Arc::new(GlobalLimitExec::new( - Arc::clone(sort_input), - 0, - Some(fetch), - )) - } else { - Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) - } - } else { - Arc::clone(sort_input) - } - }; - for child in node.children.iter_mut() { - child.data = false; + Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) } - node.data = false; - return Transformed::yes(node); + } else { + Arc::clone(sort_input) } + }; + for child in node.children.iter_mut() { + child.data = false; } - Transformed::no(node) + node.data = false; + Ok(Transformed::yes(node)) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -587,15 +639,13 @@ fn adjust_window_sort_removal( } else { // We were unable to change the window to accommodate the input, so we // will insert a sort. - let reqs = window_tree - .plan - .required_input_ordering() - .swap_remove(0) - .unwrap_or_default(); + let reqs = window_tree.plan.required_input_ordering().swap_remove(0); // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); - child_node = add_sort_above(child_node, reqs, None); + if let Some(reqs) = reqs { + child_node = add_sort_above(child_node, reqs.into_single(), None); + } let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); @@ -742,8 +792,7 @@ fn remove_corresponding_sort_from_sub_plan( let fetch = plan.fetch(); let plan = if let Some(ordering) = plan.output_ordering() { Arc::new( - SortPreservingMergeExec::new(LexOrdering::new(ordering.to_vec()), plan) - .with_fetch(fetch), + SortPreservingMergeExec::new(ordering.clone(), plan).with_fetch(fetch), ) as _ } else { Arc::new(CoalescePartitionsExec::new(plan)) as _ diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 9769e2e0366f..b536e7960208 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -27,10 +27,7 @@ use crate::utils::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_physical_expr::LexOrdering; -use datafusion_physical_plan::internal_err; - -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::repartition::RepartitionExec; @@ -145,8 +142,8 @@ pub fn plan_with_order_preserving_variants( if let Some(sort_fetch) = fetch { if coalesce_fetch < sort_fetch { return internal_err!( - "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", coalesce_fetch, sort_fetch - ); + "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", coalesce_fetch, sort_fetch + ); } } else { // If the sort node does not have a fetch, we need to keep the coalesce node's fetch. @@ -181,18 +178,17 @@ pub fn plan_with_order_breaking_variants( .map(|(node, maintains, required_ordering)| { // Replace with non-order preserving variants as long as ordering is // not required by intermediate operators: - if maintains - && (is_sort_preserving_merge(plan) - || !required_ordering.is_some_and(|required_ordering| { - node.plan - .equivalence_properties() - .ordering_satisfy_requirement(&required_ordering) - })) - { - plan_with_order_breaking_variants(node) - } else { - Ok(node) + if !maintains { + return Ok(node); + } else if is_sort_preserving_merge(plan) { + return plan_with_order_breaking_variants(node); + } else if let Some(required_ordering) = required_ordering { + let eqp = node.plan.equivalence_properties(); + if eqp.ordering_satisfy_requirement(required_ordering.into_single())? { + return Ok(node); + } } + plan_with_order_breaking_variants(node) }) .collect::>()?; sort_input.data = false; @@ -281,25 +277,18 @@ pub fn replace_with_order_preserving_variants( )?; // If the alternate plan makes this sort unnecessary, accept the alternate: - if alternate_plan - .plan - .equivalence_properties() - .ordering_satisfy( - requirements - .plan - .output_ordering() - .unwrap_or_else(|| LexOrdering::empty()), - ) - { - for child in alternate_plan.children.iter_mut() { - child.data = false; + if let Some(ordering) = requirements.plan.output_ordering() { + let eqp = alternate_plan.plan.equivalence_properties(); + if !eqp.ordering_satisfy(ordering.clone())? { + // The alternate plan does not help, use faster order-breaking variants: + alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; + alternate_plan.data = false; + requirements.children = vec![alternate_plan]; + return Ok(Transformed::yes(requirements)); } - Ok(Transformed::yes(alternate_plan)) - } else { - // The alternate plan does not help, use faster order-breaking variants: - alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; - alternate_plan.data = false; - requirements.children = vec![alternate_plan]; - Ok(Transformed::yes(requirements)) } + for child in alternate_plan.children.iter_mut() { + child.data = false; + } + Ok(Transformed::yes(alternate_plan)) } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6d2c014f9e7c..bd7b0060c3be 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -24,12 +24,17 @@ use crate::utils::{ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{plan_err, HashSet, JoinSide, Result}; +use datafusion_common::{internal_err, HashSet, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr::{ + add_offset_to_physical_sort_exprs, EquivalenceProperties, +}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, +}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ calculate_join_output_ordering, ColumnIndex, @@ -50,7 +55,7 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; /// [`EnforceSorting`]: crate::enforce_sorting::EnforceSorting #[derive(Default, Clone, Debug)] pub struct ParentRequirements { - ordering_requirement: Option, + ordering_requirement: Option, fetch: Option, } @@ -69,6 +74,7 @@ pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) { } } +/// Tries to push down the sort requirements as far as possible, if decides a `SortExec` is unnecessary removes it. pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result { sort_push_down .transform_down(pushdown_sorts_helper) @@ -87,91 +93,107 @@ fn min_fetch(f1: Option, f2: Option) -> Option { fn pushdown_sorts_helper( mut sort_push_down: SortPushDown, ) -> Result> { - let plan = &sort_push_down.plan; - let parent_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let satisfy_parent = plan - .equivalence_properties() - .ordering_satisfy_requirement(&parent_reqs); - - if is_sort(plan) { - let current_sort_fetch = plan.fetch(); - let parent_req_fetch = sort_push_down.data.fetch; - - let current_plan_reqs = plan - .output_ordering() - .cloned() - .map(LexRequirement::from) - .unwrap_or_default(); - let parent_is_stricter = plan - .equivalence_properties() - .requirements_compatible(&parent_reqs, ¤t_plan_reqs); - let current_is_stricter = plan - .equivalence_properties() - .requirements_compatible(¤t_plan_reqs, &parent_reqs); + let plan = sort_push_down.plan; + let parent_fetch = sort_push_down.data.fetch; - if !satisfy_parent && !parent_is_stricter { - // This new sort has different requirements than the ordering being pushed down. - // 1. add a `SortExec` here for the pushed down ordering (parent reqs). - // 2. continue sort pushdown, but with the new ordering of the new sort. + let Some(parent_requirement) = sort_push_down.data.ordering_requirement.clone() + else { + // If there are no ordering requirements from the parent, nothing to do + // unless we have a sort. + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; + // The sort is unnecessary, just propagate the stricter fetch and + // ordering requirements. + let fetch = min_fetch(plan.fetch(), parent_fetch); + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + sort_push_down.data.fetch = fetch; + sort_push_down.data.ordering_requirement = + Some(OrderingRequirements::from(sort_ordering)); + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): + return pushdown_sorts_helper(sort_push_down); + } + sort_push_down.plan = plan; + return Ok(Transformed::no(sort_push_down)); + }; - // remove current sort (which will be the new ordering to pushdown) - let new_reqs = current_plan_reqs; - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan + let eqp = plan.equivalence_properties(); + let satisfy_parent = + eqp.ordering_satisfy_requirement(parent_requirement.first().clone())?; - // add back sort exec matching parent - sort_push_down = - add_sort_above(sort_push_down, parent_reqs, parent_req_fetch); + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; - // make pushdown requirements be the new ones. + let sort_fetch = plan.fetch(); + let parent_is_stricter = eqp.requirements_compatible( + parent_requirement.first().clone(), + sort_ordering.clone().into(), + ); + + // Remove the current sort as we are either going to prove that it is + // unnecessary, or replace it with a stricter sort. + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + if !satisfy_parent && !parent_is_stricter { + // The sort was imposing a different ordering than the one being + // pushed down. Replace it with a sort that matches the pushed-down + // ordering, and continue the pushdown. + // Add back the sort: + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); + // Update pushdown requirements: sort_push_down.children[0].data = ParentRequirements { - ordering_requirement: Some(new_reqs), - fetch: current_sort_fetch, + ordering_requirement: Some(OrderingRequirements::from(sort_ordering)), + fetch: sort_fetch, }; + return Ok(Transformed::yes(sort_push_down)); } else { - // Don't add a SortExec - // Do update what sort requirements to keep pushing down - - // remove current sort, and get the sort's child - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan - - // set the stricter fetch - sort_push_down.data.fetch = min_fetch(current_sort_fetch, parent_req_fetch); - - // set the stricter ordering - if current_is_stricter { - sort_push_down.data.ordering_requirement = Some(current_plan_reqs); + // Sort was unnecessary, just propagate the stricter fetch and + // ordering requirements: + sort_push_down.data.fetch = min_fetch(sort_fetch, parent_fetch); + let current_is_stricter = eqp.requirements_compatible( + sort_ordering.clone().into(), + parent_requirement.first().clone(), + ); + sort_push_down.data.ordering_requirement = if current_is_stricter { + Some(OrderingRequirements::from(sort_ordering)) } else { - sort_push_down.data.ordering_requirement = Some(parent_reqs); - } - - // recursive call to helper, so it doesn't transform_down and miss the new node (previous child of sort) + Some(parent_requirement) + }; + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): return pushdown_sorts_helper(sort_push_down); } - } else if parent_reqs.is_empty() { - // note: this `satisfy_parent`, but we don't want to push down anything. - // Nothing to do. - return Ok(Transformed::no(sort_push_down)); - } else if satisfy_parent { + } + + sort_push_down.plan = plan; + if satisfy_parent { // For non-sort operators which satisfy ordering: - let reqs = plan.required_input_ordering(); - let parent_req_fetch = sort_push_down.data.fetch; + let reqs = sort_push_down.plan.required_input_ordering(); for (child, order) in sort_push_down.children.iter_mut().zip(reqs) { child.data.ordering_requirement = order; - child.data.fetch = min_fetch(parent_req_fetch, child.data.fetch); + child.data.fetch = min_fetch(parent_fetch, child.data.fetch); } - } else if let Some(adjusted) = pushdown_requirement_to_children(plan, &parent_reqs)? { - // For operators that can take a sort pushdown. - - // Continue pushdown, with updated requirements: - let parent_fetch = sort_push_down.data.fetch; - let current_fetch = plan.fetch(); + } else if let Some(adjusted) = pushdown_requirement_to_children( + &sort_push_down.plan, + parent_requirement.clone(), + )? { + // For operators that can take a sort pushdown, continue with updated + // requirements: + let current_fetch = sort_push_down.plan.fetch(); for (child, order) in sort_push_down.children.iter_mut().zip(adjusted) { child.data.ordering_requirement = order; child.data.fetch = min_fetch(current_fetch, parent_fetch); @@ -179,16 +201,13 @@ fn pushdown_sorts_helper( sort_push_down.data.ordering_requirement = None; } else { // Can not push down requirements, add new `SortExec`: - let sort_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let fetch = sort_push_down.data.fetch; - sort_push_down = add_sort_above(sort_push_down, sort_reqs, fetch); + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); assign_initial_requirements(&mut sort_push_down); } - Ok(Transformed::yes(sort_push_down)) } @@ -196,21 +215,18 @@ fn pushdown_sorts_helper( /// If sort cannot be pushed down, return None. fn pushdown_requirement_to_children( plan: &Arc, - parent_required: &LexRequirement, -) -> Result>>> { + parent_required: OrderingRequirements, +) -> Result>>> { let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { - let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].clone().unwrap_or_default(); + let mut required_input_ordering = plan.required_input_ordering(); + let maybe_child_requirement = required_input_ordering.swap_remove(0); let child_plan = plan.children().swap_remove(0); - - match determine_children_requirement(parent_required, &request_child, child_plan) - { - RequirementsCompatibility::Satisfy => { - let req = (!request_child.is_empty()) - .then(|| LexRequirement::new(request_child.to_vec())); - Ok(Some(vec![req])) - } + let Some(child_req) = maybe_child_requirement else { + return Ok(None); + }; + match determine_children_requirement(&parent_required, &child_req, child_plan) { + RequirementsCompatibility::Satisfy => Ok(Some(vec![Some(child_req)])), RequirementsCompatibility::Compatible(adjusted) => { // If parent requirements are more specific than output ordering // of the window plan, then we can deduce that the parent expects @@ -218,7 +234,7 @@ fn pushdown_requirement_to_children( // that's the case, we block the pushdown of sort operation. if !plan .equivalence_properties() - .ordering_satisfy_requirement(parent_required) + .ordering_satisfy_requirement(parent_required.into_single())? { return Ok(None); } @@ -228,82 +244,71 @@ fn pushdown_requirement_to_children( RequirementsCompatibility::NonCompatible => Ok(None), } } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let sort_req = LexRequirement::from( - sort_exec - .properties() - .output_ordering() - .cloned() - .unwrap_or_else(LexOrdering::default), - ); - if sort_exec + let Some(sort_ordering) = sort_exec.properties().output_ordering().cloned() + else { + return internal_err!("SortExec should have output ordering"); + }; + sort_exec .properties() .eq_properties - .requirements_compatible(parent_required, &sort_req) - { - debug_assert!(!parent_required.is_empty()); - Ok(Some(vec![Some(LexRequirement::new( - parent_required.to_vec(), - ))])) - } else { - Ok(None) - } + .requirements_compatible( + parent_required.first().clone(), + sort_ordering.into(), + ) + .then(|| Ok(vec![Some(parent_required)])) + .transpose() } else if plan.fetch().is_some() && plan.supports_limit_pushdown() && plan .maintains_input_order() - .iter() - .all(|maintain| *maintain) + .into_iter() + .all(|maintain| maintain) { - let output_req = LexRequirement::from( - plan.properties() - .output_ordering() - .cloned() - .unwrap_or_else(LexOrdering::default), - ); // Push down through operator with fetch when: // - requirement is aligned with output ordering // - it preserves ordering during execution - if plan - .properties() - .eq_properties - .requirements_compatible(parent_required, &output_req) - { - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + let Some(ordering) = plan.properties().output_ordering() else { + return Ok(Some(vec![Some(parent_required)])); + }; + if plan.properties().eq_properties.requirements_compatible( + parent_required.first().clone(), + ordering.clone().into(), + ) { + Ok(Some(vec![Some(parent_required)])) } else { Ok(None) } } else if is_union(plan) { - // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and - // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - let req = (!parent_required.is_empty()).then(|| parent_required.clone()); - Ok(Some(vec![req; plan.children().len()])) + // `UnionExec` does not have real sort requirements for its input, we + // just propagate the sort requirements down: + Ok(Some(vec![Some(parent_required); plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { - // If the current plan is SortMergeJoinExec let left_columns_len = smj.left().schema().fields().len(); - let parent_required_expr = LexOrdering::from(parent_required.clone()); - match expr_source_side( - parent_required_expr.as_ref(), - smj.join_type(), - left_columns_len, - ) { - Some(JoinSide::Left) => try_pushdown_requirements_to_join( + let parent_ordering: Vec = parent_required + .first() + .iter() + .cloned() + .map(Into::into) + .collect(); + let eqp = smj.properties().equivalence_properties(); + match expr_source_side(eqp, parent_ordering, smj.join_type(), left_columns_len) { + Some((JoinSide::Left, ordering)) => try_pushdown_requirements_to_join( smj, - parent_required, - parent_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Left, ), - Some(JoinSide::Right) => { + Some((JoinSide::Right, ordering)) => { let right_offset = smj.schema().fields.len() - smj.right().schema().fields.len(); - let new_right_required = - shift_right_required(parent_required, right_offset)?; - let new_right_required_expr = LexOrdering::from(new_right_required); + let ordering = add_offset_to_physical_sort_exprs( + ordering, + -(right_offset as isize), + )?; try_pushdown_requirements_to_join( smj, - parent_required, - new_right_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Right, ) } @@ -318,28 +323,26 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() // TODO: Add support for Projection push down || plan.as_any().is::() - || pushdown_would_violate_requirements(parent_required, plan.as_ref()) + || pushdown_would_violate_requirements(&parent_required, plan.as_ref()) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. // Pushing down is not beneficial Ok(None) } else if is_sort_preserving_merge(plan) { - let new_ordering = LexOrdering::from(parent_required.clone()); + let new_ordering = LexOrdering::from(parent_required.first().clone()); let mut spm_eqs = plan.equivalence_properties().clone(); + let old_ordering = spm_eqs.output_ordering().unwrap(); // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. - spm_eqs = spm_eqs.with_reorder(new_ordering); - // Do not push-down through SortPreservingMergeExec when - // ordering requirement invalidates requirement of sort preserving merge exec. - if !spm_eqs.ordering_satisfy(&plan.output_ordering().cloned().unwrap_or_default()) - { - Ok(None) - } else { + let change = spm_eqs.reorder(new_ordering)?; + if !change || spm_eqs.ordering_satisfy(old_ordering)? { // Can push-down through SortPreservingMergeExec, because parent requirement is finer // than SortPreservingMergeExec output ordering. - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + Ok(Some(vec![Some(parent_required)])) + } else { + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + Ok(None) } } else if let Some(hash_join) = plan.as_any().downcast_ref::() { handle_hash_join(hash_join, parent_required) @@ -352,22 +355,21 @@ fn pushdown_requirement_to_children( /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( - parent_required: &LexRequirement, + parent_required: &OrderingRequirements, child: &dyn ExecutionPlan, ) -> bool { child .required_input_ordering() - .iter() + .into_iter() + // If there is no requirement, pushing down would not violate anything. + .flatten() .any(|child_required| { - let Some(child_required) = child_required.as_ref() else { - // no requirements, so pushing down would not violate anything - return false; - }; - // check if the plan's requirements would still e satisfied if we pushed - // down the parent requirements + // Check if the plan's requirements would still be satisfied if we + // pushed down the parent requirements: child_required + .into_single() .iter() - .zip(parent_required.iter()) + .zip(parent_required.first().iter()) .all(|(c, p)| !c.compatible(p)) }) } @@ -378,25 +380,24 @@ fn pushdown_would_violate_requirements( /// - If parent requirements are more specific, push down parent requirements. /// - If they are not compatible, need to add a sort. fn determine_children_requirement( - parent_required: &LexRequirement, - request_child: &LexRequirement, + parent_required: &OrderingRequirements, + child_requirement: &OrderingRequirements, child_plan: &Arc, ) -> RequirementsCompatibility { - if child_plan - .equivalence_properties() - .requirements_compatible(request_child, parent_required) - { + let eqp = child_plan.equivalence_properties(); + if eqp.requirements_compatible( + child_requirement.first().clone(), + parent_required.first().clone(), + ) { // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy - } else if child_plan - .equivalence_properties() - .requirements_compatible(parent_required, request_child) - { + } else if eqp.requirements_compatible( + parent_required.first().clone(), + child_requirement.first().clone(), + ) { // Parent requirements are more specific, adjust child's requirements // and push down the new requirements: - let adjusted = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - RequirementsCompatibility::Compatible(adjusted) + RequirementsCompatibility::Compatible(Some(parent_required.clone())) } else { RequirementsCompatibility::NonCompatible } @@ -404,42 +405,41 @@ fn determine_children_requirement( fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, - parent_required: &LexRequirement, - sort_expr: &LexOrdering, + parent_required: LexRequirement, + sort_exprs: Vec, push_side: JoinSide, -) -> Result>>> { - let left_eq_properties = smj.left().equivalence_properties(); - let right_eq_properties = smj.right().equivalence_properties(); +) -> Result>>> { let mut smj_required_orderings = smj.required_input_ordering(); - let right_requirement = smj_required_orderings.swap_remove(1); - let left_requirement = smj_required_orderings.swap_remove(0); - let left_ordering = &smj.left().output_ordering().cloned().unwrap_or_default(); - let right_ordering = &smj.right().output_ordering().cloned().unwrap_or_default(); + let ordering = LexOrdering::new(sort_exprs.clone()); let (new_left_ordering, new_right_ordering) = match push_side { JoinSide::Left => { - let left_eq_properties = - left_eq_properties.clone().with_reorder(sort_expr.clone()); - if left_eq_properties - .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) + let mut left_eq_properties = smj.left().equivalence_properties().clone(); + left_eq_properties.reorder(sort_exprs)?; + let Some(left_requirement) = smj_required_orderings.swap_remove(0) else { + return Ok(None); + }; + if !left_eq_properties + .ordering_satisfy_requirement(left_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (sort_expr, right_ordering) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (ordering.as_ref(), smj.right().output_ordering()) } JoinSide::Right => { - let right_eq_properties = - right_eq_properties.clone().with_reorder(sort_expr.clone()); - if right_eq_properties - .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) + let mut right_eq_properties = smj.right().equivalence_properties().clone(); + right_eq_properties.reorder(sort_exprs)?; + let Some(right_requirement) = smj_required_orderings.swap_remove(1) else { + return Ok(None); + }; + if !right_eq_properties + .ordering_satisfy_requirement(right_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (left_ordering, sort_expr) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (smj.left().output_ordering(), ordering.as_ref()) } JoinSide::None => return Ok(None), }; @@ -449,18 +449,19 @@ fn try_pushdown_requirements_to_join( new_left_ordering, new_right_ordering, join_type, - smj.on(), smj.left().schema().fields.len(), &smj.maintains_input_order(), Some(probe_side), - ); + )?; let mut smj_eqs = smj.properties().equivalence_properties().clone(); - // smj will have this ordering when its input changes. - smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default()); - let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); + if let Some(new_output_ordering) = new_output_ordering { + // smj will have this ordering when its input changes. + smj_eqs.reorder(new_output_ordering)?; + } + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required)?; Ok(should_pushdown.then(|| { let mut required_input_ordering = smj.required_input_ordering(); - let new_req = Some(LexRequirement::from(sort_expr.clone())); + let new_req = ordering.map(Into::into); match push_side { JoinSide::Left => { required_input_ordering[0] = new_req; @@ -475,77 +476,77 @@ fn try_pushdown_requirements_to_join( } fn expr_source_side( - required_exprs: &LexOrdering, + eqp: &EquivalenceProperties, + mut ordering: Vec, join_type: JoinType, left_columns_len: usize, -) -> Option { +) -> Option<(JoinSide, Vec)> { + // TODO: Handle the case where a prefix of the ordering comes from the left + // and a suffix from the right. match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full | JoinType::LeftMark => { - let all_column_sides = required_exprs - .iter() - .filter_map(|r| { - r.expr.as_any().downcast_ref::().map(|col| { - if col.index() < left_columns_len { - JoinSide::Left - } else { - JoinSide::Right + let eq_group = eqp.eq_group(); + let mut right_ordering = ordering.clone(); + let (mut valid_left, mut valid_right) = (true, true); + for (left, right) in ordering.iter_mut().zip(right_ordering.iter_mut()) { + let col = left.expr.as_any().downcast_ref::()?; + let eq_class = eq_group.get_equivalence_class(&left.expr); + if col.index() < left_columns_len { + if valid_right { + valid_right = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() >= left_columns_len) + { + right.expr = Arc::clone(expr); + return true; + } + } + false + }); + } + } else if valid_left { + valid_left = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() < left_columns_len) + { + left.expr = Arc::clone(expr); + return true; + } } - }) - }) - .collect::>(); - - // If the exprs are all coming from one side, the requirements can be pushed down - if all_column_sides.len() != required_exprs.len() { - None - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Left)) - { - Some(JoinSide::Left) - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Right)) - { - Some(JoinSide::Right) + false + }); + }; + if !(valid_left || valid_right) { + return None; + } + } + if valid_left { + Some((JoinSide::Left, ordering)) + } else if valid_right { + Some((JoinSide::Right, right_ordering)) } else { + // TODO: Handle the case where we can push down to both sides. None } } - JoinType::LeftSemi | JoinType::LeftAnti => required_exprs + JoinType::LeftSemi | JoinType::LeftAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Left), - JoinType::RightSemi | JoinType::RightAnti => required_exprs + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Left, ordering)), + JoinType::RightSemi | JoinType::RightAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Right), - } -} - -fn shift_right_required( - parent_required: &LexRequirement, - left_columns_len: usize, -) -> Result { - let new_right_required = parent_required - .iter() - .filter_map(|r| { - let col = r.expr.as_any().downcast_ref::()?; - col.index().checked_sub(left_columns_len).map(|offset| { - r.clone() - .with_expr(Arc::new(Column::new(col.name(), offset))) - }) - }) - .collect::>(); - if new_right_required.len() == parent_required.len() { - Ok(LexRequirement::new(new_right_required)) - } else { - plan_err!( - "Expect to shift all the parent required column indexes for SortMergeJoin" - ) + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Right, ordering)), } } @@ -565,16 +566,18 @@ fn shift_right_required( /// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. fn handle_custom_pushdown( plan: &Arc, - parent_required: &LexRequirement, + parent_required: OrderingRequirements, maintains_input_order: Vec, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children, return early - if parent_required.is_empty() || plan.children().is_empty() { +) -> Result>>> { + // If the plan has no children, return early: + if plan.children().is_empty() { return Ok(None); } - // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + // Collect all unique column indices used in the parent-required sorting + // expression: + let requirement = parent_required.into_single(); + let all_indices: HashSet = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -584,14 +587,14 @@ fn handle_custom_pushdown( }) .collect(); - // Get the number of fields in each child's schema - let len_of_child_schemas: Vec = plan + // Get the number of fields in each child's schema: + let children_schema_lengths: Vec = plan .children() .iter() .map(|c| c.schema().fields().len()) .collect(); - // Find the index of the child that maintains input order + // Find the index of the order-maintaining child: let Some(maintained_child_idx) = maintains_input_order .iter() .enumerate() @@ -601,26 +604,28 @@ fn handle_custom_pushdown( return Ok(None); }; - // Check if all required columns come from the child that maintains input order - let start_idx = len_of_child_schemas[..maintained_child_idx] + // Check if all required columns come from the order-maintaining child: + let start_idx = children_schema_lengths[..maintained_child_idx] .iter() .sum::(); - let end_idx = start_idx + len_of_child_schemas[maintained_child_idx]; + let end_idx = start_idx + children_schema_lengths[maintained_child_idx]; let all_from_maintained_child = all_indices.iter().all(|i| i >= &start_idx && i < &end_idx); - // If all columns are from the maintained child, update the parent requirements + // If all columns are from the maintained child, update the parent requirements: if all_from_maintained_child { - let sub_offset = len_of_child_schemas + let sub_offset = children_schema_lengths .iter() .take(maintained_child_idx) .sum::(); - // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + // Transform the parent-required expression for the child schema by + // adjusting columns: + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let new_index = col.index() - sub_offset; @@ -642,7 +647,8 @@ fn handle_custom_pushdown( .iter() .map(|&maintains_order| { if maintains_order { - Some(LexRequirement::new(updated_parent_req.clone())) + LexRequirement::new(updated_parent_req.clone()) + .map(OrderingRequirements::new) } else { None } @@ -659,16 +665,17 @@ fn handle_custom_pushdown( // for join type: Inner, Right, RightSemi, RightAnti fn handle_hash_join( plan: &HashJoinExec, - parent_required: &LexRequirement, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children - // or the join type is not Inner, Right, RightSemi, RightAnti, return early - if parent_required.is_empty() || !plan.maintains_input_order()[1] { + parent_required: OrderingRequirements, +) -> Result>>> { + // If the plan has no children or does not maintain the right side ordering, + // return early: + if !plan.maintains_input_order()[1] { return Ok(None); } // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + let requirement = parent_required.into_single(); + let all_indices: HashSet<_> = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -694,11 +701,12 @@ fn handle_hash_join( // If all columns are from the right child, update the parent requirements if all_from_right_child { // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let index = projected_indices[col.index()].index; @@ -718,7 +726,7 @@ fn handle_hash_join( // Populating with the updated requirements for children that maintain order Ok(Some(vec![ None, - Some(LexRequirement::new(updated_parent_req)), + LexRequirement::new(updated_parent_req).map(OrderingRequirements::new), ])) } else { Ok(None) @@ -757,7 +765,7 @@ enum RequirementsCompatibility { /// Requirements satisfy Satisfy, /// Requirements compatible - Compatible(Option), + Compatible(Option), /// Requirements not compatible NonCompatible, } diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 05758e5dfdf1..27eed7024157 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -459,7 +459,7 @@ fn hash_join_convert_symmetric_subrule( JoinSide::Right => hash_join.right().output_ordering(), JoinSide::None => unreachable!(), } - .map(|p| LexOrdering::new(p.to_vec())) + .cloned() }) .flatten() }; diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0488b3fd49a8..044d27811be6 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -30,16 +30,17 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; +use datafusion_physical_expr::Distribution; +use datafusion_physical_expr_common::sort_expr::OrderingRequirements; use datafusion_physical_plan::projection::{ - make_with_child, update_expr, ProjectionExec, + make_with_child, update_expr, update_ordering_requirement, ProjectionExec, }; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, }; -use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors @@ -94,7 +95,7 @@ enum RuleMode { #[derive(Debug)] pub struct OutputRequirementExec { input: Arc, - order_requirement: Option, + order_requirement: Option, dist_requirement: Distribution, cache: PlanProperties, } @@ -102,7 +103,7 @@ pub struct OutputRequirementExec { impl OutputRequirementExec { pub fn new( input: Arc, - requirements: Option, + requirements: Option, dist_requirement: Distribution, ) -> Self { let cache = Self::compute_properties(&input); @@ -176,7 +177,7 @@ impl ExecutionPlan for OutputRequirementExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.order_requirement.clone()] } @@ -212,23 +213,23 @@ impl ExecutionPlan for OutputRequirementExec { projection: &ProjectionExec, ) -> Result>> { // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + let proj_exprs = projection.expr(); + if proj_exprs.len() >= projection.input().schema().fields().len() { return Ok(None); } - let mut updated_sort_reqs = LexRequirement::new(vec![]); - // None or empty_vec can be treated in the same way. - if let Some(reqs) = &self.required_input_ordering()[0] { - for req in &reqs.inner { - let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? + let mut requirements = self.required_input_ordering().swap_remove(0); + if let Some(reqs) = requirements { + let mut updated_reqs = vec![]; + let (lexes, soft) = reqs.into_alternatives(); + for lex in lexes.into_iter() { + let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? else { return Ok(None); }; - updated_sort_reqs.push(PhysicalSortRequirement { - expr: new_expr, - options: req.options, - }); + updated_reqs.push(updated_lex); } + requirements = OrderingRequirements::new_alternatives(updated_reqs, soft); } let dist_req = match &self.required_input_distribution()[0] { @@ -246,15 +247,10 @@ impl ExecutionPlan for OutputRequirementExec { dist => dist.clone(), }; - make_with_child(projection, &self.input()) - .map(|input| { - OutputRequirementExec::new( - input, - (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), - dist_req, - ) - }) - .map(|e| Some(Arc::new(e) as _)) + make_with_child(projection, &self.input()).map(|input| { + let e = OutputRequirementExec::new(input, requirements, dist_req); + Some(Arc::new(e) as _) + }) } } @@ -317,17 +313,18 @@ fn require_top_ordering_helper( if children.len() != 1 { Ok((plan, false)) } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - // In case of constant columns, output ordering of SortExec would give an empty set. - // Therefore; we check the sort expression field of the SortExec to assign the requirements. + // In case of constant columns, output ordering of the `SortExec` would + // be an empty set. Therefore; we check the sort expression field to + // assign the requirements. let req_ordering = sort_exec.expr(); - let req_dist = sort_exec.required_input_distribution()[0].clone(); - let reqs = LexRequirement::from(req_ordering.clone()); + let req_dist = sort_exec.required_input_distribution().swap_remove(0); + let reqs = OrderingRequirements::from(req_ordering.clone()); Ok(( Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, true, )) } else if let Some(spm) = plan.as_any().downcast_ref::() { - let reqs = LexRequirement::from(spm.expr().clone()); + let reqs = OrderingRequirements::from(spm.expr().clone()); Ok(( Arc::new(OutputRequirementExec::new( plan, @@ -337,7 +334,9 @@ fn require_top_ordering_helper( true, )) } else if plan.maintains_input_order()[0] - && plan.required_input_ordering()[0].is_none() + && (plan.required_input_ordering()[0] + .as_ref() + .is_none_or(|o| matches!(o, OrderingRequirements::Soft(_)))) { // Keep searching for a `SortExec` as long as ordering is maintained, // and on-the-way operators do not themselves require an ordering. diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index 8edbb0f09114..acc70d39f057 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -137,7 +137,8 @@ pub fn check_plan_sanity( ) { let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { - if !child_eq_props.ordering_satisfy_requirement(&sort_req) { + let sort_req = sort_req.into_single(); + if !child_eq_props.ordering_satisfy_requirement(sort_req.clone())? { let plan_str = get_plan_string(&plan); return plan_err!( "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index faedea55ca15..bff0b1e49684 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -25,7 +25,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; @@ -131,7 +130,7 @@ impl TopKAggregation { Ok(Transformed::no(plan)) }; let child = Arc::clone(child).transform_down(closure).data().ok()?; - let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child) + let sort = SortExec::new(sort.expr().clone(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); Some(Arc::new(sort)) diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index ae1a38230d04..61bc715592af 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -24,16 +24,10 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; -use datafusion_physical_expr::LexRequirement; -use datafusion_physical_expr::{ - reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, -}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::aggregates::concat_slices; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_plan::aggregates::{concat_slices, AggregateExec}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; -use datafusion_physical_plan::{ - aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, -}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -91,32 +85,30 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); - let mut aggr_expr = aggr_exec.aggr_expr().to_vec(); + let mut aggr_exprs = aggr_exec.aggr_expr().to_vec(); let groupby_exprs = aggr_exec.group_expr().input_exprs(); // If the existing ordering satisfies a prefix of the GROUP BY // expressions, prefix requirements with this section. In this // case, aggregation will work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, input); + let indices = get_ordered_partition_by_indices(&groupby_exprs, input)?; let requirement = indices .iter() .map(|&idx| { PhysicalSortRequirement::new( - Arc::::clone( - &groupby_exprs[idx], - ), + Arc::clone(&groupby_exprs[idx]), None, ) }) .collect::>(); - aggr_expr = try_convert_aggregate_if_better( - aggr_expr, + aggr_exprs = try_convert_aggregate_if_better( + aggr_exprs, &requirement, input.equivalence_properties(), )?; - let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr); + let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_exprs); Ok(Transformed::yes(Arc::new(aggr_exec) as _)) } else { @@ -160,33 +152,30 @@ fn try_convert_aggregate_if_better( aggr_exprs .into_iter() .map(|aggr_expr| { - let aggr_sort_exprs = aggr_expr - .order_bys() - .unwrap_or_else(|| LexOrdering::empty()); - let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); - let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone()); - let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs); - + let order_bys = aggr_expr.order_bys(); // If the aggregate expression benefits from input ordering, and // there is an actual ordering enabling this, try to update the // aggregate expression to benefit from the existing ordering. // Otherwise, leave it as is. - if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() - { - let reqs = LexRequirement { - inner: concat_slices(prefix_requirement, &aggr_sort_reqs), - }; - - let prefix_requirement = LexRequirement { - inner: prefix_requirement.to_vec(), - }; - - if eq_properties.ordering_satisfy_requirement(&reqs) { + if !aggr_expr.order_sensitivity().is_beneficial() { + Ok(aggr_expr) + } else if !order_bys.is_empty() { + if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.clone().into()) + .collect::>(), + ))? { // Existing ordering satisfies the aggregator requirements: aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) - } else if eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: concat_slices(&prefix_requirement, &reverse_aggr_req), - }) { + } else if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.reverse().into()) + .collect::>(), + ))? { // Converting to reverse enables more efficient execution // given the existing ordering (if possible): aggr_expr diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 57a193315a5c..3655e555a744 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -17,8 +17,8 @@ use std::sync::Arc; -use datafusion_physical_expr::LexRequirement; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::Result; +use datafusion_physical_expr::{LexOrdering, LexRequirement}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -40,14 +40,18 @@ pub fn add_sort_above( sort_requirements: LexRequirement, fetch: Option, ) -> PlanContext { - let mut sort_expr = LexOrdering::from(sort_requirements); - sort_expr.retain(|sort_expr| { - !node - .plan + let mut sort_reqs: Vec<_> = sort_requirements.into(); + sort_reqs.retain(|sort_expr| { + node.plan .equivalence_properties() .is_expr_constant(&sort_expr.expr) + .is_none() }); - let mut new_sort = SortExec::new(sort_expr, Arc::clone(&node.plan)).with_fetch(fetch); + let sort_exprs = sort_reqs.into_iter().map(Into::into).collect::>(); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return node; + }; + let mut new_sort = SortExec::new(ordering, Arc::clone(&node.plan)).with_fetch(fetch); if node.plan.output_partitioning().partition_count() > 1 { new_sort = new_sort.with_preserve_partitioning(true); } @@ -61,15 +65,15 @@ pub fn add_sort_above_with_check( node: PlanContext, sort_requirements: LexRequirement, fetch: Option, -) -> PlanContext { +) -> Result> { if !node .plan .equivalence_properties() - .ordering_satisfy_requirement(&sort_requirements) + .ordering_satisfy_requirement(sort_requirements.clone())? { - add_sort_above(node, sort_requirements, fetch) + Ok(add_sort_above(node, sort_requirements, fetch)) } else { - node + Ok(node) } } diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs index 22d18dd24891..e1a9d0b583e9 100644 --- a/datafusion/physical-plan/benches/partial_ordering.rs +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -18,11 +18,10 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; -use arrow_schema::{DataType, Field, Schema, SortOptions}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; +use criterion::{criterion_group, criterion_main, Criterion}; + const BATCH_SIZE: usize = 8192; fn create_test_arrays(num_columns: usize) -> Vec { @@ -39,22 +38,7 @@ fn bench_new_groups(c: &mut Criterion) { // Test with 1, 2, 4, and 8 order indices for num_columns in [1, 2, 4, 8] { - let fields: Vec = (0..num_columns) - .map(|i| Field::new(format!("col{i}"), DataType::Int32, false)) - .collect(); - let schema = Schema::new(fields); - let order_indices: Vec = (0..num_columns).collect(); - let ordering = LexOrdering::new( - (0..num_columns) - .map(|i| { - PhysicalSortExpr::new( - col(&format!("col{i}"), &schema).unwrap(), - SortOptions::default(), - ) - }) - .collect(), - ); group.bench_function(format!("order_indices_{num_columns}"), |b| { let batch_group_values = create_test_arrays(num_columns); @@ -62,8 +46,7 @@ fn bench_new_groups(c: &mut Criterion) { b.iter(|| { let mut ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering) - .unwrap(); + GroupOrderingPartial::try_new(order_indices.clone()).unwrap(); ordering .new_groups(&batch_group_values, &group_indices, BATCH_SIZE) .unwrap(); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 656c9a2cd5cb..14b2d0a932c2 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -42,13 +42,16 @@ use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Res use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains, - ConstExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortRequirement, + physical_exprs_contains, ConstExpr, EquivalenceProperties, +}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use itertools::Itertools; pub(crate) mod group_values; @@ -397,7 +400,7 @@ pub struct AggregateExec { pub input_schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - required_input_ordering: Option, + required_input_ordering: Option, /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, cache: PlanProperties, @@ -484,16 +487,13 @@ impl AggregateExec { // If existing ordering satisfies a prefix of the GROUP BY expressions, // prefix requirements with this section. In this case, aggregation will // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); - let mut new_requirement = LexRequirement::new( - indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&groupby_exprs[idx]), - options: None, - }) - .collect::>(), - ); + let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?; + let mut new_requirements = indices + .iter() + .map(|&idx| { + PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None) + }) + .collect::>(); let req = get_finer_aggregate_exprs_requirement( &mut aggr_expr, @@ -501,8 +501,10 @@ impl AggregateExec { input_eq_properties, &mode, )?; - new_requirement.inner.extend(req); - new_requirement = new_requirement.collapse(); + new_requirements.extend(req); + + let required_input_ordering = + LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft); // If our aggregation has grouping sets then our base grouping exprs will // be expanded based on the flags in `group_by.groups` where for each @@ -527,10 +529,7 @@ impl AggregateExec { // construct a map from the input expression to the output expression of the Aggregation group by let group_expr_mapping = - ProjectionMapping::try_new(&group_by.expr, &input.schema())?; - - let required_input_ordering = - (!new_requirement.is_empty()).then_some(new_requirement); + ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?; let cache = Self::compute_properties( &input, @@ -539,7 +538,7 @@ impl AggregateExec { &mode, &input_order_mode, aggr_expr.as_slice(), - ); + )?; Ok(AggregateExec { mode, @@ -654,7 +653,7 @@ impl AggregateExec { return false; } // ensure there are no order by expressions - if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { + if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) { return false; } // ensure there is no output ordering; can this rule be relaxed? @@ -662,8 +661,8 @@ impl AggregateExec { return false; } // ensure no ordering is required on the input - if self.required_input_ordering()[0].is_some() { - return false; + if let Some(requirement) = self.required_input_ordering().swap_remove(0) { + return matches!(requirement, OrderingRequirements::Hard(_)); } true } @@ -676,7 +675,7 @@ impl AggregateExec { mode: &AggregateMode, input_order_mode: &InputOrderMode, aggr_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Construct equivalence properties: let mut eq_properties = input .equivalence_properties() @@ -684,13 +683,12 @@ impl AggregateExec { // If the group by is empty, then we ensure that the operator will produce // only one row, and mark the generated result as a constant value. - if group_expr_mapping.map.is_empty() { - let mut constants = eq_properties.constants().to_vec(); + if group_expr_mapping.is_empty() { let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| { - ConstExpr::new(Arc::new(Column::new(func.name(), idx))) + let column = Arc::new(Column::new(func.name(), idx)); + ConstExpr::from(column as Arc) }); - constants.extend(new_constants); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(new_constants)?; } // Group by expression will be a distinct value after the aggregation. @@ -698,13 +696,11 @@ impl AggregateExec { let mut constraints = eq_properties.constraints().to_vec(); let new_constraint = Constraint::Unique( group_expr_mapping - .map .iter() - .filter_map(|(_, target_col)| { - target_col - .as_any() - .downcast_ref::() - .map(|c| c.index()) + .flat_map(|(_, target_cols)| { + target_cols.iter().flat_map(|(expr, _)| { + expr.as_any().downcast_ref::().map(|c| c.index()) + }) }) .collect(), ); @@ -731,12 +727,12 @@ impl AggregateExec { input.pipeline_behavior() }; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type, input.boundedness(), - ) + )) } pub fn input_order_mode(&self) -> &InputOrderMode { @@ -958,7 +954,7 @@ impl ExecutionPlan for AggregateExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone()] } @@ -1075,16 +1071,14 @@ fn get_aggregate_expr_req( aggr_expr: &AggregateFunctionExpr, group_by: &PhysicalGroupBy, agg_mode: &AggregateMode, -) -> LexOrdering { +) -> Option { // If the aggregation function is ordering requirement is not absolutely // necessary, or the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { - return LexOrdering::default(); + return None; } - - let mut req = aggr_expr.order_bys().cloned().unwrap_or_default(); - + let mut sort_exprs = aggr_expr.order_bys().to_vec(); // In non-first stage modes, we accumulate data (using `merge_batch`) from // different partitions (i.e. merge partial results). During this merge, we // consider the ordering of each partial result. Hence, we do not need to @@ -1095,38 +1089,11 @@ fn get_aggregate_expr_req( // will definitely be satisfied -- Each group by expression will have // distinct values per group, hence all requirements are satisfied. let physical_exprs = group_by.input_exprs(); - req.retain(|sort_expr| { + sort_exprs.retain(|sort_expr| { !physical_exprs_contains(&physical_exprs, &sort_expr.expr) }); } - req -} - -/// Computes the finer ordering for between given existing ordering requirement -/// of aggregate expression. -/// -/// # Parameters -/// -/// * `existing_req` - The existing lexical ordering that needs refinement. -/// * `aggr_expr` - A reference to an aggregate expression trait object. -/// * `group_by` - Information about the physical grouping (e.g group by expression). -/// * `eq_properties` - Equivalence properties relevant to the computation. -/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). -/// -/// # Returns -/// -/// An `Option` representing the computed finer lexical ordering, -/// or `None` if there is no finer ordering; e.g. the existing requirement and -/// the aggregator requirement is incompatible. -fn finer_ordering( - existing_req: &LexOrdering, - aggr_expr: &AggregateFunctionExpr, - group_by: &PhysicalGroupBy, - eq_properties: &EquivalenceProperties, - agg_mode: &AggregateMode, -) -> Option { - let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); - eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref()) + LexOrdering::new(sort_exprs) } /// Concatenates the given slices. @@ -1134,7 +1101,23 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { [lhs, rhs].concat() } -/// Get the common requirement that satisfies all the aggregate expressions. +// Determines if the candidate ordering is finer than the current ordering. +// Returns `None` if they are incomparable, `Some(true)` if there is no current +// ordering or candidate ordering is finer, and `Some(false)` otherwise. +fn determine_finer( + current: &Option, + candidate: &LexOrdering, +) -> Option { + if let Some(ordering) = current { + candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt()) + } else { + Some(true) + } +} + +/// Gets the common requirement that satisfies all the aggregate expressions. +/// When possible, chooses the requirement that is already satisfied by the +/// equivalence properties. /// /// # Parameters /// @@ -1149,75 +1132,75 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// /// # Returns /// -/// A `LexRequirement` instance, which is the requirement that satisfies all the -/// aggregate requirements. Returns an error in case of conflicting requirements. +/// A `Result>` instance, which is the requirement +/// that satisfies all the aggregate requirements. Returns an error in case of +/// conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, -) -> Result { - let mut requirement = LexOrdering::default(); +) -> Result> { + let mut requirement = None; for aggr_expr in aggr_exprs.iter_mut() { - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Requirement is satisfied by existing ordering - requirement = finer_ordering; + let Some(aggr_req) = get_aggregate_expr_req(aggr_expr, group_by, agg_mode) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) + else { + // There is no aggregate ordering requirement, or it is trivially + // satisfied -- we can skip this expression. + continue; + }; + // If the common requirement is finer than the current expression's, + // we can skip this expression. If the latter is finer than the former, + // adopt it if it is satisfied by the equivalence properties. Otherwise, + // defer the analysis to the reverse expression. + let forward_finer = determine_finer(&requirement, &aggr_req); + if let Some(finer) = forward_finer { + if !finer { + continue; + } else if eq_properties.ordering_satisfy(aggr_req.clone())? { + requirement = Some(aggr_req); continue; } } if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, - group_by, - eq_properties, - agg_mode, - ) { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Reverse requirement is satisfied by exiting ordering. - // Hence reverse the aggregator - requirement = finer_ordering; - *aggr_expr = Arc::new(reverse_aggr_expr); - continue; - } - } - } - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - // There is a requirement that both satisfies existing requirement and current - // aggregate requirement. Use updated requirement - requirement = finer_ordering; - continue; - } - if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, - group_by, - eq_properties, - agg_mode, - ) { - // There is a requirement that both satisfies existing requirement and reverse - // aggregate requirement. Use updated requirement - requirement = finer_ordering; + let Some(rev_aggr_req) = + get_aggregate_expr_req(&reverse_aggr_expr, group_by, agg_mode) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) + else { + // The reverse requirement is trivially satisfied -- just reverse + // the expression and continue with the next one: *aggr_expr = Arc::new(reverse_aggr_expr); continue; + }; + // If the common requirement is finer than the reverse expression's, + // just reverse it and continue the loop with the next aggregate + // expression. If the latter is finer than the former, adopt it if + // it is satisfied by the equivalence properties. Otherwise, adopt + // the forward expression. + if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) { + if !finer { + *aggr_expr = Arc::new(reverse_aggr_expr); + } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? { + *aggr_expr = Arc::new(reverse_aggr_expr); + requirement = Some(rev_aggr_req); + } else { + requirement = Some(aggr_req); + } + } else if forward_finer.is_some() { + requirement = Some(aggr_req); + } else { + // Neither the existing requirement nor the current aggregate + // requirement satisfy the other (forward or reverse), this + // means they are conflicting. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); } } - - // Neither the existing requirement and current aggregate requirement satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); } - Ok(LexRequirement::from(requirement)) + Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect())) } /// Returns physical expressions for arguments to evaluate against a batch. @@ -1240,9 +1223,7 @@ pub fn aggregate_expressions( // Append ordering requirements to expressions' results. This // way order sensitive aggregators can satisfy requirement // themselves. - if let Some(ordering_req) = agg.order_bys() { - result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); - } + result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr))); result }) .collect()), @@ -2249,14 +2230,14 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2268,13 +2249,13 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2396,9 +2377,7 @@ mod tests { async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; - // Assume column a and b are aliases - // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). - let options1 = SortOptions { + let options = SortOptions { descending: false, nulls_first: false, }; @@ -2407,58 +2386,51 @@ mod tests { let col_c = &col("c", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. - eq_properties.add_equal_conditions(col_a, col_b)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let order_by_exprs = vec![ - None, - Some(vec![PhysicalSortExpr { + vec![], + vec![PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, - }]), - Some(vec![ + options, + }], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_c), - options: options1, + options, }, - ]), - Some(vec![ + ], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, - ]), + ], ]; - let common_requirement = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: options1, - }, - PhysicalSortExpr { - expr: Arc::clone(col_c), - options: options1, - }, - ]); + let common_requirement = vec![ + PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)), + PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)), + ]; let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { - let ordering_req = order_by_expr.unwrap_or_default(); AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) .alias("a") - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_by_expr) .schema(Arc::clone(&test_schema)) .build() .map(Arc::new) @@ -2466,14 +2438,13 @@ mod tests { }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); - let res = get_finer_aggregate_exprs_requirement( + let result = get_finer_aggregate_exprs_requirement( &mut aggr_exprs, &group_by, &eq_properties, &AggregateMode::Partial, )?; - let res = LexOrdering::from(res); - assert_eq!(res, common_requirement); + assert_eq!(result, common_requirement); Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 0b742b3d20fd..bbcb30d877cf 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; + use arrow::array::ArrayRef; -use arrow::datatypes::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::mem::size_of; mod full; mod partial; @@ -42,15 +41,11 @@ pub enum GroupOrdering { impl GroupOrdering { /// Create a `GroupOrdering` for the specified ordering - pub fn try_new( - input_schema: &Schema, - mode: &InputOrderMode, - ordering: &LexOrdering, - ) -> Result { + pub fn try_new(mode: &InputOrderMode) -> Result { match mode { InputOrderMode::Linear => Ok(GroupOrdering::None), InputOrderMode::PartiallySorted(order_indices) => { - GroupOrderingPartial::try_new(input_schema, order_indices, ordering) + GroupOrderingPartial::try_new(order_indices.clone()) .map(GroupOrdering::Partial) } InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index c7a75e5f2640..3e495900f77a 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; +use std::mem::size_of; +use std::sync::Arc; + use arrow::array::ArrayRef; use arrow::compute::SortOptions; -use arrow::datatypes::Schema; use arrow_ord::partition::partition; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::cmp::Ordering; -use std::mem::size_of; -use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of /// the group keys. @@ -118,17 +117,11 @@ impl State { impl GroupOrderingPartial { /// TODO: Remove unnecessary `input_schema` parameter. - pub fn try_new( - _input_schema: &Schema, - order_indices: &[usize], - ordering: &LexOrdering, - ) -> Result { - assert!(!order_indices.is_empty()); - assert!(order_indices.len() <= ordering.len()); - + pub fn try_new(order_indices: Vec) -> Result { + debug_assert!(!order_indices.is_empty()); Ok(Self { state: State::Start, - order_indices: order_indices.to_vec(), + order_indices, }) } @@ -276,29 +269,15 @@ impl GroupOrderingPartial { #[cfg(test)] mod tests { - use arrow::array::Int32Array; - use arrow_schema::{DataType, Field}; - use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; - use super::*; + use arrow::array::Int32Array; + #[test] fn test_group_ordering_partial() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - // Ordered on column a let order_indices = vec![0]; - - let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( - col("a", &schema)?, - SortOptions::default(), - )]); - - let mut group_ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?; + let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?; let batch_group_values: Vec = vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 62f541443068..405c0b205978 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; +use super::order::GroupOrdering; +use super::AggregateExec; use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ @@ -32,11 +34,10 @@ use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, metrics, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; @@ -44,13 +45,11 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; - -use super::order::GroupOrdering; -use super::AggregateExec; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; + use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -519,15 +518,20 @@ impl GroupedHashAggregateStream { let partial_agg_schema = Arc::new(partial_agg_schema); - let spill_expr = group_schema - .fields - .into_iter() - .enumerate() - .map(|(idx, field)| PhysicalSortExpr { - expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, - options: SortOptions::default(), - }) - .collect(); + let spill_expr = + group_schema + .fields + .into_iter() + .enumerate() + .map(|(idx, field)| { + PhysicalSortExpr::new_default(Arc::new(Column::new( + field.name().as_str(), + idx, + )) as _) + }); + let Some(spill_expr) = LexOrdering::new(spill_expr) else { + return internal_err!("Spill expression is empty"); + }; let agg_fn_names = aggregate_exprs .iter() @@ -538,16 +542,7 @@ impl GroupedHashAggregateStream { let reservation = MemoryConsumer::new(name) .with_can_spill(true) .register(context.memory_pool()); - let (ordering, _) = agg - .properties() - .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs()); - let group_ordering = GroupOrdering::try_new( - &group_schema, - &agg.input_order_mode, - ordering.as_ref(), - )?; - + let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?; let group_values = new_group_values(group_schema, &group_ordering)?; timer.done(); @@ -1001,7 +996,7 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; + let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; // Spill sorted state to disk let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size( @@ -1068,7 +1063,7 @@ impl GroupedHashAggregateStream { streams.push(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, expr.as_ref(), None) + sort_batch(&batch, &expr, None) })), ))); for spill in self.spill_state.spills.drain(..) { @@ -1079,7 +1074,7 @@ impl GroupedHashAggregateStream { self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) - .with_expressions(self.spill_state.spill_expr.as_ref()) + .with_expressions(&self.spill_state.spill_expr) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) .with_reservation(self.reservation.new_empty()) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index f555755dd20a..56335f13d01b 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -1034,27 +1034,22 @@ impl fmt::Display for ProjectSchemaDisplay<'_> { } pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::Result { - if let Some(ordering) = orderings.first() { - if !ordering.is_empty() { - let start = if orderings.len() == 1 { - ", output_ordering=" - } else { - ", output_orderings=[" - }; - write!(f, "{start}")?; - for (idx, ordering) in - orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) - { - match idx { - 0 => write!(f, "[{ordering}]")?, - _ => write!(f, ", [{ordering}]")?, - } + if !orderings.is_empty() { + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{start}")?; + for (idx, ordering) in orderings.iter().enumerate() { + match idx { + 0 => write!(f, "[{ordering}]")?, + _ => write!(f, ", [{ordering}]")?, } - let end = if orderings.len() == 1 { "" } else { "]" }; - write!(f, "{end}")?; } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{end}")?; } - Ok(()) } diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index b81b3c8beeac..73b405b38471 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -51,8 +51,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{exec_err, Constraints, Result}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use futures::stream::{StreamExt, TryStreamExt}; @@ -139,7 +139,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// NOTE that checking `!is_empty()` does **not** check for a /// required input ordering. Instead, the correct check is that at /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![None; self.children().len()] } @@ -1153,17 +1153,17 @@ pub enum CardinalityEffect { #[cfg(test)] mod tests { - use super::*; - use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; + use super::*; + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + + use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; - #[derive(Debug)] pub struct EmptyExec; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 13129e382dec..25eb98a61b2e 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -233,24 +233,18 @@ impl FilterExec { if let Some(binary) = conjunction.as_any().downcast_ref::() { if binary.op() == &Operator::Eq { // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()) { - let (expr, across_parts) = ( - binary.right(), - input_eqs.get_expr_constant_value(binary.right()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); - } else if input_eqs.is_expr_constant(binary.right()) { - let (expr, across_parts) = ( - binary.left(), - input_eqs.get_expr_constant_value(binary.left()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); + if input_eqs.is_expr_constant(binary.left()).is_some() { + let across = input_eqs + .is_expr_constant(binary.right()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.right()), across)); + } else if input_eqs.is_expr_constant(binary.right()).is_some() { + let across = input_eqs + .is_expr_constant(binary.left()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.left()), across)); } } } @@ -275,7 +269,7 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(lhs, rhs)? + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))? } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -287,15 +281,13 @@ impl FilterExec { .min_value .get_value(); let expr = Arc::new(column) as _; - ConstExpr::new(expr) - .with_across_partitions(AcrossPartitions::Uniform(value.cloned())) + ConstExpr::new(expr, AcrossPartitions::Uniform(value.cloned())) }); // This is for statistics - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) - eq_properties = - eq_properties.with_constants(Self::extend_constants(input, predicate)); + eq_properties.add_constants(Self::extend_constants(input, predicate))?; let mut output_partitioning = input.output_partitioning().clone(); // If contains projection, update the PlanProperties. diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4d8c48c659ef..e4d554ceb62c 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -115,7 +115,7 @@ impl CrossJoinExec { }; let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); - let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap(); CrossJoinExec { left, @@ -142,7 +142,7 @@ impl CrossJoinExec { left: &Arc, right: &Arc, schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties // TODO: Check equivalence properties of cross join, it may preserve // ordering in some cases. @@ -154,7 +154,7 @@ impl CrossJoinExec { &[false, false], None, &[], - ); + )?; // Get output partitioning: // TODO: Optimize the cross join implementation to generate M * N @@ -162,14 +162,14 @@ impl CrossJoinExec { let output_partitioning = adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields.len(), - ); + )?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Final, boundedness_from_children([left, right]), - ) + )) } /// Returns a new `ExecutionPlan` that computes the same join as this one, diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 398c2fed7cdf..96cb09b4cb81 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -532,17 +532,17 @@ impl HashJoinExec { &Self::maintains_input_order(join_type), Some(Self::probe_side()), on, - ); + )?; let mut output_partitioning = match mode { PartitionMode::CollectLeft => { - asymmetric_join_output_partitioning(left, right, &join_type) + asymmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), PartitionMode::Partitioned => { - symmetric_join_output_partitioning(left, right, &join_type) + symmetric_join_output_partitioning(left, right, &join_type)? } }; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index f87cf3d8864c..c9b5b9e43b6f 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -259,10 +259,10 @@ impl NestedLoopJoinExec { None, // No on columns in nested loop join &[], - ); + )?; let mut output_partitioning = - asymmetric_join_output_partitioning(left, right, &join_type); + asymmetric_join_output_partitioning(left, right, &join_type)?; let emission_type = if left.boundedness().is_unbounded() { EmissionType::Final @@ -1088,22 +1088,18 @@ pub(crate) mod tests { vec![batch] }; - let mut source = - TestMemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); - if !sorted_column_names.is_empty() { - let mut sort_info = LexOrdering::default(); - for name in sorted_column_names { - let index = schema.index_of(name).unwrap(); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new(name, index)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }; - sort_info.push(sort_expr); - } - source = source.try_with_sort_information(vec![sort_info]).unwrap(); + let mut sort_info = vec![]; + for name in sorted_column_names { + let index = schema.index_of(name).unwrap(); + let sort_expr = PhysicalSortExpr::new( + Arc::new(Column::new(name, index)), + SortOptions::new(false, false), + ); + sort_info.push(sort_expr); + } + let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap(); + if let Some(ordering) = LexOrdering::new(sort_info) { + source = source.try_with_sort_information(vec![ordering]).unwrap(); } Arc::new(TestMemoryExec::update_cache(Arc::new(source))) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index cadd2b53ab11..eab3796b150f 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -70,9 +70,8 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use futures::{Stream, StreamExt}; @@ -193,11 +192,21 @@ impl SortMergeJoinExec { (left, right) }) .unzip(); + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its right side" + ); + }; let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?; Ok(Self { left, right, @@ -206,8 +215,8 @@ impl SortMergeJoinExec { join_type, schema, metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs: LexOrdering::new(left_sort_exprs), - right_sort_exprs: LexOrdering::new(right_sort_exprs), + left_sort_exprs, + right_sort_exprs, sort_options, null_equals_null, cache, @@ -289,7 +298,7 @@ impl SortMergeJoinExec { schema: SchemaRef, join_type: JoinType, join_on: JoinOnRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: let eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), @@ -299,17 +308,17 @@ impl SortMergeJoinExec { &Self::maintains_input_order(join_type), Some(Self::probe_side(&join_type)), join_on, - ); + )?; let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); + symmetric_join_output_partitioning(left, right, &join_type)?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Incremental, boundedness_from_children([left, right]), - ) + )) } pub fn swap_inputs(&self) -> Result> { @@ -409,10 +418,10 @@ impl ExecutionPlan for SortMergeJoinExec { ] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![ - Some(LexRequirement::from(self.left_sort_exprs.clone())), - Some(LexRequirement::from(self.right_sort_exprs.clone())), + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), ] } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 819a3302b062..95497f50930e 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -73,9 +73,8 @@ use datafusion_execution::TaskContext; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use ahash::RandomState; use futures::{ready, Stream, StreamExt}; @@ -237,8 +236,7 @@ impl SymmetricHashJoinExec { // Initialize the random state for the join operation: let random_state = RandomState::with_seeds(0, 0, 0, 0); let schema = Arc::new(schema); - let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on); + let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?; Ok(SymmetricHashJoinExec { left, right, @@ -263,7 +261,7 @@ impl SymmetricHashJoinExec { schema: SchemaRef, join_type: JoinType, join_on: JoinOnRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: let eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), @@ -274,17 +272,17 @@ impl SymmetricHashJoinExec { // Has alternating probe side None, join_on, - ); + )?; let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); + symmetric_join_output_partitioning(left, right, &join_type)?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type_from_children([left, right]), boundedness_from_children([left, right]), - ) + )) } /// left stream @@ -433,16 +431,14 @@ impl ExecutionPlan for SymmetricHashJoinExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![ self.left_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), self.right_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), ] } @@ -635,21 +631,18 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.right(), )?; - Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + SymmetricHashJoinExec::try_new( Arc::new(new_left), Arc::new(new_right), new_on, new_filter, self.join_type(), self.null_equals_null(), - self.right() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), - self.left() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + self.right().output_ordering().cloned(), + self.left().output_ordering().cloned(), self.partition_mode(), - )?))) + ) + .map(|e| Some(Arc::new(e) as _)) } } @@ -1743,7 +1736,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use rstest::*; @@ -1843,7 +1836,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: binary( col("la1", left_schema)?, Operator::Plus, @@ -1851,11 +1844,13 @@ mod tests { left_schema, )?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1923,14 +1918,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2068,20 +2065,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2127,20 +2126,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2186,20 +2187,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2247,20 +2250,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2309,15 +2314,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2368,20 +2374,23 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let left_sorted = vec![ - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]), - LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(), + [PhysicalSortExpr { expr: col("la2", left_schema)?, options: SortOptions::default(), - }]), + }] + .into(), ]; - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, @@ -2449,20 +2458,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("rt1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2532,20 +2543,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ri1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2608,14 +2621,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_float", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_float", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 81f56c865f04..cbabd7cb4510 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -102,10 +102,8 @@ pub async fn partitioned_sym_join_with_filter( filter, join_type, null_equals_null, - left.output_ordering().map(|p| LexOrdering::new(p.to_vec())), - right - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + left.output_ordering().cloned(), + right.output_ordering().cloned(), StreamJoinPartitionMode::Partitioned, )?; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 3abeff6621d2..f473f82ee26f 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -25,7 +25,9 @@ use std::ops::Range; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::joins::SharedBitmapBuilder; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::projection::ProjectionExec; use crate::{ ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, }; @@ -45,20 +47,17 @@ use arrow::datatypes::{ }; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; -use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr, + PhysicalExprRef, }; -use crate::joins::SharedBitmapBuilder; -use crate::projection::ProjectionExec; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; @@ -114,113 +113,84 @@ fn check_join_set_is_valid( pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, left_columns_len: usize, -) -> Partitioning { - match right_partitioning { +) -> Result { + let result = match right_partitioning { Partitioning::Hash(exprs, size) => { let new_exprs = exprs .iter() - .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len)) - .collect(); + .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _)) + .collect::>()?; Partitioning::Hash(new_exprs, *size) } result => result.clone(), - } -} - -/// Replaces the right column (first index in the `on_column` tuple) with -/// the left column (zeroth index in the tuple) inside `right_ordering`. -fn replace_on_columns_of_right_ordering( - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], - right_ordering: &mut LexOrdering, -) -> Result<()> { - for (left_col, right_col) in on_columns { - right_ordering.transform(|item| { - let new_expr = Arc::clone(&item.expr) - .transform(|e| { - if e.eq(right_col) { - Ok(Transformed::yes(Arc::clone(left_col))) - } else { - Ok(Transformed::no(e)) - } - }) - .data() - .expect("closure is infallible"); - item.expr = new_expr; - }); - } - Ok(()) -} - -fn offset_ordering( - ordering: &LexOrdering, - join_type: &JoinType, - offset: usize, -) -> LexOrdering { - match join_type { - // In the case below, right ordering should be offsetted with the left - // side length, since we append the right table to the left table. - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering - .iter() - .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset), - options: sort_expr.options, - }) - .collect(), - _ => ordering.clone(), - } + }; + Ok(result) } /// Calculate the output ordering of a given join operation. pub fn calculate_join_output_ordering( - left_ordering: &LexOrdering, - right_ordering: &LexOrdering, + left_ordering: Option<&LexOrdering>, + right_ordering: Option<&LexOrdering>, join_type: JoinType, - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], left_columns_len: usize, maintains_input_order: &[bool], probe_side: Option, -) -> Option { - let output_ordering = match maintains_input_order { +) -> Result> { + match maintains_input_order { [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - left_ordering, - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - ) - } else { - left_ordering.clone() + if let Some(right_ordering) = right_ordering.cloned() { + let right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + return if let Some(left_ordering) = left_ordering { + let mut result = left_ordering.clone(); + result.extend(right_offset); + Ok(Some(result)) + } else { + Ok(LexOrdering::new(right_offset)) + }; + } } + Ok(left_ordering.cloned()) } [false, true] => { // Special case, we can prefix ordering of left side with the ordering of right side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - left_ordering, - ) - } else { - offset_ordering(right_ordering, &join_type, left_columns_len) + return if let Some(right_ordering) = right_ordering.cloned() { + let mut right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + if let Some(left_ordering) = left_ordering { + right_offset.extend(left_ordering.clone()); + } + Ok(LexOrdering::new(right_offset)) + } else { + Ok(left_ordering.cloned()) + }; + } + let Some(right_ordering) = right_ordering else { + return Ok(None); + }; + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + add_offset_to_physical_sort_exprs( + right_ordering.clone(), + left_columns_len as _, + ) + .map(LexOrdering::new) + } + _ => Ok(Some(right_ordering.clone())), } } // Doesn't maintain ordering, output ordering is None. - [false, false] => return None, + [false, false] => Ok(None), [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), - }; - (!output_ordering.is_empty()).then_some(output_ordering) + } } /// Information about the index and placement (left or right) of the columns @@ -1299,35 +1269,36 @@ pub(crate) fn symmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { +) -> Result { let left_columns_len = left.schema().fields.len(); let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); - match join_type { + let result = match join_type { JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), JoinType::Inner | JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) + adjust_right_output_partitioning(right_partitioning, left_columns_len)? } JoinType::Full => { // We could also use left partition count as they are necessarily equal. Partitioning::UnknownPartitioning(right_partitioning.partition_count()) } - } + }; + Ok(result) } pub(crate) fn asymmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { - match join_type { +) -> Result { + let result = match join_type { JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields().len(), - ), + )?, JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), JoinType::Left | JoinType::LeftSemi @@ -1336,7 +1307,8 @@ pub(crate) fn asymmetric_join_output_partitioning( | JoinType::LeftMark => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), - } + }; + Ok(result) } /// Trait for incrementally generating Join output. @@ -1503,16 +1475,17 @@ pub(super) fn swap_join_projection( #[cfg(test)] mod tests { - use super::*; use std::collections::HashMap; use std::pin::Pin; + use super::*; + use arrow::array::Int32Array; - use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_physical_expr::PhysicalSortExpr; use rstest::rstest; @@ -2322,85 +2295,35 @@ mod tests { #[test] fn test_calculate_join_output_ordering() -> Result<()> { - let options = SortOptions::default(); let left_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), ]); let right_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))), ]); let join_type = JoinType::Inner; - let on_columns = [( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("x", 0)) as _, - )]; let left_columns_len = 5; let maintains_input_orders = [[true, false], [false, true]]; let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; let expected = [ - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - ])), - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - ])), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + ]), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + ]), ]; for (i, (maintains_input_order, probe_side)) in @@ -2411,11 +2334,10 @@ mod tests { left_ordering.as_ref(), right_ordering.as_ref(), join_type, - &on_columns, left_columns_len, maintains_input_order, probe_side, - ), + )?, expected[i] ); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index f1621acd0deb..a29f4aeb4090 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -46,11 +46,10 @@ use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; -use itertools::Itertools; use log::trace; /// Execution plan for a projection @@ -98,7 +97,7 @@ impl ProjectionExec { )); // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(expr.clone(), &input_schema)?; let cache = Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; Ok(Self { @@ -127,14 +126,12 @@ impl ProjectionExec { schema: SchemaRef, ) -> Result { // Calculate equivalence properties: - let mut input_eq_properties = input.equivalence_properties().clone(); - input_eq_properties.substitute_oeq_class(projection_mapping)?; + let input_eq_properties = input.equivalence_properties(); let eq_properties = input_eq_properties.project(projection_mapping, schema); - // Calculate output partitioning, which needs to respect aliases: - let input_partition = input.output_partitioning(); - let output_partitioning = - input_partition.project(projection_mapping, &input_eq_properties); + let output_partitioning = input + .output_partitioning() + .project(projection_mapping, input_eq_properties); Ok(PlanProperties::new( eq_properties, @@ -622,7 +619,7 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr: Arc| { + .transform_up(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -666,6 +663,42 @@ pub fn update_expr( new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } +/// Updates the given lexicographic ordering according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering( + ordering: LexOrdering, + projected_exprs: &[(Arc, String)], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in ordering.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexOrdering::new(updated_exprs)) +} + +/// Updates the given lexicographic requirement according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering_requirement( + reqs: LexRequirement, + projected_exprs: &[(Arc, String)], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in reqs.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexRequirement::new(updated_exprs)) +} + /// Downcasts all the expressions in `exprs` to `Column`s. If any of the given /// expressions is not a `Column`, returns `None`. pub fn physical_to_column_exprs( @@ -700,7 +733,7 @@ pub fn new_join_children( alias.clone(), ) }) - .collect_vec(), + .collect(), Arc::clone(left_child), )?; let left_size = left_child.schema().fields().len() as i32; @@ -718,7 +751,7 @@ pub fn new_join_children( alias.clone(), ) }) - .collect_vec(), + .collect(), Arc::clone(right_child), )?; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index d0ad50666416..8b0f7e9784af 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -651,13 +651,13 @@ impl ExecutionPlan for RepartitionExec { let input = Arc::clone(&self.input); let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); - let preserve_order = self.preserve_order; + let preserve_order = self.sort_exprs().is_some(); let name = self.name().to_owned(); let schema = self.schema(); let schema_captured = Arc::clone(&schema); // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().cloned().unwrap_or_default(); + let sort_exprs = self.sort_exprs().cloned(); let state = Arc::clone(&self.state); if let Some(mut state) = state.try_lock() { @@ -723,7 +723,7 @@ impl ExecutionPlan for RepartitionExec { StreamingMergeBuilder::new() .with_streams(input_streams) .with_schema(schema_captured) - .with_expressions(&sort_exprs) + .with_expressions(&sort_exprs.unwrap()) .with_metrics(BaselineMetrics::new(&metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(fetch) @@ -1810,11 +1810,11 @@ mod test { } fn sort_exprs(schema: &Schema) -> LexOrdering { - let options = SortOptions::default(); - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("c0", schema).unwrap(), - options, - }]) + options: SortOptions::default(), + }] + .into() } fn memory_exec(schema: &SchemaRef) -> Arc { diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 78f898a2d77a..32b34a75cc76 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -105,7 +105,8 @@ impl PartialSortExec { ) -> Self { debug_assert!(common_prefix_length > 0); let preserve_partitioning = false; - let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); + let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { input, expr, @@ -159,7 +160,7 @@ impl PartialSortExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// If `Some(fetch)`, limits output to only the first "fetch" items @@ -189,24 +190,22 @@ impl PartialSortExec { input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, input.pipeline_behavior(), input.boundedness(), - ) + )) } } @@ -421,7 +420,7 @@ impl PartialSortStream { fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result { let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?; self.in_mem_batches.clear(); - let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?; + let result = sort_batch(&input_batch, &self.expr, self.fetch)?; if let Some(remaining_fetch) = self.fetch { // remaining_fetch - result.num_rows() is always be >= 0 // because result length of sort_batch with limit cannot be @@ -504,7 +503,7 @@ mod tests { }; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -517,10 +516,11 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), 2, - )) as Arc; + )); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -569,7 +569,7 @@ mod tests { for common_prefix_length in [1, 2] { let partial_sort_exec = Arc::new( PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -582,12 +582,13 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), common_prefix_length, ) .with_fetch(Some(4)), - ) as Arc; + ); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -642,7 +643,7 @@ mod tests { [(1, &source_tables[0]), (2, &source_tables[1])] { let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -655,7 +656,8 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(source), common_prefix_length, )); @@ -731,8 +733,8 @@ mod tests { nulls_first: false, }; let schema = mem_exec.schema(); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -745,17 +747,16 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new(SortExec::new( - partial_sort_executor.expr, - partial_sort_executor.input, - )) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + )); + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), [125, 125, 150] @@ -792,8 +793,8 @@ mod tests { (Some(150), vec![125, 25]), (Some(250), vec![125, 125]), ] { - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -806,19 +807,22 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new( - SortExec::new(partial_sort_executor.expr, partial_sort_executor.input) - .with_fetch(fetch_size), - ) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + SortExec::new( + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + ) + .with_fetch(fetch_size), + ); + let result = + collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), expected_batch_num_rows @@ -847,8 +851,8 @@ mod tests { nulls_first: false, }; let fetch_size = Some(250); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -857,15 +861,14 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; for rb in result { assert!(rb.num_rows() > 0); } @@ -898,10 +901,11 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, 1, )); @@ -987,7 +991,7 @@ mod tests { )?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -1000,7 +1004,8 @@ mod tests { expr: col("c", &schema)?, options: option_desc, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, 2, )); @@ -1062,10 +1067,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, 1, )); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 683983d9e697..c31f17291e19 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -31,7 +31,7 @@ use crate::limit::LimitStream; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; -use crate::projection::{make_with_child, update_expr, ProjectionExec}; +use crate::projection::{make_with_child, update_ordering, ProjectionExec}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; @@ -53,7 +53,6 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::LexOrdering; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -200,7 +199,7 @@ struct ExternalSorter { /// Schema of the output (and the input) schema: SchemaRef, /// Sort expressions - expr: Arc<[PhysicalSortExpr]>, + expr: LexOrdering, /// The target number of rows for output batches batch_size: usize, /// If the in size of buffered memory batches is below this size, @@ -279,7 +278,7 @@ impl ExternalSorter { in_mem_batches: vec![], in_progress_spill_file: None, finished_spill_files: vec![], - expr: expr.into(), + expr, metrics, reservation, spill_manager, @@ -344,12 +343,10 @@ impl ExternalSorter { streams.push(stream); } - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() .with_streams(streams) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) .with_fetch(None) @@ -674,12 +671,10 @@ impl ExternalSorter { }) .collect::>()?; - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() .with_streams(streams) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(metrics) .with_batch_size(self.batch_size) .with_fetch(None) @@ -703,7 +698,7 @@ impl ExternalSorter { ); let schema = batch.schema(); - let expressions: LexOrdering = self.expr.iter().cloned().collect(); + let expressions = self.expr.clone(); let stream = futures::stream::once(async move { let _timer = metrics.elapsed_compute().timer(); @@ -845,7 +840,7 @@ pub struct SortExec { /// Fetch highest/lowest n results fetch: Option, /// Normalized common sort prefix between the input and the sort expressions (only used with fetch) - common_sort_prefix: LexOrdering, + common_sort_prefix: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, } @@ -856,7 +851,8 @@ impl SortExec { pub fn new(expr: LexOrdering, input: Arc) -> Self { let preserve_partitioning = false; let (cache, sort_prefix) = - Self::compute_properties(&input, expr.clone(), preserve_partitioning); + Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { expr, input, @@ -954,13 +950,10 @@ impl SortExec { input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> (PlanProperties, LexOrdering) { - // Determine execution mode: - let requirement = LexRequirement::from(sort_exprs); - + ) -> Result<(PlanProperties, Vec)> { let (sort_prefix, sort_satisfied) = input .equivalence_properties() - .extract_common_sort_prefix(&requirement); + .extract_common_sort_prefix(sort_exprs.clone())?; // The emission type depends on whether the input is already sorted: // - If already fully sorted, we can emit results in the same way as the input @@ -989,25 +982,22 @@ impl SortExec { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let sort_exprs = LexOrdering::from(requirement); - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - ( + Ok(( PlanProperties::new( eq_properties, output_partitioning, emission_type, boundedness, ), - LexOrdering::from(sort_prefix), - ) + sort_prefix, + )) } } @@ -1020,7 +1010,17 @@ impl DisplayAs for SortExec { Some(fetch) => { write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?; if !self.common_sort_prefix.is_empty() { - write!(f, ", sort_prefix=[{}]", self.common_sort_prefix) + write!(f, ", sort_prefix=[")?; + let mut first = true; + for sort_expr in &self.common_sort_prefix { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]") } else { Ok(()) } @@ -1099,12 +1099,10 @@ impl ExecutionPlan for SortExec { trace!("End SortExec's input.execute for partition: {partition}"); - let requirement = &LexRequirement::from(self.expr.clone()); - let sort_satisfied = self .input .equivalence_properties() - .ordering_satisfy_requirement(requirement); + .ordering_satisfy(self.expr.clone())?; match (sort_satisfied, self.fetch.as_ref()) { (true, Some(fetch)) => Ok(Box::pin(LimitStream::new( @@ -1219,17 +1217,10 @@ impl ExecutionPlan for SortExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: new_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortExec::new(updated_exprs, make_with_child(projection, self.input())?) @@ -1291,9 +1282,9 @@ mod tests { impl SortedUnboundedExec { fn compute_properties(schema: SchemaRef) -> PlanProperties { let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![LexOrdering::new(vec![ - PhysicalSortExpr::new_default(Arc::new(Column::new("c1", 0))), - ])]); + eq_properties.add_ordering([PhysicalSortExpr::new_default(Arc::new( + Column::new("c1", 0), + ))]); PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), @@ -1393,10 +1384,11 @@ mod tests { let schema = csv.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), )); @@ -1404,7 +1396,6 @@ mod tests { assert_eq!(result.len(), 1); assert_eq!(result[0].num_rows(), 400); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1439,10 +1430,11 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); @@ -1475,7 +1467,6 @@ mod tests { let i = as_primitive_array::(&columns[0])?; assert_eq!(i.value(0), 0); assert_eq!(i.value(i.len() - 1), 81); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1518,18 +1509,11 @@ mod tests { } let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &plan.schema())?, - options: SortOptions::default(), - }]), + [PhysicalSortExpr::new_default(col("i", &plan.schema())?)].into(), plan, )); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await; let err = result.unwrap_err(); assert!( @@ -1568,18 +1552,15 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; let num_rows = result.iter().map(|batch| batch.num_rows()).sum::(); assert_eq!(num_rows, 20000); @@ -1668,20 +1649,18 @@ mod tests { let sort_exec = Arc::new( SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), ) .with_fetch(fetch), ); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = + collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); @@ -1711,16 +1690,16 @@ mod tests { let data: ArrayRef = Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]).unwrap(); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?; let input = - TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) - .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, )); @@ -1729,7 +1708,7 @@ mod tests { let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); let expected_batch = - RecordBatch::try_new(Arc::clone(&schema), vec![expected_data]).unwrap(); + RecordBatch::try_new(Arc::clone(&schema), vec![expected_data])?; // Data is correct assert_eq!(&vec![expected_batch], &result); @@ -1768,7 +1747,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1783,7 +1762,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, )); @@ -1854,7 +1834,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1869,7 +1849,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, )); @@ -1933,10 +1914,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -1964,12 +1946,13 @@ mod tests { RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options) .unwrap(); - let expressions = LexOrdering::new(vec![PhysicalSortExpr { + let expressions = [PhysicalSortExpr { expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), options: SortOptions::default(), - }]); + }] + .into(); - let result = sort_batch(&batch, expressions.as_ref(), None).unwrap(); + let result = sort_batch(&batch, &expressions, None).unwrap(); assert_eq!(result.num_rows(), 1); } @@ -1983,9 +1966,10 @@ mod tests { cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), }; let mut plan = SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); plan = plan.with_fetch(Some(9)); diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 6930473360f0..272b8f6d75e0 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::common::spawn_buffered; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::{make_with_child, update_expr, ProjectionExec}; +use crate::projection::{make_with_child, update_ordering, ProjectionExec}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, @@ -33,8 +33,7 @@ use crate::{ use datafusion_common::{internal_err, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use log::{debug, trace}; @@ -144,7 +143,7 @@ impl SortPreservingMergeExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// Fetch @@ -160,7 +159,7 @@ impl SortPreservingMergeExec { ) -> PlanProperties { let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_per_partition_constants(); - eq_properties.add_new_orderings(vec![ordering]); + eq_properties.add_ordering(ordering); PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning @@ -240,8 +239,8 @@ impl ExecutionPlan for SortPreservingMergeExec { vec![false] } - fn required_input_ordering(&self) -> Vec> { - vec![Some(LexRequirement::from(self.expr.clone()))] + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -319,7 +318,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let result = StreamingMergeBuilder::new() .with_streams(receivers) .with_schema(schema) - .with_expressions(self.expr.as_ref()) + .with_expressions(&self.expr) .with_metrics(BaselineMetrics::new(&self.metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(self.fetch) @@ -362,17 +361,10 @@ impl ExecutionPlan for SortPreservingMergeExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: updated_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortPreservingMergeExec::new( @@ -413,7 +405,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; + use datafusion_common::{assert_batches_eq, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; @@ -421,8 +413,8 @@ mod tests { use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use futures::{FutureExt, Stream, StreamExt}; use insta::assert_snapshot; use tokio::time::timeout; @@ -449,24 +441,25 @@ mod tests { let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); - let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; let rbs = (0..1024).map(|_| rb.clone()).collect::>(); let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { - expr: col("b", &schema).unwrap(), + expr: col("b", &schema)?, options: Default::default(), }, PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let repartition_exec = RepartitionExec::try_new( - TestMemoryExec::try_new_exec(&[rbs], schema, None).unwrap(), + TestMemoryExec::try_new_exec(&[rbs], schema, None)?, Partitioning::RoundRobinBatch(2), )?; let coalesce_batches_exec = @@ -485,7 +478,7 @@ mod tests { async fn test_round_robin_tie_breaker_success() -> Result<()> { let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; let spm = generate_spm_for_round_robin_tie_breaker(true)?; - let _collected = collect(spm, task_ctx).await.unwrap(); + let _collected = collect(spm, task_ctx).await?; Ok(()) } @@ -550,30 +543,6 @@ mod tests { .await; } - #[tokio::test] - async fn test_merge_no_exprs() { - let task_ctx = Arc::new(TaskContext::default()); - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); - let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); - - let schema = batch.schema(); - let sort = LexOrdering::default(); // no sort expressions - let exec = TestMemoryExec::try_new_exec( - &[vec![batch.clone()], vec![batch]], - schema, - None, - ) - .unwrap(); - - let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); - - let res = collect(merge, task_ctx).await.unwrap_err(); - assert_contains!( - res.to_string(), - "Internal error: Sort expressions cannot be empty for streaming merge" - ); - } - #[tokio::test] async fn test_merge_some_overlap() { let task_ctx = Arc::new(TaskContext::default()); @@ -741,7 +710,7 @@ mod tests { context: Arc, ) { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -750,7 +719,8 @@ mod tests { expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -798,13 +768,14 @@ mod tests { let csv = test::scan_partitioned(partitions); let schema = csv.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let basic = basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; @@ -859,17 +830,18 @@ mod tests { let sorted = basic_sort(csv, sort, context).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); - Ok(TestMemoryExec::try_new_exec(&split, sorted.schema(), None).unwrap()) + TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _) } #[tokio::test] async fn test_partition_sort_streaming_input() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) @@ -881,12 +853,9 @@ mod tests { assert_eq!(basic.num_rows(), 1200); assert_eq!(partition.num_rows(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(&[partition]) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = + arrow::util::pretty::pretty_format_batches(&[partition])?.to_string(); assert_eq!(basic, partition); @@ -896,10 +865,11 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() -> Result<()> { let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); // Test streaming with default batch size let task_ctx = Arc::new(TaskContext::default()); @@ -914,19 +884,14 @@ mod tests { let task_ctx = Arc::new(task_ctx); let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let merged = collect(merge, task_ctx).await.unwrap(); + let merged = collect(merge, task_ctx).await?; assert_eq!(merged.len(), 53); - assert_eq!(basic.num_rows(), 1200); assert_eq!(merged.iter().map(|x| x.num_rows()).sum::(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice()) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string(); assert_eq!(basic, partition); @@ -971,7 +936,7 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { @@ -986,7 +951,8 @@ mod tests { nulls_first: false, }, }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1020,13 +986,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2))); @@ -1052,13 +1019,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1082,10 +1050,11 @@ mod tests { async fn test_async() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort: LexOrdering = [PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: SortOptions::default(), - }]); + }] + .into(); let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) @@ -1121,7 +1090,7 @@ mod tests { let merge_stream = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(batches.schema()) - .with_expressions(sort.as_ref()) + .with_expressions(&sort) .with_metrics(BaselineMetrics::new(&metrics, 0)) .with_batch_size(task_ctx.session_config().batch_size()) .with_fetch(fetch) @@ -1161,10 +1130,11 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1220,10 +1190,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -1268,13 +1239,14 @@ mod tests { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("value", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1331,10 +1303,11 @@ mod tests { .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc) .collect::>(); let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![columns - .iter() - .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))) - .collect::()]); + eq_properties.add_ordering( + columns + .iter() + .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))), + ); PlanProperties::new( eq_properties, Partitioning::Hash(columns, 3), @@ -1453,9 +1426,10 @@ mod tests { congestion_cleared: Arc::new(Mutex::new(false)), }; let spm = SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx)); @@ -1464,12 +1438,8 @@ mod tests { match result { Ok(Ok(Ok(_batches))) => Ok(()), Ok(Ok(Err(e))) => Err(e), - Ok(Err(_)) => Err(DataFusionError::Execution( - "SortPreservingMerge task panicked or was cancelled".to_string(), - )), - Err(_) => Err(DataFusionError::Execution( - "SortPreservingMerge caused a deadlock".to_string(), - )), + Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"), + Err(_) => exec_err!("SortPreservingMerge caused a deadlock"), } } } diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 3f022ec6095a..b74954eb9688 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -52,10 +52,11 @@ macro_rules! merge_helper { }}; } +#[derive(Default)] pub struct StreamingMergeBuilder<'a> { streams: Vec, schema: Option, - expressions: &'a LexOrdering, + expressions: Option<&'a LexOrdering>, metrics: Option, batch_size: Option, fetch: Option, @@ -63,21 +64,6 @@ pub struct StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker: bool, } -impl Default for StreamingMergeBuilder<'_> { - fn default() -> Self { - Self { - streams: vec![], - schema: None, - expressions: LexOrdering::empty(), - metrics: None, - batch_size: None, - fetch: None, - reservation: None, - enable_round_robin_tie_breaker: false, - } - } -} - impl<'a> StreamingMergeBuilder<'a> { pub fn new() -> Self { Self { @@ -97,7 +83,7 @@ impl<'a> StreamingMergeBuilder<'a> { } pub fn with_expressions(mut self, expressions: &'a LexOrdering) -> Self { - self.expressions = expressions; + self.expressions = Some(expressions); self } @@ -145,22 +131,13 @@ impl<'a> StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker, } = self; - // Early return if streams or expressions are empty - let checks = [ - ( - streams.is_empty(), - "Streams cannot be empty for streaming merge", - ), - ( - expressions.is_empty(), - "Sort expressions cannot be empty for streaming merge", - ), - ]; - - if let Some((_, error_message)) = checks.iter().find(|(condition, _)| *condition) - { - return internal_err!("{}", error_message); + // Early return if streams or expressions are empty: + if streams.is_empty() { + return internal_err!("Streams cannot be empty for streaming merge"); } + let Some(expressions) = expressions else { + return internal_err!("Sort expressions cannot be empty for streaming merge"); + }; // Unwrapping mandatory fields let schema = schema.expect("Schema cannot be empty for streaming merge"); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 6274995d04da..3f2e99b1d003 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -27,7 +27,7 @@ use crate::execution_plan::{Boundedness, EmissionType}; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ - all_alias_free_columns, new_projections_for_columns, update_expr, ProjectionExec, + all_alias_free_columns, new_projections_for_columns, update_ordering, ProjectionExec, }; use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; @@ -35,7 +35,7 @@ use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use async_trait::async_trait; use futures::stream::StreamExt; @@ -99,7 +99,7 @@ impl StreamingTableExec { projected_output_ordering.into_iter().collect::>(); let cache = Self::compute_properties( Arc::clone(&projected_schema), - &projected_output_ordering, + projected_output_ordering.clone(), &partitions, infinite, ); @@ -146,7 +146,7 @@ impl StreamingTableExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( schema: SchemaRef, - orderings: &[LexOrdering], + orderings: Vec, partitions: &[Arc], infinite: bool, ) -> PlanProperties { @@ -306,20 +306,11 @@ impl ExecutionPlan for StreamingTableExec { ); let mut lex_orderings = vec![]; - for lex_ordering in self.projected_output_ordering().into_iter() { - let mut orderings = LexOrdering::default(); - for order in lex_ordering { - let Some(new_ordering) = - update_expr(&order.expr, projection.expr(), false)? - else { - return Ok(None); - }; - orderings.push(PhysicalSortExpr { - expr: new_ordering, - options: order.options, - }); - } - lex_orderings.push(orderings); + for ordering in self.projected_output_ordering().into_iter() { + let Some(ordering) = update_ordering(ordering, projection.expr())? else { + return Ok(None); + }; + lex_orderings.push(ordering); } StreamingTableExec::try_new( diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 4d5244e0e1d4..140a5f35a748 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -40,10 +40,12 @@ use datafusion_common::{ config::ConfigOptions, internal_err, project_schema, Result, Statistics, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, utils::collect_columns, - EquivalenceProperties, LexOrdering, Partitioning, +use datafusion_physical_expr::equivalence::{ + OrderingEquivalenceClass, ProjectionMapping, }; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; use futures::{Future, FutureExt}; @@ -216,7 +218,7 @@ impl TestMemoryExec { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } @@ -240,7 +242,7 @@ impl TestMemoryExec { cache: PlanProperties::new( EquivalenceProperties::new_with_orderings( Arc::clone(&projected_schema), - vec![].as_slice(), + Vec::::new(), ), Partitioning::UnknownPartitioning(partitions.len()), EmissionType::Incremental, @@ -324,24 +326,21 @@ impl TestMemoryExec { // If there is a projection on the source, we also need to project orderings if let Some(projection) = &self.projection { + let base_schema = self.original_schema(); + let proj_exprs = projection.iter().map(|idx| { + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }); + let projection_mapping = + ProjectionMapping::try_new(proj_exprs, &base_schema)?; let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, + Arc::clone(&base_schema), + sort_information, ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + let proj_eqp = + base_eqp.project(&projection_mapping, Arc::clone(&self.projected_schema)); + let oeq_class: OrderingEquivalenceClass = proj_eqp.into(); + sort_information = oeq_class.into(); } self.sort_information = sort_information; diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 0b5780b9143f..d8b6f0e400b8 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -17,26 +17,23 @@ //! TopK: Combination of Sort / LIMIT -use arrow::{ - compute::interleave_record_batch, - row::{RowConverter, Rows, SortField}, -}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; + use arrow::array::{ArrayRef, RecordBatch}; +use arrow::compute::interleave_record_batch; use arrow::datatypes::SchemaRef; -use datafusion_common::Result; -use datafusion_common::{internal_datafusion_err, HashMap}; +use arrow::row::{RowConverter, Rows, SortField}; +use datafusion_common::{internal_datafusion_err, HashMap, Result}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, }; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// Global TopK /// @@ -102,7 +99,7 @@ pub struct TopK { /// The target number of rows for output batches batch_size: usize, /// sort expressions - expr: Arc<[PhysicalSortExpr]>, + expr: LexOrdering, /// row converter, for sort keys row_converter: RowConverter, /// scratch space for converting rows @@ -123,7 +120,7 @@ pub struct TopK { const ESTIMATED_BYTES_PER_ROW: usize = 20; fn build_sort_fields( - ordering: &LexOrdering, + ordering: &[PhysicalSortExpr], schema: &SchemaRef, ) -> Result> { ordering @@ -145,7 +142,7 @@ impl TopK { pub fn try_new( partition_id: usize, schema: SchemaRef, - common_sort_prefix: LexOrdering, + common_sort_prefix: Vec, expr: LexOrdering, k: usize, batch_size: usize, @@ -155,7 +152,7 @@ impl TopK { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); - let sort_fields: Vec<_> = build_sort_fields(&expr, &schema)?; + let sort_fields = build_sort_fields(&expr, &schema)?; // TODO there is potential to add special cases for single column sort fields // to improve performance @@ -166,8 +163,7 @@ impl TopK { let prefix_row_converter = if common_sort_prefix.is_empty() { None } else { - let input_sort_fields: Vec<_> = - build_sort_fields(&common_sort_prefix, &schema)?; + let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?; Some(RowConverter::new(input_sort_fields)?) }; @@ -176,7 +172,7 @@ impl TopK { metrics: TopKMetrics::new(metrics, partition_id), reservation, batch_size, - expr: Arc::from(expr), + expr, row_converter, scratch_rows, heap: TopKHeap::new(k, batch_size), @@ -821,8 +817,8 @@ mod tests { }; // Input ordering uses only column "a" (a prefix of the full sort). - let input_ordering = LexOrdering::from(vec![sort_expr_a.clone()]); - let full_expr = LexOrdering::from(vec![sort_expr_a, sort_expr_b]); + let prefix = vec![sort_expr_a.clone()]; + let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]); // Create a dummy runtime environment and metrics. let runtime = Arc::new(RuntimeEnv::default()); @@ -832,7 +828,7 @@ mod tests { let mut topk = TopK::try_new( 0, Arc::clone(&schema), - input_ordering, + prefix, full_expr, 3, 2, diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 930fe793d1d4..73d7933e7c05 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -677,15 +677,13 @@ fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { mod tests { use super::*; use crate::collect; - use crate::test; - use crate::test::TestMemoryExec; + use crate::test::{self, TestMemoryExec}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; + use datafusion_physical_expr::equivalence::convert_to_orderings; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -701,19 +699,6 @@ mod tests { Ok(schema) } - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect::() - } - #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -889,18 +874,9 @@ mod tests { (first_child_orderings, second_child_orderings, union_orderings), ) in test_cases.iter().enumerate() { - let first_orderings = first_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let second_orderings = second_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let union_expected_orderings = union_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); + let first_orderings = convert_to_orderings(first_child_orderings); + let second_orderings = convert_to_orderings(second_child_orderings); + let union_expected_orderings = convert_to_orderings(union_orderings); let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new( TestMemoryExec::try_new(&[], Arc::clone(&schema), None)? .try_with_sort_information(first_orderings)?, @@ -911,7 +887,7 @@ mod tests { ))); let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); - union_expected_eq.add_new_orderings(union_expected_orderings); + union_expected_eq.add_orderings(union_expected_orderings); let union = UnionExec::new(vec![child1, child2]); let union_eq_properties = union.properties().equivalence_properties(); diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 6751f9b20240..d851d08a101f 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -38,7 +38,7 @@ use crate::{ ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use ahash::RandomState; + use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, @@ -60,9 +60,12 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; +use ahash::RandomState; use futures::stream::Stream; use futures::{ready, StreamExt}; use hashbrown::hash_table::HashTable; @@ -111,7 +114,7 @@ impl BoundedWindowAggExec { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, - ); + )?; if indices.len() == partition_by_exprs.len() { indices } else { @@ -123,7 +126,7 @@ impl BoundedWindowAggExec { vec![] } }; - let cache = Self::compute_properties(&input, &schema, &window_expr); + let cache = Self::compute_properties(&input, &schema, &window_expr)?; Ok(Self { input, window_expr, @@ -151,7 +154,7 @@ impl BoundedWindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -191,9 +194,9 @@ impl BoundedWindowAggExec { input: &Arc, schema: &SchemaRef, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(schema, input, window_exprs); + let eq_properties = window_equivalence_properties(schema, input, window_exprs)?; // As we can have repartitioning using the partition keys, this can // be either one or more than one, depending on the presence of @@ -201,13 +204,13 @@ impl BoundedWindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here input.pipeline_behavior(), input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -303,14 +306,14 @@ impl ExecutionPlan for BoundedWindowAggExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } fn required_input_distribution(&self) -> Vec { @@ -750,7 +753,7 @@ impl LinearSearch { /// when computing partitions. pub struct SortedSearch { /// Stores partition by columns and their ordering information - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, /// Input ordering and partition by key ordering need not be the same, so /// this vector stores the mapping between them. For instance, if the input /// is ordered by a, b and the window expression contains a PARTITION BY b, a @@ -1347,10 +1350,10 @@ mod tests { Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; - let orderby_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, options: SortOptions::default(), - }]); + }]; let window_frame = WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::CurrentRow, @@ -1366,7 +1369,7 @@ mod tests { fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame), &input.schema(), false, @@ -1463,13 +1466,14 @@ mod tests { } fn schema_orders(schema: &SchemaRef) -> Result> { - let orderings = vec![LexOrdering::new(vec![PhysicalSortExpr { + let orderings = vec![[PhysicalSortExpr { expr: col("sn", schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }])]; + }] + .into()]; Ok(orderings) } @@ -1620,7 +1624,7 @@ mod tests { Arc::new(StandardWindowExpr::new( last_value_func, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1631,7 +1635,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func1, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1642,7 +1646,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func2, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1783,8 +1787,8 @@ mod tests { let plan = projection_exec(window)?; let expected_plan = vec![ - "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]@2 as col_2]", - " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", + "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", + " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", ]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d2b7e0a49e95..5583abfd72a2 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -22,7 +22,6 @@ mod utils; mod window_agg_exec; use std::borrow::Borrow; -use std::iter; use std::sync::Arc; use crate::{ @@ -42,12 +41,13 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{ - reverse_order_bys, - window::{SlidingAggregateWindowExpr, StandardWindowFunctionExpr}, - ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, +use datafusion_physical_expr::window::{ + SlidingAggregateWindowExpr, StandardWindowFunctionExpr, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use itertools::Itertools; @@ -99,7 +99,7 @@ pub fn create_window_expr( name: String, args: &[Arc], partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, @@ -131,7 +131,7 @@ pub fn create_window_expr( /// Creates an appropriate [`WindowExpr`] based on the window frame and fn window_expr_from_aggregate_expr( partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, aggregate: Arc, ) -> Arc { @@ -278,26 +278,33 @@ pub(crate) fn calc_requirements< >( partition_by_exprs: impl IntoIterator, orderby_sort_exprs: impl IntoIterator, -) -> Option { - let mut sort_reqs = LexRequirement::new( - partition_by_exprs - .into_iter() - .map(|partition_by| { - PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) - }) - .collect::>(), - ); +) -> Option { + let mut sort_reqs_with_partition = partition_by_exprs + .into_iter() + .map(|partition_by| { + PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) + }) + .collect::>(); + let mut sort_reqs = vec![]; for element in orderby_sort_exprs.into_iter() { let PhysicalSortExpr { expr, options } = element.borrow(); - if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(expr), - Some(*options), - )); + let sort_req = PhysicalSortRequirement::new(Arc::clone(expr), Some(*options)); + if !sort_reqs_with_partition.iter().any(|e| e.expr.eq(expr)) { + sort_reqs_with_partition.push(sort_req.clone()); + } + if !sort_reqs + .iter() + .any(|e: &PhysicalSortRequirement| e.expr.eq(expr)) + { + sort_reqs.push(sort_req); } } - // Convert empty result to None. Otherwise wrap result inside Some() - (!sort_reqs.is_empty()).then_some(sort_reqs) + + let mut alternatives = vec![]; + alternatives.extend(LexRequirement::new(sort_reqs_with_partition)); + alternatives.extend(LexRequirement::new(sort_reqs)); + + OrderingRequirements::new_alternatives(alternatives, false) } /// This function calculates the indices such that when partition by expressions reordered with the indices @@ -308,18 +315,18 @@ pub(crate) fn calc_requirements< pub fn get_ordered_partition_by_indices( partition_by_exprs: &[Arc], input: &Arc, -) -> Vec { +) -> Result> { let (_, indices) = input .equivalence_properties() - .find_longest_permutation(partition_by_exprs); - indices + .find_longest_permutation(partition_by_exprs)?; + Ok(indices) } pub(crate) fn get_partition_by_sort_exprs( input: &Arc, partition_by_exprs: &[Arc], ordered_partition_by_indices: &[usize], -) -> Result { +) -> Result> { let ordered_partition_exprs = ordered_partition_by_indices .iter() .map(|idx| Arc::clone(&partition_by_exprs[*idx])) @@ -328,7 +335,7 @@ pub(crate) fn get_partition_by_sort_exprs( assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); let (ordering, _) = input .equivalence_properties() - .find_longest_permutation(&ordered_partition_exprs); + .find_longest_permutation(&ordered_partition_exprs)?; if ordering.len() == ordered_partition_exprs.len() { Ok(ordering) } else { @@ -340,11 +347,11 @@ pub(crate) fn window_equivalence_properties( schema: &SchemaRef, input: &Arc, window_exprs: &[Arc], -) -> EquivalenceProperties { +) -> Result { // We need to update the schema, so we can't directly use input's equivalence // properties. let mut window_eq_properties = EquivalenceProperties::new(Arc::clone(schema)) - .extend(input.equivalence_properties().clone()); + .extend(input.equivalence_properties().clone())?; let window_schema_len = schema.fields.len(); let input_schema_len = window_schema_len - window_exprs.len(); @@ -356,22 +363,25 @@ pub(crate) fn window_equivalence_properties( // Collect columns defining partitioning, and construct all `SortOptions` // variations for them. Then, we will check each one whether it satisfies // the existing ordering provided by the input plan. - let partition_by_orders = partitioning_exprs + let mut all_satisfied_lexs = vec![]; + for lex in partitioning_exprs .iter() - .map(|pb_order| sort_options_resolving_constant(Arc::clone(pb_order))); - let all_satisfied_lexs = partition_by_orders + .map(|pb_order| sort_options_resolving_constant(Arc::clone(pb_order))) .multi_cartesian_product() - .map(LexOrdering::new) - .filter(|lex| window_eq_properties.ordering_satisfy(lex)) - .collect::>(); + .filter_map(LexOrdering::new) + { + if window_eq_properties.ordering_satisfy(lex.clone())? { + all_satisfied_lexs.push(lex); + } + } // If there is a partitioning, and no possible ordering cannot satisfy // the input plan's orderings, then we cannot further introduce any // new orderings for the window plan. if !no_partitioning && all_satisfied_lexs.is_empty() { - return window_eq_properties; + return Ok(window_eq_properties); } else if let Some(std_expr) = expr.as_any().downcast_ref::() { - std_expr.add_equal_orderings(&mut window_eq_properties); + std_expr.add_equal_orderings(&mut window_eq_properties)?; } else if let Some(plain_expr) = expr.as_any().downcast_ref::() { @@ -379,26 +389,26 @@ pub(crate) fn window_equivalence_properties( // unbounded starting point. // First, check if the frame covers the whole table: if plain_expr.get_window_frame().end_bound.is_unbounded() { - let window_col = Column::new(expr.name(), i + input_schema_len); + let window_col = + Arc::new(Column::new(expr.name(), i + input_schema_len)) as _; if no_partitioning { // Window function has a constant result across the table: - window_eq_properties = window_eq_properties - .with_constants(iter::once(ConstExpr::new(Arc::new(window_col)))) + window_eq_properties + .add_constants(std::iter::once(ConstExpr::from(window_col)))? } else { // Window function results in a partial constant value in // some ordering. Adjust the ordering equivalences accordingly: let new_lexs = all_satisfied_lexs.into_iter().flat_map(|lex| { - let orderings = lex.take_exprs(); let new_partial_consts = - sort_options_resolving_constant(Arc::new(window_col.clone())); + sort_options_resolving_constant(Arc::clone(&window_col)); new_partial_consts.into_iter().map(move |partial| { - let mut existing = orderings.clone(); + let mut existing = lex.clone(); existing.push(partial); - LexOrdering::new(existing) + existing }) }); - window_eq_properties.add_new_orderings(new_lexs); + window_eq_properties.add_orderings(new_lexs); } } else { // The window frame is ever expanding, so set monotonicity comes @@ -406,7 +416,7 @@ pub(crate) fn window_equivalence_properties( plain_expr.add_equal_orderings( &mut window_eq_properties, window_expr_indices[i], - ); + )?; } } else if let Some(sliding_expr) = expr.as_any().downcast_ref::() @@ -424,22 +434,18 @@ pub(crate) fn window_equivalence_properties( let window_col = Column::new(expr.name(), i + input_schema_len); if no_partitioning { // Reverse set-monotonic cases with no partitioning: - let new_ordering = - vec![LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(increasing, true), - )])]; - window_eq_properties.add_new_orderings(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(increasing, true), + )]); } else { // Reverse set-monotonic cases for all orderings: - for lex in all_satisfied_lexs.into_iter() { - let mut existing = lex.take_exprs(); - existing.push(PhysicalSortExpr::new( + for mut lex in all_satisfied_lexs.into_iter() { + lex.push(PhysicalSortExpr::new( Arc::new(window_col.clone()), SortOptions::new(increasing, true), )); - window_eq_properties - .add_new_ordering(LexOrdering::new(existing)); + window_eq_properties.add_ordering(lex); } } } @@ -450,44 +456,44 @@ pub(crate) fn window_equivalence_properties( // utilize set-monotonicity since the set shrinks as the frame // boundary starts "touching" the end of the table. else if frame.is_causal() { - let mut args_all_lexs = sliding_expr + let args_all_lexs = sliding_expr .get_aggregate_expr() .expressions() .into_iter() .map(sort_options_resolving_constant) .multi_cartesian_product(); - let mut asc = false; - if args_all_lexs.any(|order| { + let (mut asc, mut satisfied) = (false, false); + for order in args_all_lexs { if let Some(f) = order.first() { asc = !f.options.descending; } - window_eq_properties.ordering_satisfy(&LexOrdering::new(order)) - }) { + if window_eq_properties.ordering_satisfy(order)? { + satisfied = true; + break; + } + } + if satisfied { let increasing = set_monotonicity.eq(&SetMonotonicity::Increasing); let window_col = Column::new(expr.name(), i + input_schema_len); if increasing && (asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(false, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(false, false), + )]); } else if !increasing && (!asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(true, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(true, false), + )]); }; } } } } } - window_eq_properties + Ok(window_eq_properties) } /// Constructs the best-fitting windowing operator (a `WindowAggExec` or a @@ -514,7 +520,7 @@ pub fn get_best_fitting_window( let orderby_keys = window_exprs[0].order_by(); let (should_reverse, input_order_mode) = if let Some((should_reverse, input_order_mode)) = - get_window_mode(partitionby_exprs, orderby_keys, input) + get_window_mode(partitionby_exprs, orderby_keys, input)? { (should_reverse, input_order_mode) } else { @@ -580,35 +586,29 @@ pub fn get_best_fitting_window( /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], - orderby_keys: &LexOrdering, + orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Option<(bool, InputOrderMode)> { - let input_eqs = input.equivalence_properties().clone(); - let mut partition_by_reqs: LexRequirement = LexRequirement::new(vec![]); - let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); - vec![].extend(indices.iter().map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&partitionby_exprs[idx]), - options: None, - })); - partition_by_reqs - .inner - .extend(indices.iter().map(|&idx| PhysicalSortRequirement { +) -> Result> { + let mut input_eqs = input.equivalence_properties().clone(); + let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs)?; + let partition_by_reqs = indices + .iter() + .map(|&idx| PhysicalSortRequirement { expr: Arc::clone(&partitionby_exprs[idx]), options: None, - })); + }) + .collect::>(); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); - let partition_by_eqs = input_eqs.with_constants(const_exprs); - let order_by_reqs = LexRequirement::from(orderby_keys.clone()); - let reverse_order_by_reqs = LexRequirement::from(reverse_order_bys(orderby_keys)); - for (should_swap, order_by_reqs) in - [(false, order_by_reqs), (true, reverse_order_by_reqs)] + let const_exprs = partitionby_exprs.iter().cloned().map(ConstExpr::from); + input_eqs.add_constants(const_exprs)?; + let reverse_orderby_keys = + orderby_keys.iter().map(|e| e.reverse()).collect::>(); + for (should_swap, orderbys) in + [(false, orderby_keys), (true, reverse_orderby_keys.as_ref())] { - let req = LexRequirement::new( - [partition_by_reqs.inner.clone(), order_by_reqs.inner].concat(), - ) - .collapse(); - if partition_by_eqs.ordering_satisfy_requirement(&req) { + let mut req = partition_by_reqs.clone(); + req.extend(orderbys.iter().cloned().map(Into::into)); + if req.is_empty() || input_eqs.ordering_satisfy_requirement(req)? { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { InputOrderMode::Sorted @@ -617,10 +617,10 @@ pub fn get_window_mode( } else { InputOrderMode::PartiallySorted(indices) }; - return Some((should_swap, mode)); + return Ok(Some((should_swap, mode))); } } - None + Ok(None) } fn sort_options_resolving_constant(expr: Arc) -> Vec { @@ -642,11 +642,11 @@ mod tests { use arrow::compute::SortOptions; use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; - use datafusion_functions_aggregate::count::count_udaf; - use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; + use futures::FutureExt; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -696,16 +696,14 @@ mod tests { /// Created a sorted Streaming Table exec pub fn streaming_table_exec( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, infinite_source: bool, ) -> Result> { - let sort_exprs = sort_exprs.into_iter().collect(); - Ok(Arc::new(StreamingTableExec::try_new( Arc::clone(schema), vec![], None, - Some(sort_exprs), + Some(ordering), infinite_source, None, )?)) @@ -719,25 +717,38 @@ mod tests { ( vec!["a"], vec![("b", true, true)], - vec![("a", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("b", Some((true, true)))], + vec![("b", Some((true, true)))], + ], ), // PARTITION BY a, ORDER BY a ASC NULLS FIRST - (vec!["a"], vec![("a", true, true)], vec![("a", None)]), + ( + vec!["a"], + vec![("a", true, true)], + vec![vec![("a", None)], vec![("a", Some((true, true)))]], + ), // PARTITION BY a, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a"], vec![("b", true, true), ("c", false, false)], vec![ - ("a", None), - ("b", Some((true, true))), - ("c", Some((false, false))), + vec![ + ("a", None), + ("b", Some((true, true))), + ("c", Some((false, false))), + ], + vec![("b", Some((true, true))), ("c", Some((false, false)))], ], ), // PARTITION BY a, c, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a", "c"], vec![("b", true, true), ("c", false, false)], - vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![("b", Some((true, true))), ("c", Some((false, false)))], + ], ), ]; for (pb_params, ob_params, expected_params) in test_data { @@ -749,25 +760,26 @@ mod tests { let mut orderbys = vec![]; for (col_name, descending, nulls_first) in ob_params { let expr = col(col_name, &schema)?; - let options = SortOptions { - descending, - nulls_first, - }; - orderbys.push(PhysicalSortExpr { expr, options }); + let options = SortOptions::new(descending, nulls_first); + orderbys.push(PhysicalSortExpr::new(expr, options)); } - let mut expected: Option = None; - for (col_name, reqs) in expected_params { - let options = reqs.map(|(descending, nulls_first)| SortOptions { - descending, - nulls_first, - }); - let expr = col(col_name, &schema)?; - let res = PhysicalSortRequirement::new(expr, options); - if let Some(expected) = &mut expected { - expected.push(res); - } else { - expected = Some(LexRequirement::new(vec![res])); + let mut expected: Option = None; + for expected_param in expected_params.clone() { + let mut requirements = vec![]; + for (col_name, reqs) in expected_param { + let options = reqs.map(|(descending, nulls_first)| { + SortOptions::new(descending, nulls_first) + }); + let expr = col(col_name, &schema)?; + requirements.push(PhysicalSortRequirement::new(expr, options)); + } + if let Some(requirements) = LexRequirement::new(requirements) { + if let Some(alts) = expected.as_mut() { + alts.add_alternative(requirements); + } else { + expected = Some(OrderingRequirements::new(requirements)); + } } } assert_eq!(calc_requirements(partitionbys, orderbys), expected); @@ -789,7 +801,7 @@ mod tests { "count".to_owned(), &[col("a", &schema)?], &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, @@ -893,13 +905,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -986,7 +999,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for col_name in order_by_params { let expr = col(col_name, &test_schema)?; // Give default ordering, this is same with input ordering direction @@ -994,11 +1007,8 @@ mod tests { let options = SortOptions::default(); order_by_exprs.push(PhysicalSortExpr { expr, options }); } - let res = get_window_mode( - &partition_by_exprs, - order_by_exprs.as_ref(), - &exec_unbounded, - ); + let res = + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( @@ -1016,13 +1026,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -1151,7 +1162,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for (col_name, descending, nulls_first) in order_by_params { let expr = col(col_name, &test_schema)?; let options = SortOptions { @@ -1162,7 +1173,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, order_by_exprs.as_ref(), &exec_unbounded), + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 4c76e2230875..1b7cb9bb76e1 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -44,7 +44,9 @@ use datafusion_common::stats::Precision; use datafusion_common::utils::{evaluate_partition_ranges, transpose}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; use futures::{ready, Stream, StreamExt}; @@ -79,8 +81,8 @@ impl WindowAggExec { let schema = Arc::new(schema); let ordered_partition_by_indices = - get_ordered_partition_by_indices(window_expr[0].partition_by(), &input); - let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr); + get_ordered_partition_by_indices(window_expr[0].partition_by(), &input)?; + let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr)?; Ok(Self { input, window_expr, @@ -107,7 +109,7 @@ impl WindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -121,9 +123,9 @@ impl WindowAggExec { schema: SchemaRef, input: &Arc, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(&schema, input, window_exprs); + let eq_properties = window_equivalence_properties(&schema, input, window_exprs)?; // Get output partitioning: // Because we can have repartitioning using the partition keys this @@ -131,13 +133,13 @@ impl WindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache: - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here EmissionType::Final, input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -234,17 +236,17 @@ impl ExecutionPlan for WindowAggExec { vec![true] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); if self.ordered_partition_by_indices.len() < partition_bys.len() { - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } else { let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } } @@ -319,7 +321,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, ordered_partition_by_indices: Vec, } @@ -331,7 +333,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, ordered_partition_by_indices: Vec, ) -> Result { // In WindowAggExec all partition by columns should be ordered. diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 18073516610c..dc68a1147324 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -315,28 +315,22 @@ pub fn serialize_expr( null_treatment: _, }, } = window_fun.as_ref(); - let (window_function, fun_definition) = match fun { + let mut buf = Vec::new(); + let window_function = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udaf(aggr_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udwf(window_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), ) } }; + let fun_definition = (!buf.is_empty()).then_some(buf); let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5024bb558a65..1c60470b2218 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -102,13 +102,13 @@ pub fn parse_physical_sort_exprs( registry: &dyn FunctionRegistry, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result> { proto .iter() .map(|sort_expr| { parse_physical_sort_expr(sort_expr, registry, input_schema, codec) }) - .collect::>() + .collect() } /// Parses a physical window expr from a protobuf. @@ -175,7 +175,7 @@ pub fn parse_physical_window_expr( name, &window_node_expr, &partition_by, - order_by.as_ref(), + &order_by, Arc::new(window_frame), &extended_schema, false, @@ -528,13 +528,13 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let sort_expr = parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, registry, &schema, codec, )?; - output_ordering.push(sort_expr); + output_ordering.extend(LexOrdering::new(sort_exprs)); } let config = FileScanConfigBuilder::new(object_store_url, file_schema, file_source) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 7a85a2a8efbd..672f5d2bd43a 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1030,7 +1030,7 @@ impl protobuf::PhysicalPlanNode { ) }) .collect::>>()?; - let ordering_req: LexOrdering = agg_node + let order_bys = agg_node .ordering_req .iter() .map(|e| { @@ -1041,7 +1041,7 @@ impl protobuf::PhysicalPlanNode { extension_codec, ) }) - .collect::>()?; + .collect::>()?; agg_node .aggregate_function .as_ref() @@ -1063,7 +1063,7 @@ impl protobuf::PhysicalPlanNode { .alias(name) .with_ignore_nulls(agg_node.ignore_nulls) .with_distinct(agg_node.distinct) - .order_by(ordering_req) + .order_by(order_bys) .build() .map(Arc::new) } @@ -1292,11 +1292,7 @@ impl protobuf::PhysicalPlanNode { &left_schema, extension_codec, )?; - let left_sort_exprs = if left_sort_exprs.is_empty() { - None - } else { - Some(left_sort_exprs) - }; + let left_sort_exprs = LexOrdering::new(left_sort_exprs); let right_sort_exprs = parse_physical_sort_exprs( &sym_join.right_sort_exprs, @@ -1304,11 +1300,7 @@ impl protobuf::PhysicalPlanNode { &right_schema, extension_codec, )?; - let right_sort_exprs = if right_sort_exprs.is_empty() { - None - } else { - Some(right_sort_exprs) - }; + let right_sort_exprs = LexOrdering::new(right_sort_exprs); let partition_mode = protobuf::StreamPartitionMode::try_from( sym_join.partition_mode, @@ -1420,47 +1412,45 @@ impl protobuf::PhysicalPlanNode { runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; + let input = into_physical_plan(&sort.input, registry, runtime, extension_codec)?; let exprs = sort - .expr - .iter() - .map(|expr| { - let expr = expr.expr_type.as_ref().ok_or_else(|| { + .expr + .iter() + .map(|expr| { + let expr = expr.expr_type.as_ref().ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected expr {self:?}" + )) + })?; + if let ExprType::Sort(sort_expr) = expr { + let expr = sort_expr + .expr + .as_ref() + .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" + "physical_plan::from_proto() Unexpected sort expr {self:?}" )) - })?; - if let ExprType::Sort(sort_expr) = expr { - let expr = sort_expr - .expr - .as_ref() - .ok_or_else(|| { - proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" - )) - })? - .as_ref(); - Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, - options: SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }, - }) - } else { - internal_err!( - "physical_plan::from_proto() {self:?}" - ) - } + })? + .as_ref(); + Ok(PhysicalSortExpr { + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, + options: SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }, }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None - } else { - Some(sort.fetch as usize) + } else { + internal_err!( + "physical_plan::from_proto() {self:?}" + ) + } + }) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); }; - let new_sort = SortExec::new(exprs, input) + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); + let new_sort = SortExec::new(ordering, input) .with_fetch(fetch) .with_preserve_partitioning(sort.preserve_partitioning); @@ -1474,8 +1464,7 @@ impl protobuf::PhysicalPlanNode { runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; + let input = into_physical_plan(&sort.input, registry, runtime, extension_codec)?; let exprs = sort .expr .iter() @@ -1511,14 +1500,13 @@ impl protobuf::PhysicalPlanNode { internal_err!("physical_plan::from_proto() {self:?}") } }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None - } else { - Some(sort.fetch as usize) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); }; + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); Ok(Arc::new( - SortPreservingMergeExec::new(exprs, input).with_fetch(fetch), + SortPreservingMergeExec::new(ordering, input).with_fetch(fetch), )) } @@ -1657,9 +1645,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), @@ -1692,9 +1683,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), @@ -1730,9 +1724,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d1b1f51ae107..a702bad0800c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -21,8 +21,9 @@ use std::sync::Arc; use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::physical_plan::FileSink; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; +use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion::physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, @@ -53,11 +54,8 @@ pub fn serialize_physical_aggr_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let ordering_req = match aggr_expr.order_bys() { - Some(order) => order.clone(), - None => LexOrdering::default(), - }; - let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?; + let order_bys = + serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); @@ -67,7 +65,7 @@ pub fn serialize_physical_aggr_expr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, - ordering_req, + ordering_req: order_bys, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), fun_definition: (!buf.is_empty()).then_some(buf), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7d56bb6c5db1..ab419f371349 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -60,7 +60,7 @@ use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion::physical_expr::{ - LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, + LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -74,6 +74,7 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, + SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; @@ -109,8 +110,7 @@ use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::protobuf; -use datafusion_proto::protobuf::PhysicalPlanNode; +use datafusion_proto::protobuf::{self, PhysicalPlanNode}; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -321,9 +321,9 @@ fn roundtrip_udwf() -> Result<()> { &[ col("a", &schema)? ], - &LexOrdering::new(vec![ - PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), - ]), + &[ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)) + ], Arc::new(WindowFrame::new(None)), )); @@ -360,13 +360,13 @@ fn roundtrip_window() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( nth_value_window, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -380,7 +380,7 @@ fn roundtrip_window() -> Result<()> { .build() .map(Arc::new)?, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), )); @@ -400,7 +400,7 @@ fn roundtrip_window() -> Result<()> { let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, &[], - &LexOrdering::default(), + &[], Arc::new(window_frame), )); @@ -528,13 +528,13 @@ fn rountrip_aggregate_with_sort() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let sort_exprs = vec![PhysicalSortExpr { expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }]; let aggregates = vec![ @@ -654,7 +654,7 @@ fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -669,7 +669,8 @@ fn roundtrip_sort() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs, Arc::new(EmptyExec::new(schema)), @@ -681,7 +682,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -696,7 +697,8 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), @@ -1131,7 +1133,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr.clone(), &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), ))], filter, @@ -1176,13 +1178,13 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( udwf, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -1239,7 +1241,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr, &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), ))], filter, @@ -1335,13 +1337,14 @@ fn roundtrip_json_sink() -> Result<()> { file_sink_config, JsonWriterOptions::new(CompressionTypeVariant::UNCOMPRESSED), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1372,13 +1375,14 @@ fn roundtrip_csv_sink() -> Result<()> { file_sink_config, CsvWriterOptions::new(WriterBuilder::default(), CompressionTypeVariant::ZSTD), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; @@ -1428,13 +1432,14 @@ fn roundtrip_parquet_sink() -> Result<()> { file_sink_config, TableParquetOptions::default(), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1471,31 +1476,29 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { for left_order in &[ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_left.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - for right_order in &[ + for right_order in [ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_right.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - roundtrip_test(Arc::new( - datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(schema_left.clone())), - Arc::new(EmptyExec::new(schema_right.clone())), - on.clone(), - None, - join_type, - false, - left_order.clone(), - right_order.clone(), - *partition_mode, - )?, - ))?; + roundtrip_test(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + false, + left_order.clone(), + right_order, + *partition_mode, + )?))?; } } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index f42a3ad138c4..2ea2299c1fcf 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -138,7 +138,7 @@ impl SqlToRel<'_, S> { Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(into.name)?, - constraints: Constraints::empty(), + constraints: Constraints::default(), input: Arc::new(plan), if_not_exists: false, or_replace: false, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index ed77435d6a85..0966d9ff0538 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -160,9 +160,6 @@ SELECT approx_percentile_cont(c12) WITHIN GROUP (ORDER BY c12) FROM aggregate_te statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal SELECT approx_percentile_cont(0.95, c5) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 -statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported -SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5), approx_percentile_cont(0.2) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 - statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) IGNORE NULLS FROM aggregate_test_100 diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index ce23fe26528c..26fef6d666b1 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -370,7 +370,7 @@ query TT explain select number, letter, age, number as column4, letter as column5 from partial_sorted order by number desc, column4 desc, letter asc, column5 asc, age desc limit 3; ---- physical_plan -01)SortExec: TopK(fetch=3), expr=[number@0 DESC, column4@3 DESC, letter@1 ASC NULLS LAST, column5@4 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] 02)--ProjectionExec: expr=[number@0 as number, letter@1 as letter, age@2 as age, number@0 as column4, letter@1 as column5] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet @@ -380,7 +380,7 @@ explain select number + 1 as number_plus, number, number + 1 as other_number_plu ---- physical_plan 01)SortPreservingMergeExec: [number_plus@0 DESC, number@1 DESC, other_number_plus@2 DESC, age@3 ASC NULLS LAST], fetch=3 -02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, other_number_plus@2 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], sort_prefix=[number_plus@0 DESC, number@1 DESC] +02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], sort_prefix=[number_plus@0 DESC, number@1 DESC] 03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] 04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1