Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
ScalarValue,
Expand Down Expand Up @@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
expr: Result<Arc<dyn PhysicalExpr>>,
input_physical_schema: &SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>> {
if let Ok(e) = &expr {
if let Some(column) = e.as_any().downcast_ref::<Column>() {
let physical_field = input_physical_schema.field(column.index());
let Ok(expr) = expr else { return expr };
expr.transform_down(|node| {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let idx = column.index();
let physical_field = input_physical_schema.field(idx);
let expr_col_name = column.name();
let physical_name = physical_field.name();

if physical_name != expr_col_name {
if expr_col_name != physical_name {
// handle edge cases where the physical_name contains ':'.
let colon_count = physical_name.matches(':').count();
let mut splits = expr_col_name.match_indices(':');
let split_pos = splits.nth(colon_count);

if let Some((idx, _)) = split_pos {
let base_name = &expr_col_name[..idx];
if let Some((i, _)) = split_pos {
let base_name = &expr_col_name[..i];
if base_name == physical_name {
let updated_column = Column::new(physical_name, column.index());
return Ok(Arc::new(updated_column));
let updated_column = Column::new(physical_name, idx);
return Ok(Transformed::yes(Arc::new(updated_column)));
}
}
}

// If names already match or fix is not possible, just leave it as it is
Ok(Transformed::no(node))
} else {
Ok(Transformed::no(node))
}
}
expr
})
.data()
}

struct OptimizationInvariantChecker<'a> {
Expand Down Expand Up @@ -2201,8 +2210,11 @@ mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, TableReference};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
use datafusion_expr::{
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

Expand Down Expand Up @@ -2719,6 +2731,47 @@ mod tests {

assert_eq!(col.name(), "metric:avg");
}

#[tokio::test]
async fn test_maybe_fix_nested_column_name_with_colon() {
let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]);
let schema_ref: SchemaRef = Arc::new(schema);

// Construct the nested expr
let col_expr = Arc::new(Column::new("column:1", 0)) as Arc<dyn PhysicalExpr>;
let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone()));

// Create a binary expression and put the column inside
let binary_expr = Arc::new(BinaryExpr::new(
is_not_null_expr.clone(),
Operator::Or,
is_not_null_expr.clone(),
)) as Arc<dyn PhysicalExpr>;

let fixed_expr =
maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap();

let bin = fixed_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("Expected BinaryExpr");

// Check that both sides where renamed
for expr in &[bin.left(), bin.right()] {
let is_not_null = expr
.as_any()
.downcast_ref::<IsNotNullExpr>()
.expect("Expected IsNotNull");

let col = is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.expect("Expected Column");

assert_eq!(col.name(), "column");
}
}
struct ErrorExtensionPlanner {}

#[async_trait]
Expand Down
10 changes: 9 additions & 1 deletion datafusion/physical-plan/src/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,12 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {

let fields = (0..first_schema.fields().len())
.map(|i| {
inputs
// We take the name from the left side of the union to match how names are coerced during logical planning,
// which also uses the left side names.
let base_field = first_schema.field(i).clone();

// Coerce metadata and nullability across all inputs
let merged_field = inputs
.iter()
.enumerate()
.map(|(input_idx, input)| {
Expand All @@ -535,6 +540,9 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
// We can unwrap this because if inputs was empty, this would've already panic'ed when we
// indexed into inputs[0].
.unwrap()
.with_name(base_field.name());

merged_field
})
.collect::<Vec<_>>();

Expand Down
24 changes: 24 additions & 0 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,28 @@ mod tests {
);
Ok(())
}

#[tokio::test]
async fn test_multiple_unions() -> Result<()> {
let plan_str = test_plan_to_string("multiple_unions.json").await?;
assert_snapshot!(
plan_str,
@r#"
Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key
Union
Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key
Left Join: sales.product_key = food.@food_id
TableScan: sales
TableScan: food
Union
Projection: people.$f3, people.$f5, people.product_key0
Left Join: people.product_key0 = food.@food_id
TableScan: people
TableScan: food
TableScan: more_products
"#
);

Ok(())
}
}
Loading
Loading