Skip to content

Fix ArrayAgg schema mismatch issue #8055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,92 @@ mod tests {

use super::*;

async fn assert_logical_expr_schema_eq_physical_expr_schema(
df: DataFrame,
) -> Result<()> {
let logical_expr_dfschema = df.schema();
let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned());
let batches = df.collect().await?;
let physical_expr_schema = batches[0].schema();
assert_eq!(logical_expr_schema, physical_expr_schema);
Ok(())
}

#[tokio::test]
async fn test_array_agg_ord_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg("double_field" ORDER BY "string_field") as "double_field",
array_agg("string_field" ORDER BY "string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(3.0, 'c')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg("double_field") as "double_field",
array_agg("string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_distinct_schema() -> Result<()> {
let ctx = SessionContext::new();

let create_table_query = r#"
CREATE TABLE test_table (
"double_field" DOUBLE,
"string_field" VARCHAR
) AS VALUES
(1.0, 'a'),
(2.0, 'b'),
(2.0, 'a')
"#;
ctx.sql(create_table_query).await?;

let query = r#"SELECT
array_agg(distinct "double_field") as "double_field",
array_agg(distinct "string_field") as "string_field"
FROM test_table"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn select_columns() -> Result<()> {
// build plan using Table API
Expand Down
37 changes: 32 additions & 5 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct ArrayAgg {
name: String,
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
is_expr_nullable: bool,
}

impl ArrayAgg {
Expand All @@ -45,11 +46,13 @@ impl ArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
is_expr_nullable: bool,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type: data_type,
is_expr_nullable,
}
}
}
Expand All @@ -62,8 +65,9 @@ impl AggregateExpr for ArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable,
))
}

Expand All @@ -77,7 +81,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable,
)])
}

Expand Down Expand Up @@ -184,7 +188,6 @@ mod tests {
use super::*;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
Expand All @@ -195,6 +198,30 @@ mod tests {
use datafusion_common::DataFusionError;
use datafusion_common::Result;

macro_rules! test_op {
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type())
};
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;

let agg = Arc::new(<$OP>::new(
col("a", &schema)?,
"bla".to_string(),
$EXPECTED_DATATYPE,
true,
));
let actual = aggregate(&batch, agg)?;
let expected = ScalarValue::from($EXPECTED);

assert_eq!(expected, actual);

Ok(()) as Result<(), DataFusionError>
}};
}

#[test]
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand All @@ -208,7 +235,7 @@ mod tests {
])]);
let list = ScalarValue::List(Arc::new(list));

generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}

#[test]
Expand Down Expand Up @@ -264,7 +291,7 @@ mod tests {

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();

generic_test_op!(
test_op!(
array,
DataType::List(Arc::new(Field::new_list(
"item",
Expand Down
11 changes: 9 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// Whether the input expression can produce NULL values
is_expr_nullable: bool,
}

impl DistinctArrayAgg {
Expand All @@ -48,12 +50,14 @@ impl DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
is_expr_nullable: bool,
) -> Self {
let name = name.into();
Self {
name,
expr,
input_data_type,
is_expr_nullable,
}
}
}
Expand All @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable,
))
}

Expand All @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable,
)])
}

Expand Down Expand Up @@ -238,6 +243,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));
let actual = aggregate(&batch, agg)?;

Expand All @@ -255,6 +261,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
true,
));

let mut accum1 = agg.create_accumulator()?;
Expand Down
10 changes: 7 additions & 3 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use itertools::izip;
pub struct OrderSensitiveArrayAgg {
name: String,
input_data_type: DataType,
is_expr_nullable: bool,
order_by_data_types: Vec<DataType>,
expr: Arc<dyn PhysicalExpr>,
ordering_req: LexOrdering,
Expand All @@ -61,13 +62,15 @@ impl OrderSensitiveArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
is_expr_nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
Self {
name: name.into(),
expr,
input_data_type,
is_expr_nullable,
order_by_data_types,
ordering_req,
}
Expand All @@ -82,8 +85,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable,
))
}

Expand All @@ -99,13 +103,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
false,
self.is_expr_nullable, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types);
fields.push(Field::new_list(
format_state_name(&self.name, "array_agg_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
false,
self.is_expr_nullable,
));
Ok(fields)
}
Expand Down
23 changes: 17 additions & 6 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,21 @@ pub fn create_aggregate_expr(
),
(AggregateFunction::ArrayAgg, false) => {
let expr = input_phy_exprs[0].clone();
let is_expr_nullable = expr.nullable(input_schema)?;

if ordering_req.is_empty() {
Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
Arc::new(expressions::ArrayAgg::new(
expr,
name,
data_type,
is_expr_nullable,
))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
is_expr_nullable,
ordering_types,
ordering_req.to_vec(),
))
Expand All @@ -132,10 +140,13 @@ pub fn create_aggregate_expr(
"ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available"
);
}
let expr = input_phy_exprs[0].clone();
let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
input_phy_exprs[0].clone(),
expr,
name,
data_type,
is_expr_nullable,
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Expand Down Expand Up @@ -432,8 +443,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true,),
false,
Field::new("item", data_type.clone(), true),
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down Expand Up @@ -471,8 +482,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true,),
false,
Field::new("item", data_type.clone(), true),
true,
),
result_agg_phy_exprs.field().unwrap()
);
Expand Down