Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts};
use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer};
use datafusion::common::{not_impl_err, DFSchemaRef};
use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder};
use datafusion::logical_expr::{
Aggregate, Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder,
};
use substrait::proto::aggregate_function::AggregationInvocation;
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::AggregateRel;
Expand Down Expand Up @@ -116,12 +118,49 @@ pub async fn from_aggregate_rel(
.map(|e| name_tracker.get_uniquely_named_expr(e.clone()))
.collect::<Result<Vec<Expr>, _>>()?;

input.aggregate(group_exprs, aggr_exprs)?.build()
let plan = input.aggregate(group_exprs, aggr_exprs)?.build()?;
reorder_grouping_set_output(plan)
} else {
not_impl_err!("Aggregate without an input is not valid")
}
}

/// Reorders the output of grouping-set aggregates so the column layout matches the Substrait
/// specification. DataFusion's [`Aggregate::output_expressions`] produces
/// `[grouping keys..., __grouping_id, measures...]`, whereas Substrait requires
/// `[grouping keys..., measures..., grouping_id]`. A projection is added only when the internal
/// grouping id is not already in the final position.
fn reorder_grouping_set_output(
plan: LogicalPlan,
) -> datafusion::common::Result<LogicalPlan> {
match plan {
LogicalPlan::Aggregate(agg)
if matches!(agg.group_expr.first(), Some(Expr::GroupingSet(_))) =>
{
let mut columns = agg.schema.columns();
if let Some(idx) = columns
.iter()
.position(|col| col.name == Aggregate::INTERNAL_GROUPING_ID)
{
if idx == columns.len() - 1 {
return Ok(LogicalPlan::Aggregate(agg));
}

let grouping_id = columns.remove(idx);
columns.push(grouping_id);

let exprs = columns.into_iter().map(Expr::Column).collect::<Vec<_>>();
return LogicalPlanBuilder::from(LogicalPlan::Aggregate(agg))
.project(exprs)?
.build();
}

Ok(LogicalPlan::Aggregate(agg))
}
_ => Ok(plan),
}
}

#[allow(deprecated)]
async fn from_substrait_grouping(
consumer: &impl SubstraitConsumer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use datafusion::logical_expr::expr::Alias;
use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet};
use substrait::proto::aggregate_rel::{Grouping, Measure};
use substrait::proto::rel::RelType;
use substrait::proto::{AggregateRel, Expression, Rel};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::{rel_common, AggregateRel, Expression, Rel, RelCommon};

pub fn from_aggregate(
producer: &mut impl SubstraitProducer,
Expand All @@ -38,9 +39,15 @@ pub fn from_aggregate(
.map(|e| to_substrait_agg_measure(producer, e, agg.input.schema()))
.collect::<datafusion::common::Result<Vec<_>>>()?;

let common = grouping_set_emit_mapping(agg).map(|output_mapping| RelCommon {
emit_kind: Some(EmitKind::Emit(rel_common::Emit { output_mapping })),
hint: None,
advanced_extension: None,
});

Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
common,
input: Some(input),
grouping_expressions,
groupings,
Expand Down Expand Up @@ -180,3 +187,18 @@ pub fn to_substrait_agg_measure(
),
}
}

fn grouping_set_emit_mapping(agg: &Aggregate) -> Option<Vec<i32>> {
match agg.group_expr.as_slice() {
[Expr::GroupingSet(grouping_set)] => {
let group_key_count = grouping_set.distinct_expr().len() as i32;
let measure_count = agg.aggr_expr.len() as i32;
let output_mapping: Vec<i32> = (0..group_key_count)
.chain((group_key_count + 1)..(group_key_count + 1 + measure_count))
.chain(std::iter::once(group_key_count))
.collect();
Some(output_mapping)
}
_ => None,
}
}
23 changes: 23 additions & 0 deletions datafusion/substrait/tests/cases/aggregation_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,27 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn multiple_grouping_sets() -> Result<()> {
let proto_plan = read_json(
"tests/testdata/test_plans/aggregate_groupings/multiple_groupings.json",
);
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

assert_snapshot!(
plan,
@r#"
Projection: c0, c1, sum(c0) AS summation
Aggregate: groupBy=[[GROUPING SETS ((c0), (c1), (c0, c1))]], aggr=[[sum(c0)]]
EmptyRelation: rows=0
"#
);

// Trigger execution to ensure plan validity
DataFrame::new(ctx.state(), plan).show().await?;

Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "sum:i8"
}
}
],
"relations": [
{
"root": {
"input": {
"aggregate": {
"common": {
"emit": {
"outputMapping": [
0, 1, 2
]
}
},
"input": {
"read": {
"baseSchema": {
"names": [
"c0",
"c1"
],
"struct": {
"nullability": "NULLABILITY_REQUIRED",
"types": [
{
"i8": {
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i8": {
"nullability": "NULLABILITY_NULLABLE"
}
}
]
}
},
"common": {
"direct": {}
},
"virtualTable": {}
}
},
"groupingExpressions": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
},
{
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
}
],
"groupings": [
{
"expressionReferences": [0]
},
{
"expressionReferences": [1]
},
{
"expressionReferences": [0, 1]
}
],
"measures": [
{
"measure": {
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
}
],
"functionReference": 1,
"invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i8": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT"
}
}
]
}
},
"names": [
"c0",
"c1",
"summation"
]
}
}
],
"version": {
"minorNumber": 29
}
}