Skip to content

Commit 91a44c1

Browse files
authored
Fix ArrayAgg schema mismatch issue (#8055)
* fix schema Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * upd parquet-testing Signed-off-by: jayzhan211 <[email protected]> * avoid parquet file Signed-off-by: jayzhan211 <[email protected]> * reset parquet-testing Signed-off-by: jayzhan211 <[email protected]> * remove file Signed-off-by: jayzhan211 <[email protected]> * fix Signed-off-by: jayzhan211 <[email protected]> * fix test Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * rename and upd docstring Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 1803b25 commit 91a44c1

File tree

5 files changed

+160
-20
lines changed

5 files changed

+160
-20
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,92 @@ mod tests {
13401340

13411341
use super::*;
13421342

1343+
async fn assert_logical_expr_schema_eq_physical_expr_schema(
1344+
df: DataFrame,
1345+
) -> Result<()> {
1346+
let logical_expr_dfschema = df.schema();
1347+
let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned());
1348+
let batches = df.collect().await?;
1349+
let physical_expr_schema = batches[0].schema();
1350+
assert_eq!(logical_expr_schema, physical_expr_schema);
1351+
Ok(())
1352+
}
1353+
1354+
#[tokio::test]
1355+
async fn test_array_agg_ord_schema() -> Result<()> {
1356+
let ctx = SessionContext::new();
1357+
1358+
let create_table_query = r#"
1359+
CREATE TABLE test_table (
1360+
"double_field" DOUBLE,
1361+
"string_field" VARCHAR
1362+
) AS VALUES
1363+
(1.0, 'a'),
1364+
(2.0, 'b'),
1365+
(3.0, 'c')
1366+
"#;
1367+
ctx.sql(create_table_query).await?;
1368+
1369+
let query = r#"SELECT
1370+
array_agg("double_field" ORDER BY "string_field") as "double_field",
1371+
array_agg("string_field" ORDER BY "string_field") as "string_field"
1372+
FROM test_table"#;
1373+
1374+
let result = ctx.sql(query).await?;
1375+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
1376+
Ok(())
1377+
}
1378+
1379+
#[tokio::test]
1380+
async fn test_array_agg_schema() -> Result<()> {
1381+
let ctx = SessionContext::new();
1382+
1383+
let create_table_query = r#"
1384+
CREATE TABLE test_table (
1385+
"double_field" DOUBLE,
1386+
"string_field" VARCHAR
1387+
) AS VALUES
1388+
(1.0, 'a'),
1389+
(2.0, 'b'),
1390+
(3.0, 'c')
1391+
"#;
1392+
ctx.sql(create_table_query).await?;
1393+
1394+
let query = r#"SELECT
1395+
array_agg("double_field") as "double_field",
1396+
array_agg("string_field") as "string_field"
1397+
FROM test_table"#;
1398+
1399+
let result = ctx.sql(query).await?;
1400+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
1401+
Ok(())
1402+
}
1403+
1404+
#[tokio::test]
1405+
async fn test_array_agg_distinct_schema() -> Result<()> {
1406+
let ctx = SessionContext::new();
1407+
1408+
let create_table_query = r#"
1409+
CREATE TABLE test_table (
1410+
"double_field" DOUBLE,
1411+
"string_field" VARCHAR
1412+
) AS VALUES
1413+
(1.0, 'a'),
1414+
(2.0, 'b'),
1415+
(2.0, 'a')
1416+
"#;
1417+
ctx.sql(create_table_query).await?;
1418+
1419+
let query = r#"SELECT
1420+
array_agg(distinct "double_field") as "double_field",
1421+
array_agg(distinct "string_field") as "string_field"
1422+
FROM test_table"#;
1423+
1424+
let result = ctx.sql(query).await?;
1425+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
1426+
Ok(())
1427+
}
1428+
13431429
#[tokio::test]
13441430
async fn select_columns() -> Result<()> {
13451431
// build plan using Table API

datafusion/physical-expr/src/aggregate/array_agg.rs

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,14 @@ use std::sync::Arc;
3434
/// ARRAY_AGG aggregate expression
3535
#[derive(Debug)]
3636
pub struct ArrayAgg {
37+
/// Column name
3738
name: String,
39+
/// The DataType for the input expression
3840
input_data_type: DataType,
41+
/// The input expression
3942
expr: Arc<dyn PhysicalExpr>,
43+
/// If the input expression can have NULLs
44+
nullable: bool,
4045
}
4146

4247
impl ArrayAgg {
@@ -45,11 +50,13 @@ impl ArrayAgg {
4550
expr: Arc<dyn PhysicalExpr>,
4651
name: impl Into<String>,
4752
data_type: DataType,
53+
nullable: bool,
4854
) -> Self {
4955
Self {
5056
name: name.into(),
51-
expr,
5257
input_data_type: data_type,
58+
expr,
59+
nullable,
5360
}
5461
}
5562
}
@@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg {
6269
fn field(&self) -> Result<Field> {
6370
Ok(Field::new_list(
6471
&self.name,
72+
// This should be the same as return type of AggregateFunction::ArrayAgg
6573
Field::new("item", self.input_data_type.clone(), true),
66-
false,
74+
self.nullable,
6775
))
6876
}
6977

@@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg {
7785
Ok(vec![Field::new_list(
7886
format_state_name(&self.name, "array_agg"),
7987
Field::new("item", self.input_data_type.clone(), true),
80-
false,
88+
self.nullable,
8189
)])
8290
}
8391

@@ -184,7 +192,6 @@ mod tests {
184192
use super::*;
185193
use crate::expressions::col;
186194
use crate::expressions::tests::aggregate;
187-
use crate::generic_test_op;
188195
use arrow::array::ArrayRef;
189196
use arrow::array::Int32Array;
190197
use arrow::datatypes::*;
@@ -195,6 +202,30 @@ mod tests {
195202
use datafusion_common::DataFusionError;
196203
use datafusion_common::Result;
197204

205+
macro_rules! test_op {
206+
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
207+
test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type())
208+
};
209+
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
210+
let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);
211+
212+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
213+
214+
let agg = Arc::new(<$OP>::new(
215+
col("a", &schema)?,
216+
"bla".to_string(),
217+
$EXPECTED_DATATYPE,
218+
true,
219+
));
220+
let actual = aggregate(&batch, agg)?;
221+
let expected = ScalarValue::from($EXPECTED);
222+
223+
assert_eq!(expected, actual);
224+
225+
Ok(()) as Result<(), DataFusionError>
226+
}};
227+
}
228+
198229
#[test]
199230
fn array_agg_i32() -> Result<()> {
200231
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
@@ -208,7 +239,7 @@ mod tests {
208239
])]);
209240
let list = ScalarValue::List(Arc::new(list));
210241

