From 0842af71a17785b38202c2df5f3517e519ca08f6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 7 Jan 2025 15:27:00 -0500 Subject: [PATCH] Encapsulate fields of `EquivalenceGroup` --- .../tests/fuzz_cases/equivalence/utils.rs | 6 ++--- .../physical-expr/src/equivalence/class.rs | 25 ++++++++++++++----- .../physical-expr/src/equivalence/mod.rs | 6 ++--- .../src/equivalence/properties.rs | 5 ++-- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index e18dab35fc91..68541f989d3a 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -179,7 +179,7 @@ fn add_equal_conditions_test() -> Result<()> { // This new entry is redundant, size shouldn't increase eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -188,7 +188,7 @@ fn add_equal_conditions_test() -> Result<()> { // however there shouldn't be any new equivalence class eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -202,7 +202,7 @@ fn add_equal_conditions_test() -> Result<()> { // Hence equivalent class count should decrease from 2 to 1. eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 9e535a94eb6e..cb11409479a8 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -20,12 +20,12 @@ use crate::{ expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use std::fmt::Display; -use std::sync::Arc; - use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{JoinType, ScalarValue}; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; +use std::fmt::Display; +use std::sync::Arc; +use std::vec::IntoIter; use indexmap::{IndexMap, IndexSet}; @@ -323,11 +323,10 @@ impl Display for EquivalenceClass { } } -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. +/// A collection of distinct `EquivalenceClass`es #[derive(Debug, Clone)] pub struct EquivalenceGroup { - pub classes: Vec, + classes: Vec, } impl EquivalenceGroup { @@ -717,6 +716,20 @@ impl EquivalenceGroup { .zip(right_children) .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) } + + /// Return the inner classes of this equivalence group. + pub fn into_inner(self) -> Vec { + self.classes + } +} + +impl IntoIterator for EquivalenceGroup { + type Item = EquivalenceClass; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.classes.into_iter() + } } impl Display for EquivalenceGroup { diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index d4c14f7bc8ff..b50633d777f7 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -254,7 +254,7 @@ mod tests { // This new entry is redundant, size shouldn't increase eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -263,7 +263,7 @@ mod tests { // however there shouldn't be any new equivalence class eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); @@ -277,7 +277,7 @@ mod tests { // Hence equivalent class count should decrease from 2 to 1. eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; + let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); assert!(eq_groups.contains(&col_a_expr)); assert!(eq_groups.contains(&col_b_expr)); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index c3d458103285..ddc36d193dce 100755 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -338,7 +338,6 @@ impl EquivalenceProperties { let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); let eq_class = self .eq_group - .classes .iter() .find_map(|class| { class @@ -1234,7 +1233,7 @@ impl EquivalenceProperties { // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group.classes { + for eq_class in self.eq_group { let new_eq_exprs = eq_class .into_vec() .into_iter() @@ -2315,7 +2314,7 @@ mod tests { // At the output a1=a2=a3=a4 assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; + let eq_class = out_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_class.len(), 4); assert!(eq_class.contains(col_a1)); assert!(eq_class.contains(col_a2));