Skip to content

Commit 44d890c

Browse files
authored
Merge branch 'main' into describe-column-name-bug
2 parents 578636e + 52f340b commit 44d890c

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::{Field, Schema};
18+
use arrow::datatypes::{DataType, Field, Schema};
1919
use datafusion_common::Result;
2020
use datafusion_expr_common::accumulator::Accumulator;
2121
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
@@ -71,6 +71,13 @@ pub struct AccumulatorArgs<'a> {
7171
pub exprs: &'a [Arc<dyn PhysicalExpr>],
7272
}
7373

74+
impl AccumulatorArgs<'_> {
75+
/// Returns the return type of the aggregate function.
76+
pub fn return_type(&self) -> &DataType {
77+
self.return_field.data_type()
78+
}
79+
}
80+
7481
/// Factory that returns an accumulator for the given aggregate function.
7582
pub type AccumulatorFactoryFunction =
7683
Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;

datafusion/functions-aggregate/src/average.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ impl AggregateUDFImpl for Avg {
182182
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
183183
matches!(
184184
args.return_field.data_type(),
185-
DataType::Float64 | DataType::Decimal128(_, _)
185+
DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
186186
)
187187
}
188188

@@ -243,6 +243,45 @@ impl AggregateUDFImpl for Avg {
243243
)))
244244
}
245245

246+
(Duration(time_unit), Duration(_result_unit)) => {
247+
let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64);
248+
249+
match time_unit {
250+
TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
251+
DurationSecondType,
252+
_,
253+
>::new(
254+
&data_type,
255+
args.return_type(),
256+
avg_fn,
257+
))),
258+
TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::<
259+
DurationMillisecondType,
260+
_,
261+
>::new(
262+
&data_type,
263+
args.return_type(),
264+
avg_fn,
265+
))),
266+
TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::<
267+
DurationMicrosecondType,
268+
_,
269+
>::new(
270+
&data_type,
271+
args.return_type(),
272+
avg_fn,
273+
))),
274+
TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
275+
DurationNanosecondType,
276+
_,
277+
>::new(
278+
&data_type,
279+
args.return_type(),
280+
avg_fn,
281+
))),
282+
}
283+
}
284+
246285
_ => not_impl_err!(
247286
"AvgGroupsAccumulator for ({} --> {})",
248287
&data_type,

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5098,6 +5098,28 @@ FROM d WHERE column1 IS NOT NULL;
50985098
statement ok
50995099
drop table d;
51005100

5101+
statement ok
5102+
create table dn as values
5103+
(arrow_cast(10, 'Duration(Second)'), 'a', 1),
5104+
(arrow_cast(20, 'Duration(Second)'), 'a', 2),
5105+
(NULL, 'b', 1),
5106+
(arrow_cast(40, 'Duration(Second)'), 'b', 2),
5107+
(arrow_cast(50, 'Duration(Second)'), 'c', 1),
5108+
(NULL, 'c', 2);
5109+
5110+
query T?I
5111+
SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER BY column2, column3;
5112+
----
5113+
a 0 days 0 hours 0 mins 10 secs 1
5114+
a 0 days 0 hours 0 mins 20 secs 2
5115+
b NULL 1
5116+
b 0 days 0 hours 0 mins 40 secs 2
5117+
c 0 days 0 hours 0 mins 50 secs 1
5118+
c NULL 2
5119+
5120+
statement ok
5121+
drop table dn;
5122+
51015123
# Prepare the table with dictionary values for testing
51025124
statement ok
51035125
CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);

0 commit comments

Comments
 (0)