From c438b16439b2b08de09925470ef12c34023323d8 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Thu, 11 Apr 2024 10:10:10 +0800 Subject: [PATCH 1/4] Implement `filter` kernel for byte view arrays. --- arrow-arith/src/arity.rs | 3 ++- arrow-array/src/types.rs | 5 ---- arrow-buffer/src/native.rs | 5 ++++ arrow-data/src/data.rs | 5 ++++ arrow-select/src/filter.rs | 49 +++++++++++++++++++++++++++++++------- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 3d8214d89dc3..ff5c8e822cc0 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -21,6 +21,7 @@ use arrow_array::builder::BufferBuilder; use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ArrowNativeType; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::ArrayData; use arrow_schema::ArrowError; @@ -386,7 +387,7 @@ where O: ArrowPrimitiveType, F: Fn(A::Item, B::Item) -> Result, { - let mut buffer = MutableBuffer::new(len * O::get_byte_width()); + let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width()); for idx in 0..len { unsafe { buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index e33f7bde7cba..3ade1405d7bc 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -58,11 +58,6 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { /// the corresponding Arrow data type of this primitive type. const DATA_TYPE: DataType; - /// Returns the byte width of this primitive type. - fn get_byte_width() -> usize { - std::mem::size_of::() - } - /// Returns a default value of this primitive type. /// /// This is useful for aggregate array ops like `sum()`, `mean()`. diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs index 680974351a4b..de665d4e3874 100644 --- a/arrow-buffer/src/native.rs +++ b/arrow-buffer/src/native.rs @@ -47,6 +47,11 @@ mod private { pub trait ArrowNativeType: std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static { + /// Returns the byte width of this native type. + fn get_byte_width() -> usize { + std::mem::size_of::() + } + /// Convert native integer type from usize /// /// Returns `None` if [`Self`] is not an integer or conversion would result diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index e227b168eee5..b61decc1629c 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -1761,6 +1761,11 @@ impl ArrayDataBuilder { self } + pub fn add_buffers(mut self, bs: Vec) -> Self { + self.buffers.extend(bs); + self + } + pub fn child_data(mut self, v: Vec) -> Self { self.child_data = v; self diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 2af19ff85056..0b46664322d4 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -23,10 +23,10 @@ use std::sync::Arc; use arrow_array::builder::BooleanBufferBuilder; use arrow_array::cast::AsArray; use arrow_array::types::{ - ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType, + ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType, }; use arrow_array::*; -use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer}; +use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer}; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; use arrow_data::transform::MutableArrayData; @@ -333,12 +333,18 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_bytes(values.as_string::(), predicate))) } + DataType::Utf8View => { + Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) + } DataType::Binary => { Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } DataType::LargeBinary => { Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } + DataType::BinaryView => { + Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) + } DataType::RunEndEncoded(_, _) => { downcast_run_array!{ values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), @@ -508,12 +514,8 @@ fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanA BooleanArray::from(data) } -/// `filter` implementation for primitive arrays -fn filter_primitive(array: &PrimitiveArray, predicate: &FilterPredicate) -> PrimitiveArray -where - T: ArrowPrimitiveType, -{ - let values = array.values(); +#[inline(never)] +fn filter_native(values: &[T], predicate: &FilterPredicate) -> Buffer { assert!(values.len() >= predicate.filter.len()); let buffer = match &predicate.strategy { @@ -546,9 +548,19 @@ where IterationStrategy::All | IterationStrategy::None => unreachable!(), }; + buffer.into() +} + +/// `filter` implementation for primitive arrays +fn filter_primitive(array: &PrimitiveArray, predicate: &FilterPredicate) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let values = array.values(); + let buffer = filter_native(values, predicate); let mut builder = ArrayDataBuilder::new(array.data_type().clone()) .len(predicate.count) - .add_buffer(buffer.into()); + .add_buffer(buffer); if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); @@ -673,6 +685,25 @@ where GenericByteArray::from(data) } +/// `filter` implementation for byte view arrays. +fn filter_byte_view( + array: &GenericByteViewArray, + predicate: &FilterPredicate, +) -> GenericByteViewArray { + let new_view_buffer = filter_native(array.views(), predicate); + + let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) + .len(predicate.count) + .add_buffer(new_view_buffer) + .add_buffers(array.data_buffers().to_vec()); + + if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { + builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); + } + + GenericByteViewArray::from(unsafe { builder.build_unchecked() }) +} + /// `filter` implementation for dictionaries fn filter_dict(array: &DictionaryArray, predicate: &FilterPredicate) -> DictionaryArray where From 70819a13049d88144581531e6deae200a7376a2d Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Thu, 11 Apr 2024 19:38:38 +0800 Subject: [PATCH 2/4] Add unit tests and fix. --- arrow-select/src/filter.rs | 65 +++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 0b46664322d4..8e06b07f5ef4 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -343,7 +343,7 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result(), predicate))) } DataType::BinaryView => { - Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) + Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate))) } DataType::RunEndEncoded(_, _) => { downcast_run_array!{ @@ -919,6 +919,69 @@ mod tests { assert!(d.is_null(1)); } + fn _test_filter_byte_view() + where + T: ByteViewType, + str: AsRef, + T::Native: PartialEq, + { + let array = { + // ["hello", "world", null, "large payload over 12 bytes", "lulu"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_value("world"); + builder.append_null(); + builder.append_value("large payload over 12 bytes"); + builder.append_value("lulu"); + builder.finish() + }; + + { + let predicate = BooleanArray::from(vec![true, false, true, true, false]); + let actual = filter(&array, &predicate).unwrap(); + + assert_eq!(actual.len(), 3); + + let expected = { + // ["hello", null, "large payload over 12 bytes"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_null(); + builder.append_value("large payload over 12 bytes"); + builder.finish() + }; + + assert_eq!(actual.as_ref(), &expected); + } + + { + let predicate = BooleanArray::from(vec![true, false, false, false, true]); + let actual = filter(&array, &predicate).unwrap(); + + assert_eq!(actual.len(), 2); + + let expected = { + // ["hello", "lulu"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_value("lulu"); + builder.finish() + }; + + assert_eq!(actual.as_ref(), &expected); + } + } + + #[test] + fn test_filter_string_view() { + _test_filter_byte_view::() + } + + #[test] + fn test_filter_binary_view() { + _test_filter_byte_view::() + } + #[test] fn test_filter_array_slice_with_null() { let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4); From ef4f44f1c996de1d06a2250e8ec22c9e9f88004a Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Sun, 14 Apr 2024 09:33:03 +0800 Subject: [PATCH 3/4] Deprecate `ArrowPrimitiveType::get_byte_width`. --- arrow-array/src/types.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 3ade1405d7bc..038b2a291f58 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -58,6 +58,12 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { /// the corresponding Arrow data type of this primitive type. const DATA_TYPE: DataType; + /// Returns the byte width of this primitive type. + #[deprecated(note = "Use ArrowNativeType::get_byte_width")] + fn get_byte_width() -> usize { + std::mem::size_of::() + } + /// Returns a default value of this primitive type. /// /// This is useful for aggregate array ops like `sum()`, `mean()`. From 9a6fd8139b136cb19d03046b323a89b080fa5025 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Sun, 14 Apr 2024 10:11:56 +0800 Subject: [PATCH 4/4] Add string view filter benchmark. --- arrow/benches/filter_kernels.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs index 50f3cb40094d..e48b5302241d 100644 --- a/arrow/benches/filter_kernels.rs +++ b/arrow/benches/filter_kernels.rs @@ -214,6 +214,32 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("filter single record batch", |b| { b.iter(|| filter_record_batch(&batch, &filter_array)) }); + + let data_array = create_string_view_array_with_len(size, 0.5, 4, false); + c.bench_function("filter context short string view (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context short string view high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context short string view low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_string_view_array_with_len(size, 0.5, 4, true); + c.bench_function("filter context mixed string view (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context mixed string view high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context mixed string view low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); } criterion_group!(benches, add_benchmark);