Skip to content

Commit 938bed7

Browse files
authored
Infer placeholder datatype for Expr::InSubquery (apache#80)
* infer placeholder datatype for InSubquery * update comment * only infer subquery if exactly 1 field
1 parent 0fbf212 commit 938bed7

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

datafusion/expr/src/expr.rs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,26 @@ impl Expr {
16231623
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
16241624
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
16251625
}
1626+
Expr::InSubquery(InSubquery {
1627+
expr,
1628+
subquery,
1629+
negated: _,
1630+
}) => {
1631+
let subquery_schema = subquery.subquery.schema();
1632+
let fields = subquery_schema.fields();
1633+
1634+
// only supports subquery with exactly 1 field
1635+
if let [first_field] = &fields[..] {
1636+
rewrite_placeholder(
1637+
expr.as_mut(),
1638+
&Expr::Column(Column {
1639+
relation: None,
1640+
name: first_field.name().clone(),
1641+
}),
1642+
schema,
1643+
)?;
1644+
}
1645+
}
16261646
Expr::Placeholder(_) => {
16271647
has_placeholder = true;
16281648
}
@@ -2801,7 +2821,8 @@ mod test {
28012821
use crate::expr_fn::col;
28022822
use crate::{
28032823
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
2804-
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
2824+
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
2825+
ScalarUDFImpl, TableScan, Volatility,
28052826
};
28062827
use arrow::datatypes::{Field, Schema};
28072828
use sqlparser::ast;
@@ -2863,6 +2884,86 @@ mod test {
28632884
}
28642885
}
28652886

2887+
#[test]
2888+
fn infer_placeholder_in_subquery() -> Result<()> {
2889+
// Schema for my_table: A (Int32), B (Int32)
2890+
let schema = Arc::new(Schema::new(vec![
2891+
Field::new("A", DataType::Int32, true),
2892+
Field::new("B", DataType::Int32, true),
2893+
]));
2894+
2895+
let source = Arc::new(LogicalTableSource::new(schema.clone()));
2896+
2897+
// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
2898+
let placeholder = Expr::Placeholder(Placeholder {
2899+
id: "$1".to_string(),
2900+
data_type: None,
2901+
});
2902+
2903+
// Subquery: SELECT A FROM my_table WHERE B > 3
2904+
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
2905+
left: Box::new(col("B")),
2906+
op: Operator::Gt,
2907+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
2908+
});
2909+
2910+
let subquery_scan = LogicalPlan::TableScan(TableScan {
2911+
table_name: TableReference::from("my_table"),
2912+
source,
2913+
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
2914+
projection: None,
2915+
filters: vec![subquery_filter.clone()],
2916+
fetch: None,
2917+
});
2918+
2919+
let projected_fields = vec![Field::new("A", DataType::Int32, true)];
2920+
let projected_schema = Arc::new(DFSchema::from_unqualified_fields(
2921+
projected_fields.into(),
2922+
Default::default(),
2923+
)?);
2924+
2925+
let subquery = Subquery {
2926+
subquery: Arc::new(LogicalPlan::Projection(Projection {
2927+
expr: vec![col("A")],
2928+
input: Arc::new(subquery_scan),
2929+
schema: projected_schema,
2930+
})),
2931+
outer_ref_columns: vec![],
2932+
};
2933+
2934+
let in_subquery = Expr::InSubquery(InSubquery {
2935+
expr: Box::new(placeholder),
2936+
subquery,
2937+
negated: false,
2938+
});
2939+
2940+
let df_schema = DFSchema::try_from(schema)?;
2941+
2942+
let (inferred_expr, contains_placeholder) =
2943+
in_subquery.infer_placeholder_types(&df_schema)?;
2944+
2945+
assert!(
2946+
contains_placeholder,
2947+
"Expression should contain a placeholder"
2948+
);
2949+
2950+
match inferred_expr {
2951+
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
2952+
Expr::Placeholder(placeholder) => {
2953+
assert_eq!(
2954+
placeholder.data_type,
2955+
Some(DataType::Int32),
2956+
"Placeholder $1 should infer Int32"
2957+
);
2958+
}
2959+
_ => panic!("Expected Placeholder expression in InSubquery"),
2960+
},
2961+
_ => panic!("Expected InSubquery expression"),
2962+
}
2963+
2964+
Ok(())
2965+
}
2966+
28662967
#[test]
28672968
fn infer_placeholder_like_and_similar_to() {
28682969
// name LIKE $1

0 commit comments

Comments
 (0)