|
21 | 21 | mod min_max_bytes;
|
22 | 22 |
|
23 | 23 | use arrow::array::{
|
24 |
| - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, |
25 |
| - Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, |
26 |
| - DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, |
27 |
| - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, |
28 |
| - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, |
29 |
| - LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, |
30 |
| - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, |
31 |
| - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, |
32 |
| - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, |
| 24 | + ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, |
| 25 | + Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, |
| 26 | + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, |
| 27 | + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, |
| 28 | + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, |
| 29 | + LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, |
| 30 | + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, |
| 31 | + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, |
| 32 | + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, |
| 33 | + UInt64Array, UInt8Array, |
33 | 34 | };
|
34 | 35 | use arrow::compute;
|
35 | 36 | use arrow::datatypes::{
|
@@ -610,6 +611,10 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
|
610 | 611 | min_binary_view
|
611 | 612 | )
|
612 | 613 | }
|
| 614 | + DataType::Dictionary(_, _) => { |
| 615 | + let values = values.as_any_dictionary().values(); |
| 616 | + min_batch(values)? |
| 617 | + } |
613 | 618 | _ => min_max_batch!(values, min),
|
614 | 619 | })
|
615 | 620 | }
|
@@ -653,6 +658,10 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
|
653 | 658 | max_binary
|
654 | 659 | )
|
655 | 660 | }
|
| 661 | + DataType::Dictionary(_, _) => { |
| 662 | + let values = values.as_any_dictionary().values(); |
| 663 | + max_batch(values)? |
| 664 | + } |
656 | 665 | _ => min_max_batch!(values, max),
|
657 | 666 | })
|
658 | 667 | }
|
@@ -1627,8 +1636,11 @@ make_udaf_expr_and_func!(
|
1627 | 1636 | #[cfg(test)]
|
1628 | 1637 | mod tests {
|
1629 | 1638 | use super::*;
|
1630 |
| - use arrow::datatypes::{ |
1631 |
| - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, |
| 1639 | + use arrow::{ |
| 1640 | + array::DictionaryArray, |
| 1641 | + datatypes::{ |
| 1642 | + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, |
| 1643 | + }, |
1632 | 1644 | };
|
1633 | 1645 | use std::sync::Arc;
|
1634 | 1646 |
|
@@ -1854,9 +1866,31 @@ mod tests {
|
1854 | 1866 | #[test]
|
1855 | 1867 | fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
|
1856 | 1868 | let data_type =
|
1857 |
| - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); |
| 1869 | + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); |
1858 | 1870 | let result = get_min_max_result_type(&[data_type])?;
|
1859 |
| - assert_eq!(result, vec![DataType::Int32]); |
| 1871 | + assert_eq!(result, vec![DataType::Utf8]); |
| 1872 | + Ok(()) |
| 1873 | + } |
| 1874 | + |
| 1875 | + #[test] |
| 1876 | + fn test_min_max_dictionary() -> Result<()> { |
| 1877 | + let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); |
| 1878 | + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); |
| 1879 | + let dict_array = |
| 1880 | + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); |
| 1881 | + let dict_array_ref = Arc::new(dict_array) as ArrayRef; |
| 1882 | + let rt_type = |
| 1883 | + get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone(); |
| 1884 | + |
| 1885 | + let mut min_acc = MinAccumulator::try_new(&rt_type)?; |
| 1886 | + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; |
| 1887 | + let min_result = min_acc.evaluate()?; |
| 1888 | + assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); |
| 1889 | + |
| 1890 | + let mut max_acc = MaxAccumulator::try_new(&rt_type)?; |
| 1891 | + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; |
| 1892 | + let max_result = max_acc.evaluate()?; |
| 1893 | + assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); |
1860 | 1894 | Ok(())
|
1861 | 1895 | }
|
1862 | 1896 | }
|
0 commit comments