Skip to content

Commit 8e0ca1a

Browse files
authored
Add String view helper functions (#11517)
* add functions * add tests for hash util
1 parent 8d8732c commit 8e0ca1a

File tree

3 files changed

+159
-12
lines changed

3 files changed

+159
-12
lines changed

datafusion/common/src/cast.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use arrow::{
3636
},
3737
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
3838
};
39+
use arrow_array::{BinaryViewArray, StringViewArray};
3940

4041
// Downcast ArrayRef to Date32Array
4142
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
@@ -87,6 +88,11 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> {
8788
Ok(downcast_value!(array, StringArray))
8889
}
8990

91+
// Downcast ArrayRef to StringViewArray
92+
pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> {
93+
Ok(downcast_value!(array, StringViewArray))
94+
}
95+
9096
// Downcast ArrayRef to UInt32Array
9197
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> {
9298
Ok(downcast_value!(array, UInt32Array))
@@ -221,6 +227,11 @@ pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> {
221227
Ok(downcast_value!(array, BinaryArray))
222228
}
223229

230+
// Downcast ArrayRef to BinaryViewArray
231+
pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> {
232+
Ok(downcast_value!(array, BinaryViewArray))
233+
}
234+
224235
// Downcast ArrayRef to FixedSizeListArray
225236
pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> {
226237
Ok(downcast_value!(array, FixedSizeListArray))

datafusion/common/src/hash_utils.rs

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ pub fn create_hashes<'a>(
360360
random_state: &RandomState,
361361
hashes_buffer: &'a mut Vec<u64>,
362362
) -> Result<&'a mut Vec<u64>> {
363+
use crate::cast::{as_binary_view_array, as_string_view_array};
364+
363365
for (i, col) in arrays.iter().enumerate() {
364366
let array = col.as_ref();
365367
// combine hashes with `combine_hashes` for all columns besides the first
@@ -370,8 +372,10 @@ pub fn create_hashes<'a>(
370372
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash),
371373
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash),
372374
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash),
375+
DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash),
373376
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, rehash),
374377
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, rehash),
378+
DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash),
375379
DataType::FixedSizeBinary(_) => {
376380
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
377381
hash_array(array, random_state, hashes_buffer, rehash)
@@ -486,22 +490,57 @@ mod tests {
486490
Ok(())
487491
}
488492

489-
#[test]
490-
fn create_hashes_binary() -> Result<()> {
491-
let byte_array = Arc::new(BinaryArray::from_vec(vec![
492-
&[4, 3, 2],
493-
&[4, 3, 2],
494-
&[1, 2, 3],
495-
]));
493+
macro_rules! create_hash_binary {
494+
($NAME:ident, $ARRAY:ty) => {
495+
#[cfg(not(feature = "force_hash_collisions"))]
496+
#[test]
497+
fn $NAME() {
498+
let binary = [
499+
Some(b"short".to_byte_slice()),
500+
None,
501+
Some(b"long but different 12 bytes string"),
502+
Some(b"short2"),
503+
Some(b"Longer than 12 bytes string"),
504+
Some(b"short"),
505+
Some(b"Longer than 12 bytes string"),
506+
];
507+
508+
let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>());
509+
let ref_array = Arc::new(binary.iter().cloned().collect::<BinaryArray>());
510+
511+
let random_state = RandomState::with_seeds(0, 0, 0, 0);
512+
513+
let mut binary_hashes = vec![0; binary.len()];
514+
create_hashes(&[binary_array], &random_state, &mut binary_hashes)
515+
.unwrap();
516+
517+
let mut ref_hashes = vec![0; binary.len()];
518+
create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap();
519+
520+
// Null values result in a zero hash,
521+
for (val, hash) in binary.iter().zip(binary_hashes.iter()) {
522+
match val {
523+
Some(_) => assert_ne!(*hash, 0),
524+
None => assert_eq!(*hash, 0),
525+
}
526+
}
496527

497-
let random_state = RandomState::with_seeds(0, 0, 0, 0);
498-
let hashes_buff = &mut vec![0; byte_array.len()];
499-
let hashes = create_hashes(&[byte_array], &random_state, hashes_buff)?;
500-
assert_eq!(hashes.len(), 3,);
528+
// same logical values should hash to the same hash value
529+
assert_eq!(binary_hashes, ref_hashes);
501530

502-
Ok(())
531+
// Same values should map to same hash values
532+
assert_eq!(binary[0], binary[5]);
533+
assert_eq!(binary[4], binary[6]);
534+
535+
// different binary should map to different hash values
536+
assert_ne!(binary[0], binary[2]);
537+
}
538+
};
503539
}
504540

541+
create_hash_binary!(binary_array, BinaryArray);
542+
create_hash_binary!(binary_view_array, BinaryViewArray);
543+
505544
#[test]
506545
fn create_hashes_fixed_size_binary() -> Result<()> {
507546
let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]];
@@ -517,6 +556,64 @@ mod tests {
517556
Ok(())
518557
}
519558

