diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 3d8214d89dc3..ff5c8e822cc0 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -21,6 +21,7 @@ use arrow_array::builder::BufferBuilder; use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ArrowNativeType; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::ArrayData; use arrow_schema::ArrowError; @@ -386,7 +387,7 @@ where O: ArrowPrimitiveType, F: Fn(A::Item, B::Item) -> Result, { - let mut buffer = MutableBuffer::new(len * O::get_byte_width()); + let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width()); for idx in 0..len { unsafe { buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index e33f7bde7cba..038b2a291f58 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -59,6 +59,7 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { const DATA_TYPE: DataType; /// Returns the byte width of this primitive type. + #[deprecated(note = "Use ArrowNativeType::get_byte_width")] fn get_byte_width() -> usize { std::mem::size_of::() } diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs index 680974351a4b..de665d4e3874 100644 --- a/arrow-buffer/src/native.rs +++ b/arrow-buffer/src/native.rs @@ -47,6 +47,11 @@ mod private { pub trait ArrowNativeType: std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static { + /// Returns the byte width of this native type. + fn get_byte_width() -> usize { + std::mem::size_of::() + } + /// Convert native integer type from usize /// /// Returns `None` if [`Self`] is not an integer or conversion would result diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index e227b168eee5..b61decc1629c 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -1761,6 +1761,11 @@ impl ArrayDataBuilder { self } + pub fn add_buffers(mut self, bs: Vec) -> Self { + self.buffers.extend(bs); + self + } + pub fn child_data(mut self, v: Vec) -> Self { self.child_data = v; self diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 2af19ff85056..8e06b07f5ef4 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -23,10 +23,10 @@ use std::sync::Arc; use arrow_array::builder::BooleanBufferBuilder; use arrow_array::cast::AsArray; use arrow_array::types::{ - ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType, + ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType, }; use arrow_array::*; -use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer}; +use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer}; use arrow_buffer::{Buffer, MutableBuffer}; use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; use arrow_data::transform::MutableArrayData; @@ -333,12 +333,18 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_bytes(values.as_string::(), predicate))) } + DataType::Utf8View => { + Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) + } DataType::Binary => { Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } DataType::LargeBinary => { Ok(Arc::new(filter_bytes(values.as_binary::(), predicate))) } + DataType::BinaryView => { + Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate))) + } DataType::RunEndEncoded(_, _) => { downcast_run_array!{ values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), @@ -508,12 +514,8 @@ fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanA BooleanArray::from(data) } -/// `filter` implementation for primitive arrays -fn filter_primitive(array: &PrimitiveArray, predicate: &FilterPredicate) -> PrimitiveArray -where - T: ArrowPrimitiveType, -{ - let values = array.values(); +#[inline(never)] +fn filter_native(values: &[T], predicate: &FilterPredicate) -> Buffer { assert!(values.len() >= predicate.filter.len()); let buffer = match &predicate.strategy { @@ -546,9 +548,19 @@ where IterationStrategy::All | IterationStrategy::None => unreachable!(), }; + buffer.into() +} + +/// `filter` implementation for primitive arrays +fn filter_primitive(array: &PrimitiveArray, predicate: &FilterPredicate) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let values = array.values(); + let buffer = filter_native(values, predicate); let mut builder = ArrayDataBuilder::new(array.data_type().clone()) .len(predicate.count) - .add_buffer(buffer.into()); + .add_buffer(buffer); if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); @@ -673,6 +685,25 @@ where GenericByteArray::from(data) } +/// `filter` implementation for byte view arrays. +fn filter_byte_view( + array: &GenericByteViewArray, + predicate: &FilterPredicate, +) -> GenericByteViewArray { + let new_view_buffer = filter_native(array.views(), predicate); + + let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) + .len(predicate.count) + .add_buffer(new_view_buffer) + .add_buffers(array.data_buffers().to_vec()); + + if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { + builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); + } + + GenericByteViewArray::from(unsafe { builder.build_unchecked() }) +} + /// `filter` implementation for dictionaries fn filter_dict(array: &DictionaryArray, predicate: &FilterPredicate) -> DictionaryArray where @@ -888,6 +919,69 @@ mod tests { assert!(d.is_null(1)); } + fn _test_filter_byte_view() + where + T: ByteViewType, + str: AsRef, + T::Native: PartialEq, + { + let array = { + // ["hello", "world", null, "large payload over 12 bytes", "lulu"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_value("world"); + builder.append_null(); + builder.append_value("large payload over 12 bytes"); + builder.append_value("lulu"); + builder.finish() + }; + + { + let predicate = BooleanArray::from(vec![true, false, true, true, false]); + let actual = filter(&array, &predicate).unwrap(); + + assert_eq!(actual.len(), 3); + + let expected = { + // ["hello", null, "large payload over 12 bytes"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_null(); + builder.append_value("large payload over 12 bytes"); + builder.finish() + }; + + assert_eq!(actual.as_ref(), &expected); + } + + { + let predicate = BooleanArray::from(vec![true, false, false, false, true]); + let actual = filter(&array, &predicate).unwrap(); + + assert_eq!(actual.len(), 2); + + let expected = { + // ["hello", "lulu"] + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value("hello"); + builder.append_value("lulu"); + builder.finish() + }; + + assert_eq!(actual.as_ref(), &expected); + } + } + + #[test] + fn test_filter_string_view() { + _test_filter_byte_view::() + } + + #[test] + fn test_filter_binary_view() { + _test_filter_byte_view::() + } + #[test] fn test_filter_array_slice_with_null() { let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4); diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs index 50f3cb40094d..e48b5302241d 100644 --- a/arrow/benches/filter_kernels.rs +++ b/arrow/benches/filter_kernels.rs @@ -214,6 +214,32 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("filter single record batch", |b| { b.iter(|| filter_record_batch(&batch, &filter_array)) }); + + let data_array = create_string_view_array_with_len(size, 0.5, 4, false); + c.bench_function("filter context short string view (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context short string view high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context short string view low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_string_view_array_with_len(size, 0.5, 4, true); + c.bench_function("filter context mixed string view (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context mixed string view high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context mixed string view low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); } criterion_group!(benches, add_benchmark);