From 1e3e9cb462bf03033fafcde55536dc276e771b89 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Mon, 19 May 2025 10:37:05 +0200 Subject: [PATCH 1/6] Fix union schema name coercion --- datafusion/physical-plan/src/union.rs | 10 +- .../tests/cases/consumer_integration.rs | 22 ++ .../testdata/test_plans/multiple_unions.json | 328 ++++++++++++++++++ 3 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 datafusion/substrait/tests/testdata/test_plans/multiple_unions.json diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 2b666093f29e0..f2f9e7ce76c40 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -513,7 +513,12 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { let fields = (0..first_schema.fields().len()) .map(|i| { - inputs + // We take the name from the left side of the union to match how names are coerced during logical planning, + // which also uses the left side names. + let base_field = first_schema.field(i).clone(); + + // Coerce metadata and nullability across all inputs + let merged_field = inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -535,6 +540,9 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() + .with_name(base_field.name()); + + merged_field }) .collect::>(); diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bdeeeb585c0cb..4f75503bd3b76 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -560,4 +560,26 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_multiple_unions() -> Result<()> { + let plan_str = test_plan_to_string("multiple_unions.json").await?; + assert_snapshot!( + plan_str, + @"Projection: Utf8(\"people\") AS product_category, Utf8(\"people\")__temp__0 AS product_type, product_key\ + \n Union\ + \n Projection: Utf8(\"people\"), Utf8(\"people\") AS Utf8(\"people\")__temp__0, sales.product_key\ + \n Left Join: sales.product_key = food.@food_id\ + \n TableScan: sales\ + \n TableScan: food\ + \n Union\ + \n Projection: people.$f3, people.$f5, people.product_key0\ + \n Left Join: people.product_key0 = food.@food_id\ + \n TableScan: people\ + \n TableScan: food\ + \n TableScan: more_products" + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json new file mode 100644 index 0000000000000..8b82d6eec7552 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json @@ -0,0 +1,328 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [2, 3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["product_key"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "literal": { + "string": "people" + } + }, { + "literal": { + "string": "people" + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f3", "$f5", "product_key0"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "people" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f1000", "$f2000", "more_products_key0000"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "more_products" + ] + } + + } + }], + "op": "SET_OP_UNION_ALL" + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["product_category", "product_type", "product_key"] + } + }] +} \ No newline at end of file From 9c3b201a146607013a48abf8569390a29da005dc Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Mon, 19 May 2025 15:23:54 +0200 Subject: [PATCH 2/6] Address renaming for columns that are not in the top level as well --- datafusion/core/src/physical_planner.rs | 52 +++++++++++++++---------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 53199294709a4..4fd92e81e7ff9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, @@ -2069,29 +2071,37 @@ fn maybe_fix_physical_column_name( expr: Result>, input_physical_schema: &SchemaRef, ) -> Result> { - if let Ok(e) = &expr { - if let Some(column) = e.as_any().downcast_ref::() { - let physical_field = input_physical_schema.field(column.index()); - let expr_col_name = column.name(); - let physical_name = physical_field.name(); - - if physical_name != expr_col_name { - // handle edge cases where the physical_name contains ':'. - let colon_count = physical_name.matches(':').count(); - let mut splits = expr_col_name.match_indices(':'); - let split_pos = splits.nth(colon_count); - - if let Some((idx, _)) = split_pos { - let base_name = &expr_col_name[..idx]; - if base_name == physical_name { - let updated_column = Column::new(physical_name, column.index()); - return Ok(Arc::new(updated_column)); + expr.and_then(|e| { + e.transform_down(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let idx = column.index(); + let physical_field = input_physical_schema.field(idx); + let expr_col_name = column.name(); + let physical_name = physical_field.name(); + + if expr_col_name != physical_name { + // handle edge cases where the physical_name contains ':'. + let colon_count = physical_name.matches(':').count(); + let mut splits = expr_col_name.match_indices(':'); + let split_pos = splits.nth(colon_count); + + if let Some((i, _)) = split_pos { + let base_name = &expr_col_name[..i]; + if base_name == physical_name { + let updated_column = Column::new(physical_name, idx); + return Ok(Transformed::yes(Arc::new(updated_column))); + } } } + + // If names already match or fix is not possible, just leave it as it is + Ok(Transformed::no(node)) + } else { + Ok(Transformed::no(node)) } - } - } - expr + }) + .data() + }) } struct OptimizationInvariantChecker<'a> { From f7560089565a2f8c1a47c72eaa5717a191f797cc Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Mon, 19 May 2025 16:48:46 +0200 Subject: [PATCH 3/6] Add unit test --- datafusion/core/src/physical_planner.rs | 48 ++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4fd92e81e7ff9..dce9cbe886fb9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2205,8 +2205,11 @@ mod tests { use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{ + col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2723,6 +2726,49 @@ mod tests { assert_eq!(col.name(), "metric:avg"); } + + #[tokio::test] + async fn test_maybe_fix_nested_column_name_with_colon() { + let schema = + Schema::new(vec![Field::new("column", DataType::Int32, false)]); + let schema_ref: SchemaRef = Arc::new(schema); + + // Construct the nested expr + let col_expr = + Arc::new(Column::new("column:1", 0)) as Arc; + let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); + + // Create a binary expression and put the column inside + let binary_expr = Arc::new(BinaryExpr::new( + is_not_null_expr.clone(), + Operator::Or, + is_not_null_expr.clone(), + )) as Arc; + + let fixed_expr = + maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap(); + + let bin = fixed_expr + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + + // Check that both sides where renamed + for expr in &[bin.left(), bin.right()] { + let is_not_null = expr + .as_any() + .downcast_ref::() + .expect("Expected IsNotNull"); + + let col = is_not_null + .arg() + .as_any() + .downcast_ref::() + .expect("Expected Column"); + + assert_eq!(col.name(), "column"); + } + } struct ErrorExtensionPlanner {} #[async_trait] From 12cfeba5900ef247e6f09cdb005e91a25be89358 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Mon, 19 May 2025 17:26:24 +0200 Subject: [PATCH 4/6] Format --- datafusion/core/src/physical_planner.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index dce9cbe886fb9..c0269f43a11fc 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2729,13 +2729,11 @@ mod tests { #[tokio::test] async fn test_maybe_fix_nested_column_name_with_colon() { - let schema = - Schema::new(vec![Field::new("column", DataType::Int32, false)]); + let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]); let schema_ref: SchemaRef = Arc::new(schema); // Construct the nested expr - let col_expr = - Arc::new(Column::new("column:1", 0)) as Arc; + let col_expr = Arc::new(Column::new("column:1", 0)) as Arc; let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); // Create a binary expression and put the column inside From e69a28e4689227efba9c8f35e6c15fabbbd868f5 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Mon, 19 May 2025 18:28:57 +0200 Subject: [PATCH 5/6] Use insta tests properly --- .../tests/cases/consumer_integration.rs | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index f86396d905f98..4a121e41d27e7 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -566,18 +566,20 @@ mod tests { let plan_str = test_plan_to_string("multiple_unions.json").await?; assert_snapshot!( plan_str, - @"Projection: Utf8(\"people\") AS product_category, Utf8(\"people\")__temp__0 AS product_type, product_key\ - \n Union\ - \n Projection: Utf8(\"people\"), Utf8(\"people\") AS Utf8(\"people\")__temp__0, sales.product_key\ - \n Left Join: sales.product_key = food.@food_id\ - \n TableScan: sales\ - \n TableScan: food\ - \n Union\ - \n Projection: people.$f3, people.$f5, people.product_key0\ - \n Left Join: people.product_key0 = food.@food_id\ - \n TableScan: people\ - \n TableScan: food\ - \n TableScan: more_products" + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# ); Ok(()) From 8d627db46f3d3802b38f0c81f009f4c1121275ca Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Tue, 20 May 2025 18:49:06 +0200 Subject: [PATCH 6/6] Address review - comment + minor simplification change --- datafusion/core/src/physical_planner.rs | 53 ++++++++++++------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5930bb6eb48e7..2588b5816780f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2069,37 +2069,36 @@ fn maybe_fix_physical_column_name( expr: Result>, input_physical_schema: &SchemaRef, ) -> Result> { - expr.and_then(|e| { - e.transform_down(|node| { - if let Some(column) = node.as_any().downcast_ref::() { - let idx = column.index(); - let physical_field = input_physical_schema.field(idx); - let expr_col_name = column.name(); - let physical_name = physical_field.name(); - - if expr_col_name != physical_name { - // handle edge cases where the physical_name contains ':'. - let colon_count = physical_name.matches(':').count(); - let mut splits = expr_col_name.match_indices(':'); - let split_pos = splits.nth(colon_count); - - if let Some((i, _)) = split_pos { - let base_name = &expr_col_name[..i]; - if base_name == physical_name { - let updated_column = Column::new(physical_name, idx); - return Ok(Transformed::yes(Arc::new(updated_column))); - } + let Ok(expr) = expr else { return expr }; + expr.transform_down(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let idx = column.index(); + let physical_field = input_physical_schema.field(idx); + let expr_col_name = column.name(); + let physical_name = physical_field.name(); + + if expr_col_name != physical_name { + // handle edge cases where the physical_name contains ':'. + let colon_count = physical_name.matches(':').count(); + let mut splits = expr_col_name.match_indices(':'); + let split_pos = splits.nth(colon_count); + + if let Some((i, _)) = split_pos { + let base_name = &expr_col_name[..i]; + if base_name == physical_name { + let updated_column = Column::new(physical_name, idx); + return Ok(Transformed::yes(Arc::new(updated_column))); } } - - // If names already match or fix is not possible, just leave it as it is - Ok(Transformed::no(node)) - } else { - Ok(Transformed::no(node)) } - }) - .data() + + // If names already match or fix is not possible, just leave it as it is + Ok(Transformed::no(node)) + } else { + Ok(Transformed::no(node)) + } }) + .data() } struct OptimizationInvariantChecker<'a> {