Skip to content

Commit 5e1214c

Browse files
authored
allow min max dictionary (#15827)
1 parent b2c210a commit 5e1214c

File tree

1 file changed

+47
-13
lines changed

1 file changed

+47
-13
lines changed

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@
2121
mod min_max_bytes;
2222

2323
use arrow::array::{
24-
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
25-
Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
26-
DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array,
27-
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
28-
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
29-
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
30-
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
31-
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
32-
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
24+
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
25+
Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray,
26+
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
27+
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
28+
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray,
29+
LargeBinaryArray, LargeStringArray, StringArray, StringViewArray,
30+
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
31+
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
32+
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
33+
UInt64Array, UInt8Array,
3334
};
3435
use arrow::compute;
3536
use arrow::datatypes::{
@@ -610,6 +611,10 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
610611
min_binary_view
611612
)
612613
}
614+
DataType::Dictionary(_, _) => {
615+
let values = values.as_any_dictionary().values();
616+
min_batch(values)?
617+
}
613618
_ => min_max_batch!(values, min),
614619
})
615620
}
@@ -653,6 +658,10 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
653658
max_binary
654659
)
655660
}
661+
DataType::Dictionary(_, _) => {
662+
let values = values.as_any_dictionary().values();
663+
max_batch(values)?
664+
}
656665
_ => min_max_batch!(values, max),
657666
})
658667
}
@@ -1627,8 +1636,11 @@ make_udaf_expr_and_func!(
16271636
#[cfg(test)]
16281637
mod tests {
16291638
use super::*;
1630-
use arrow::datatypes::{
1631-
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
1639+
use arrow::{
1640+
array::DictionaryArray,
1641+
datatypes::{
1642+
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
1643+
},
16321644
};
16331645
use std::sync::Arc;
16341646

@@ -1854,9 +1866,31 @@ mod tests {
18541866
#[test]
18551867
fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
18561868
let data_type =
1857-
DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32));
1869+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
18581870
let result = get_min_max_result_type(&[data_type])?;
1859-
assert_eq!(result, vec![DataType::Int32]);
1871+
assert_eq!(result, vec![DataType::Utf8]);
1872+
Ok(())
1873+
}
1874+
1875+
#[test]
1876+
fn test_min_max_dictionary() -> Result<()> {
1877+
let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]);
1878+
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]);
1879+
let dict_array =
1880+
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1881+
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1882+
let rt_type =
1883+
get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone();
1884+
1885+
let mut min_acc = MinAccumulator::try_new(&rt_type)?;
1886+
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1887+
let min_result = min_acc.evaluate()?;
1888+
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1889+
1890+
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
1891+
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1892+
let max_result = max_acc.evaluate()?;
1893+
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
18601894
Ok(())
18611895
}
18621896
}

0 commit comments

Comments
 (0)