Skip to content

Commit 3860cd3

Browse files
authored
Fix ambiguous reference error in filter plan (#1925)
* Add test that fails with ambiguous reference to column * Use plan schema rather than combined merged schema in filter optimization * First try to resolve predicate using plan's schema then fall back to using all schemas * Revert testing submodule change
1 parent fd17765 commit 3860cd3

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

datafusion/src/optimizer/common_subexpr_eliminate.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,19 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result<Logi
111111
}))
112112
}
113113
LogicalPlan::Filter(Filter { predicate, input }) => {
114-
let schemas = plan.all_schemas();
115-
let all_schema =
116-
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
117-
lhs.merge(rhs);
118-
lhs
119-
});
120-
let data_type = predicate.get_type(&all_schema)?;
114+
let schema = plan.schema().as_ref().clone();
115+
let data_type = if let Ok(data_type) = predicate.get_type(&schema) {
116+
data_type
117+
} else {
118+
// predicate type could not be resolved in schema, fall back to all schemas
119+
let schemas = plan.all_schemas();
120+
let all_schema =
121+
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
122+
lhs.merge(rhs);
123+
lhs
124+
});
125+
predicate.get_type(&all_schema)?
126+
};
121127

122128
let mut id_array = vec![];
123129
expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?;

datafusion/tests/dataframe.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::error::Result;
2828
use datafusion::execution::context::ExecutionContext;
2929
use datafusion::logical_plan::{col, Expr};
3030
use datafusion::{datasource::MemTable, prelude::JoinType};
31+
use datafusion_expr::lit;
3132

3233
#[tokio::test]
3334
async fn join() -> Result<()> {
@@ -120,3 +121,34 @@ async fn sort_on_unprojected_columns() -> Result<()> {
120121

121122
Ok(())
122123
}
124+
125+
#[tokio::test]
126+
async fn filter_with_alias_overwrite() -> Result<()> {
127+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
128+
129+
let batch = RecordBatch::try_new(
130+
Arc::new(schema.clone()),
131+
vec![Arc::new(Int32Array::from_slice(&[1, 10, 10, 100]))],
132+
)
133+
.unwrap();
134+
135+
let mut ctx = ExecutionContext::new();
136+
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap();
137+
ctx.register_table("t", Arc::new(provider)).unwrap();
138+
139+
let df = ctx
140+
.table("t")
141+
.unwrap()
142+
.select(vec![(col("a").eq(lit(10))).alias("a")])
143+
.unwrap()
144+
.filter(col("a"))
145+
.unwrap();
146+
let results = df.collect().await.unwrap();
147+
148+
let expected = vec![
149+
"+------+", "| a |", "+------+", "| true |", "| true |", "+------+",
150+
];
151+
assert_batches_eq!(expected, &results);
152+
153+
Ok(())
154+
}

0 commit comments

Comments
 (0)