559+
macro_rules! create_hash_string {
560+
($NAME:ident, $ARRAY:ty) => {
561+
#[cfg(not(feature = "force_hash_collisions"))]
562+
#[test]
563+
fn $NAME() {
564+
let strings = [
565+
Some("short"),
566+
None,
567+
Some("long but different 12 bytes string"),
568+
Some("short2"),
569+
Some("Longer than 12 bytes string"),
570+
Some("short"),
571+
Some("Longer than 12 bytes string"),
572+
];
573+
574+
let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>());
575+
let dict_array = Arc::new(
576+
strings
577+
.iter()
578+
.cloned()
579+
.collect::<DictionaryArray<Int8Type>>(),
580+
);
581+
582+
let random_state = RandomState::with_seeds(0, 0, 0, 0);
583+
584+
let mut string_hashes = vec![0; strings.len()];
585+
create_hashes(&[string_array], &random_state, &mut string_hashes)
586+
.unwrap();
587+
588+
let mut dict_hashes = vec![0; strings.len()];
589+
create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap();
590+
591+
// Null values result in a zero hash,
592+
for (val, hash) in strings.iter().zip(string_hashes.iter()) {
593+
match val {
594+
Some(_) => assert_ne!(*hash, 0),
595+
None => assert_eq!(*hash, 0),
596+
}
597+
}
598+
599+
// same logical values should hash to the same hash value
600+
assert_eq!(string_hashes, dict_hashes);
601+
602+
// Same values should map to same hash values
603+
assert_eq!(strings[0], strings[5]);
604+
assert_eq!(strings[4], strings[6]);
605+
606+
// different strings should map to different hash values
607+
assert_ne!(strings[0], strings[2]);
608+
}
609+
};
610+
}
611+
612+
create_hash_string!(string_array, StringArray);
613+
create_hash_string!(large_string_array, LargeStringArray);
614+
create_hash_string!(string_view_array, StringArray);
615+
create_hash_string!(dict_string_array, DictionaryArray<Int8Type>);
616+
520617
#[test]
521618
// Tests actual values of hashes, which are different if forcing collisions
522619
#[cfg(not(feature = "force_hash_collisions"))]

datafusion/physical-expr/src/aggregate/min_max.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use arrow_array::types::{
4545
Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
4646
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
4747
};
48+
use arrow_array::{BinaryViewArray, StringViewArray};
4849
use datafusion_common::internal_err;
4950
use datafusion_common::ScalarValue;
5051
use datafusion_common::{downcast_value, DataFusionError, Result};
@@ -453,6 +454,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
453454
DataType::LargeUtf8 => {
454455
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
455456
}
457+
DataType::Utf8View => {
458+
typed_min_max_batch_string!(
459+
values,
460+
StringViewArray,
461+
Utf8View,
462+
min_string_view
463+
)
464+
}
456465
DataType::Boolean => {
457466
typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
458467
}
@@ -467,6 +476,14 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
467476
min_binary
468477
)
469478
}
479+
DataType::BinaryView => {
480+
typed_min_max_batch_binary!(
481+
&values,
482+
BinaryViewArray,
483+
BinaryView,
484+
min_binary_view
485+
)
486+
}
470487
_ => min_max_batch!(values, min),
471488
})
472489
}
@@ -480,12 +497,28 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
480497
DataType::LargeUtf8 => {
481498
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
482499
}
500+
DataType::Utf8View => {
501+
typed_min_max_batch_string!(
502+
values,
503+
StringViewArray,
504+
Utf8View,
505+
max_string_view
506+
)
507+
}
483508
DataType::Boolean => {
484509
typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
485510
}
486511
DataType::Binary => {
487512
typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
488513
}
514+
DataType::BinaryView => {
515+
typed_min_max_batch_binary!(
516+
&values,
517+
BinaryViewArray,
518+
BinaryView,
519+
max_binary_view
520+
)
521+
}
489522
DataType::LargeBinary => {
490523
typed_min_max_batch_binary!(
491524
&values,
@@ -629,12 +662,18 @@ macro_rules! min_max {
629662
(ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
630663
typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
631664
}
665+
(ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
666+
typed_min_max_string!(lhs, rhs, Utf8View, $OP)
667+
}
632668
(ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
633669
typed_min_max_string!(lhs, rhs, Binary, $OP)
634670
}
635671
(ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
636672
typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
637673
}
674+
(ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
675+
typed_min_max_string!(lhs, rhs, BinaryView, $OP)
676+
}
638677
(ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
639678
typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
640679
}

0 commit comments

Comments
 (0)