Skip to content

Commit 39e2c98

Browse files
fix: reorder grouping-set aggregates output for Substrait
1 parent 531af8e commit 39e2c98

File tree

4 files changed

+219
-4
lines changed

4 files changed

+219
-4
lines changed

datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts};
1919
use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer};
2020
use datafusion::common::{not_impl_err, DFSchemaRef};
21-
use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder};
21+
use datafusion::logical_expr::{
22+
Aggregate, Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder,
23+
};
2224
use substrait::proto::aggregate_function::AggregationInvocation;
2325
use substrait::proto::aggregate_rel::Grouping;
2426
use substrait::proto::AggregateRel;
@@ -116,12 +118,49 @@ pub async fn from_aggregate_rel(
116118
.map(|e| name_tracker.get_uniquely_named_expr(e.clone()))
117119
.collect::<Result<Vec<Expr>, _>>()?;
118120

119-
input.aggregate(group_exprs, aggr_exprs)?.build()
121+
let plan = input.aggregate(group_exprs, aggr_exprs)?.build()?;
122+
reorder_grouping_set_output(plan)
120123
} else {
121124
not_impl_err!("Aggregate without an input is not valid")
122125
}
123126
}
124127

128+
/// Reorders the output of grouping-set aggregates so the column layout matches the Substrait
129+
/// specification. DataFusion's [`Aggregate::output_expressions`] produces
130+
/// `[grouping keys..., __grouping_id, measures...]`, whereas Substrait requires
131+
/// `[grouping keys..., measures..., grouping_id]`. A projection is added only when the internal
132+
/// grouping id is not already in the final position.
133+
fn reorder_grouping_set_output(
134+
plan: LogicalPlan,
135+
) -> datafusion::common::Result<LogicalPlan> {
136+
match plan {
137+
LogicalPlan::Aggregate(agg)
138+
if matches!(agg.group_expr.first(), Some(Expr::GroupingSet(_))) =>
139+
{
140+
let mut columns = agg.schema.columns();
141+
let Some(idx) = columns
142+
.iter()
143+
.position(|col| col.name == Aggregate::INTERNAL_GROUPING_ID)
144+
else {
145+
return Ok(LogicalPlan::Aggregate(agg));
146+
};
147+
148+
if idx + 1 == columns.len() {
149+
return Ok(LogicalPlan::Aggregate(agg));
150+
}
151+
152+
let grouping_id = columns.remove(idx);
153+
let mut exprs: Vec<Expr> = columns.into_iter().map(Expr::Column).collect();
154+
exprs.push(Expr::Column(grouping_id));
155+
156+
LogicalPlanBuilder::from(LogicalPlan::Aggregate(agg))
157+
.project(exprs)?
158+
.build()
159+
}
160+
_ => Ok(plan),
161+
}
162+
}
163+
125164
#[allow(deprecated)]
126165
async fn from_substrait_grouping(
127166
consumer: &impl SubstraitConsumer,

datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use datafusion::logical_expr::expr::Alias;
2323
use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet};
2424
use substrait::proto::aggregate_rel::{Grouping, Measure};
2525
use substrait::proto::rel::RelType;
26-
use substrait::proto::{AggregateRel, Expression, Rel};
26+
use substrait::proto::rel_common::EmitKind;
27+
use substrait::proto::{rel_common, AggregateRel, Expression, Rel, RelCommon};
2728

