Skip to content

Commit e88e5aa

Browse files
authored
Implement filter kernel for byte view arrays. (#5624)
* Implement `filter` kernel for byte view arrays. * Add unit tests and fix. * Deprecate `ArrowPrimitiveType::get_byte_width`. * Add string view filter benchmark.
1 parent fee6921 commit e88e5aa

File tree

6 files changed

+142
-10
lines changed

6 files changed

+142
-10
lines changed

arrow-arith/src/arity.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use arrow_array::builder::BufferBuilder;
2121
use arrow_array::types::ArrowDictionaryKeyType;
2222
use arrow_array::*;
2323
use arrow_buffer::buffer::NullBuffer;
24+
use arrow_buffer::ArrowNativeType;
2425
use arrow_buffer::{Buffer, MutableBuffer};
2526
use arrow_data::ArrayData;
2627
use arrow_schema::ArrowError;
@@ -386,7 +387,7 @@ where
386387
O: ArrowPrimitiveType,
387388
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
388389
{
389-
let mut buffer = MutableBuffer::new(len * O::get_byte_width());
390+
let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
390391
for idx in 0..len {
391392
unsafe {
392393
buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);

arrow-array/src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static {
5959
const DATA_TYPE: DataType;
6060

6161
/// Returns the byte width of this primitive type.
62+
#[deprecated(note = "Use ArrowNativeType::get_byte_width")]
6263
fn get_byte_width() -> usize {
6364
std::mem::size_of::<Self::Native>()
6465
}

arrow-buffer/src/native.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ mod private {
4747
pub trait ArrowNativeType:
4848
std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static
4949
{
50+
/// Returns the byte width of this native type.
51+
fn get_byte_width() -> usize {
52+
std::mem::size_of::<Self>()
53+
}
54+
5055
/// Convert native integer type from usize
5156
///
5257
/// Returns `None` if [`Self`] is not an integer or conversion would result

arrow-data/src/data.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,11 @@ impl ArrayDataBuilder {
17701770
self
17711771
}
17721772

1773+
pub fn add_buffers(mut self, bs: Vec<Buffer>) -> Self {
1774+
self.buffers.extend(bs);
1775+
self
1776+
}
1777+
17731778
pub fn child_data(mut self, v: Vec<ArrayData>) -> Self {
17741779
self.child_data = v;
17751780
self

arrow-select/src/filter.rs

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ use std::sync::Arc;
2323
use arrow_array::builder::BooleanBufferBuilder;
2424
use arrow_array::cast::AsArray;
2525
use arrow_array::types::{
26-
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType,
26+
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
2727
};
2828
use arrow_array::*;
29-
use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer};
29+
use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
3030
use arrow_buffer::{Buffer, MutableBuffer};
3131
use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
3232
use arrow_data::transform::MutableArrayData;
@@ -333,12 +333,18 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
333333
DataType::LargeUtf8 => {
334334
Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
335335
}
336+
DataType::Utf8View => {
337+
Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
338+
}
336339
DataType::Binary => {
337340
Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
338341
}
339342
DataType::LargeBinary => {
340343
Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
341344
}
345+
DataType::BinaryView => {
346+
Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
347+
}
342348
DataType::RunEndEncoded(_, _) => {
343349
downcast_run_array!{
344350
values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
@@ -508,12 +514,8 @@ fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanA
508514
BooleanArray::from(data)
509515
}
510516

511-
/// `filter` implementation for primitive arrays
512-
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
513-
where
514-
T: ArrowPrimitiveType,
515-
{
516-
let values = array.values();
517+
#[inline(never)]
518+
fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
517519
assert!(values.len() >= predicate.filter.len());
518520

519521
let buffer = match &predicate.strategy {
@@ -546,9 +548,19 @@ where
546548
IterationStrategy::All | IterationStrategy::None => unreachable!(),
547549
};
548550

551+
buffer.into()
552+
}
553+
554+
/// `filter` implementation for primitive arrays
555+
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
556+
where
557+
T: ArrowPrimitiveType,
558+
{
559+
let values = array.values();
560+
let buffer = filter_native(values, predicate);
549561
let mut builder = ArrayDataBuilder::new(array.data_type().clone())
550562
.len(predicate.count)
551-
.add_buffer(buffer.into());
563+
.add_buffer(buffer);
552564

553565
if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
554566
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
@@ -673,6 +685,25 @@ where
673685
GenericByteArray::from(data)
674686
}
675687

688+
/// `filter` implementation for byte view arrays.
689+
fn filter_byte_view<T: ByteViewType>(
690+
array: &GenericByteViewArray<T>,
691+
predicate: &FilterPredicate,
692+
) -> GenericByteViewArray<T> {
693+
let new_view_buffer = filter_native(array.views(), predicate);
694+
695+
let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
696+
.len(predicate.count)
697+
.add_buffer(new_view_buffer)
698+
.add_buffers(array.data_buffers().to_vec());
699+
700+
if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
701+
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
702+
}
703+
704+
GenericByteViewArray::from(unsafe { builder.build_unchecked() })
705+
}
706+
676707
/// `filter` implementation for dictionaries
677708
fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
678709
where
@@ -888,6 +919,69 @@ mod tests {
888919
assert!(d.is_null(1));
889920
}
890921

922+
fn _test_filter_byte_view<T>()
923+
where
924+
T: ByteViewType,
925+
str: AsRef<T::Native>,
926+
T::Native: PartialEq,
927+
{
928+
let array = {
929+
// ["hello", "world", null, "large payload over 12 bytes", "lulu"]
930+
let mut builder = GenericByteViewBuilder::<T>::new();
931+
builder.append_value("hello");
932+
builder.append_value("world");
933+
builder.append_null();
934+
builder.append_value("large payload over 12 bytes");
935+
builder.append_value("lulu");
936+
builder.finish()
937+
};
938+
939+
{
940+
let predicate = BooleanArray::from(vec![true, false, true, true, false]);
941+
let actual = filter(&array, &predicate).unwrap();
942+
943+
assert_eq!(actual.len(), 3);
944+
945+
let expected = {
946+
// ["hello", null, "large payload over 12 bytes"]
947+
let mut builder = GenericByteViewBuilder::<T>::new();
948+
builder.append_value("hello");
949+
builder.append_null();
950+
builder.append_value("large payload over 12 bytes");
951+
builder.finish()
952+
};
953+
954+
assert_eq!(actual.as_ref(), &expected);
955+
}
956+
957+
{
958+
let predicate = BooleanArray::from(vec![true, false, false, false, true]);
959+
let actual = filter(&array, &predicate).unwrap();
960+
961+
assert_eq!(actual.len(), 2);
962+
963+
let expected = {
964+
// ["hello", "lulu"]
965+
let mut builder = GenericByteViewBuilder::<T>::new();
966+
builder.append_value("hello");
967+
builder.append_value("lulu");
968+
builder.finish()
969+
};
970+
971+
assert_eq!(actual.as_ref(), &expected);
972+
}
973+
}
974+
975+
#[test]
976+
fn test_filter_string_view() {
977+
_test_filter_byte_view::<StringViewType>()
978+
}
979+
980+
#[test]
981+
fn test_filter_binary_view() {
982+
_test_filter_byte_view::<BinaryViewType>()
983+
}
984+
891985
#[test]
892986
fn test_filter_array_slice_with_null() {
893987
let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);

arrow/benches/filter_kernels.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,32 @@ fn add_benchmark(c: &mut Criterion) {
214214
c.bench_function("filter single record batch", |b| {
215215
b.iter(|| filter_record_batch(&batch, &filter_array))
216216
});
217+
218+
let data_array = create_string_view_array_with_len(size, 0.5, 4, false);
219+
c.bench_function("filter context short string view (kept 1/2)", |b| {
220+
b.iter(|| bench_built_filter(&filter, &data_array))
221+
});
222+
c.bench_function(
223+
"filter context short string view high selectivity (kept 1023/1024)",
224+
|b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
225+
);
226+
c.bench_function(
227+
"filter context short string view low selectivity (kept 1/1024)",
228+
|b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
229+
);
230+
231+
let data_array = create_string_view_array_with_len(size, 0.5, 4, true);
232+
c.bench_function("filter context mixed string view (kept 1/2)", |b| {
233+
b.iter(|| bench_built_filter(&filter, &data_array))
234+
});
235+
c.bench_function(
236+
"filter context mixed string view high selectivity (kept 1023/1024)",
237+
|b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
238+
);
239+
c.bench_function(
240+
"filter context mixed string view low selectivity (kept 1/1024)",
241+
|b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
242+
);
217243
}
218244

219245
criterion_group!(benches, add_benchmark);

0 commit comments

Comments
 (0)