211-
generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
242+
test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
212243
}
213244

214245
#[test]
@@ -264,7 +295,7 @@ mod tests {
264295

265296
let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
266297

267-
generic_test_op!(
298+
test_op!(
268299
array,
269300
DataType::List(Arc::new(Field::new_list(
270301
"item",

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ pub struct DistinctArrayAgg {
4040
input_data_type: DataType,
4141
/// The input expression
4242
expr: Arc<dyn PhysicalExpr>,
43+
/// If the input expression can have NULLs
44+
nullable: bool,
4345
}
4446

4547
impl DistinctArrayAgg {
@@ -48,12 +50,14 @@ impl DistinctArrayAgg {
4850
expr: Arc<dyn PhysicalExpr>,
4951
name: impl Into<String>,
5052
input_data_type: DataType,
53+
nullable: bool,
5154
) -> Self {
5255
let name = name.into();
5356
Self {
5457
name,
55-
expr,
5658
input_data_type,
59+
expr,
60+
nullable,
5761
}
5862
}
5963
}
@@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg {
6771
fn field(&self) -> Result<Field> {
6872
Ok(Field::new_list(
6973
&self.name,
74+
// This should be the same as return type of AggregateFunction::ArrayAgg
7075
Field::new("item", self.input_data_type.clone(), true),
71-
false,
76+
self.nullable,
7277
))
7378
}
7479

@@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg {
8287
Ok(vec![Field::new_list(
8388
format_state_name(&self.name, "distinct_array_agg"),
8489
Field::new("item", self.input_data_type.clone(), true),
85-
false,
90+
self.nullable,
8691
)])
8792
}
8893

@@ -238,6 +243,7 @@ mod tests {
238243
col("a", &schema)?,
239244
"bla".to_string(),
240245
datatype,
246+
true,
241247
));
242248
let actual = aggregate(&batch, agg)?;
243249

@@ -255,6 +261,7 @@ mod tests {
255261
col("a", &schema)?,
256262
"bla".to_string(),
257263
datatype,
264+
true,
258265
));
259266

260267
let mut accum1 = agg.create_accumulator()?;

datafusion/physical-expr/src/aggregate/array_agg_ordered.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,17 @@ use itertools::izip;
4848
/// and that can merge aggregations from multiple partitions.
4949
#[derive(Debug)]
5050
pub struct OrderSensitiveArrayAgg {
51+
/// Column name
5152
name: String,
53+
/// The DataType for the input expression
5254
input_data_type: DataType,
53-
order_by_data_types: Vec<DataType>,
55+
/// The input expression
5456
expr: Arc<dyn PhysicalExpr>,
57+
/// If the input expression can have NULLs
58+
nullable: bool,
59+
/// Ordering data types
60+
order_by_data_types: Vec<DataType>,
61+
/// Ordering requirement
5562
ordering_req: LexOrdering,
5663
}
5764

@@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg {
6168
expr: Arc<dyn PhysicalExpr>,
6269
name: impl Into<String>,
6370
input_data_type: DataType,
71+
nullable: bool,
6472
order_by_data_types: Vec<DataType>,
6573
ordering_req: LexOrdering,
6674
) -> Self {
6775
Self {
6876
name: name.into(),
69-
expr,
7077
input_data_type,
78+
expr,
79+
nullable,
7180
order_by_data_types,
7281
ordering_req,
7382
}
@@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
8291
fn field(&self) -> Result<Field> {
8392
Ok(Field::new_list(
8493
&self.name,
94+
// This should be the same as return type of AggregateFunction::ArrayAgg
8595
Field::new("item", self.input_data_type.clone(), true),
86-
false,
96+
self.nullable,
8797
))
8898
}
8999

@@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
99109
let mut fields = vec![Field::new_list(
100110
format_state_name(&self.name, "array_agg"),
101111
Field::new("item", self.input_data_type.clone(), true),
102-
false,
112+
self.nullable, // This should be the same as field()
103113
)];
104114
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
105115
fields.push(Field::new_list(
106116
format_state_name(&self.name, "array_agg_orderings"),
107117
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
108-
false,
118+
self.nullable,
109119
));
110120
Ok(fields)
111121
}

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,16 @@ pub fn create_aggregate_expr(
114114
),
115115
(AggregateFunction::ArrayAgg, false) => {
116116
let expr = input_phy_exprs[0].clone();
117+
let nullable = expr.nullable(input_schema)?;
118+
117119
if ordering_req.is_empty() {
118-
Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
120+
Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable))
119121
} else {
120122
Arc::new(expressions::OrderSensitiveArrayAgg::new(
121123
expr,
122124
name,
123125
data_type,
126+
nullable,
124127
ordering_types,
125128
ordering_req.to_vec(),
126129
))
@@ -132,10 +135,13 @@ pub fn create_aggregate_expr(
132135
"ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available"
133136
);
134137
}
138+
let expr = input_phy_exprs[0].clone();
139+
let is_expr_nullable = expr.nullable(input_schema)?;
135140
Arc::new(expressions::DistinctArrayAgg::new(
136-
input_phy_exprs[0].clone(),
141+
expr,
137142
name,
138143
data_type,
144+
is_expr_nullable,
139145
))
140146
}
141147
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
@@ -432,8 +438,8 @@ mod tests {
432438
assert_eq!(
433439
Field::new_list(
434440
"c1",
435-
Field::new("item", data_type.clone(), true,),
436-
false,
441+
Field::new("item", data_type.clone(), true),
442+
true,
437443
),
438444
result_agg_phy_exprs.field().unwrap()
439445
);
@@ -471,8 +477,8 @@ mod tests {
471477
assert_eq!(
472478
Field::new_list(
473479
"c1",
474-
Field::new("item", data_type.clone(), true,),
475-
false,
480+
Field::new("item", data_type.clone(), true),
481+
true,
476482
),
477483
result_agg_phy_exprs.field().unwrap()
478484
);

0 commit comments

Comments
 (0)