2829
pub fn from_aggregate(
2930
producer: &mut impl SubstraitProducer,
@@ -38,9 +39,15 @@ pub fn from_aggregate(
3839
.map(|e| to_substrait_agg_measure(producer, e, agg.input.schema()))
3940
.collect::<datafusion::common::Result<Vec<_>>>()?;
4041

42+
let common = grouping_set_emit_mapping(agg).map(|output_mapping| RelCommon {
43+
emit_kind: Some(EmitKind::Emit(rel_common::Emit { output_mapping })),
44+
hint: None,
45+
advanced_extension: None,
46+
});
47+
4148
Ok(Box::new(Rel {
4249
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
43-
common: None,
50+
common,
4451
input: Some(input),
4552
grouping_expressions,
4653
groupings,
@@ -180,3 +187,21 @@ pub fn to_substrait_agg_measure(
180187
),
181188
}
182189
}
190+
191+
fn grouping_set_emit_mapping(agg: &Aggregate) -> Option<Vec<i32>> {
192+
match agg.group_expr.as_slice() {
193+
[Expr::GroupingSet(grouping_set)] => {
194+
let group_key_count = grouping_set.distinct_expr().len() as i32;
195+
let measure_count = agg.aggr_expr.len() as i32;
196+
let grouping_id_index = group_key_count;
197+
198+
let mut output_mapping: Vec<i32> = (0..group_key_count).collect();
199+
for offset in 0..measure_count {
200+
output_mapping.push(group_key_count + 1 + offset);
201+
}
202+
output_mapping.push(grouping_id_index);
203+
Some(output_mapping)
204+
}
205+
_ => None,
206+
}
207+
}

datafusion/substrait/tests/cases/aggregation_tests.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,27 @@ mod tests {
6868

6969
Ok(())
7070
}
71+
72+
#[tokio::test]
73+
async fn multiple_grouping_sets() -> Result<()> {
74+
let proto_plan = read_json(
75+
"tests/testdata/test_plans/aggregate_groupings/multiple_groupings.json",
76+
);
77+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
78+
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
79+
80+
assert_snapshot!(
81+
plan,
82+
@r#"
83+
Projection: c0, c1, sum(c0) AS summation
84+
Aggregate: groupBy=[[GROUPING SETS ((c0), (c1), (c0, c1))]], aggr=[[sum(c0)]]
85+
EmptyRelation: rows=0
86+
"#
87+
);
88+
89+
// Trigger execution to ensure plan validity
90+
DataFrame::new(ctx.state(), plan).show().await?;
91+
92+
Ok(())
93+
}
7194
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"extensionUriAnchor": 1,
5+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
6+
}
7+
],
8+
"extensions": [
9+
{
10+
"extensionFunction": {
11+
"extensionUriReference": 1,
12+
"functionAnchor": 1,
13+
"name": "sum:i8"
14+
}
15+
}
16+
],
17+
"relations": [
18+
{
19+
"root": {
20+
"input": {
21+
"aggregate": {
22+
"common": {
23+
"emit": {
24+
"outputMapping": [
25+
0, 1, 2
26+
]
27+
}
28+
},
29+
"input": {
30+
"read": {
31+
"baseSchema": {
32+
"names": [
33+
"c0",
34+
"c1"
35+
],
36+
"struct": {
37+
"nullability": "NULLABILITY_REQUIRED",
38+
"types": [
39+
{
40+
"i8": {
41+
"nullability": "NULLABILITY_NULLABLE"
42+
}
43+
},
44+
{
45+
"i8": {
46+
"nullability": "NULLABILITY_NULLABLE"
47+
}
48+
}
49+
]
50+
}
51+
},
52+
"common": {
53+
"direct": {}
54+
},
55+
"virtualTable": {}
56+
}
57+
},
58+
"groupingExpressions": [
59+
{
60+
"selection": {
61+
"directReference": {
62+
"structField": {}
63+
},
64+
"rootReference": {}
65+
}
66+
},
67+
{
68+
"selection": {
69+
"directReference": {
70+
"structField": {
71+
"field": 1
72+
}
73+
},
74+
"rootReference": {}
75+
}
76+
}
77+
],
78+
"groupings": [
79+
{
80+
"expressionReferences": [0]
81+
},
82+
{
83+
"expressionReferences": [1]
84+
},
85+
{
86+
"expressionReferences": [0, 1]
87+
}
88+
],
89+
"measures": [
90+
{
91+
"measure": {
92+
"arguments": [
93+
{
94+
"value": {
95+
"selection": {
96+
"directReference": {
97+
"structField": {}
98+
},
99+
"rootReference": {}
100+
}
101+
}
102+
}
103+
],
104+
"functionReference": 1,
105+
"invocation": "AGGREGATION_INVOCATION_ALL",
106+
"outputType": {
107+
"i8": {
108+
"nullability": "NULLABILITY_NULLABLE"
109+
}
110+
},
111+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT"
112+
}
113+
}
114+
]
115+
}
116+
},
117+
"names": [
118+
"c0",
119+
"c1",
120+
"summation"
121+
]
122+
}
123+
}
124+
],
125+
"version": {
126+
"minorNumber": 29
127+
}
128+
}

0 commit comments

Comments
 (0)