diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs index 2c084a1cb247b..e7647e5adb8b7 100644 --- a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -15,16 +15,94 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array, -}; +//! Optimized primitive type filters for InList expressions. +//! +//! This module provides membership tests for Arrow primitive types. + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::buffer::{BooleanBuffer, NullBuffer}; -use arrow::compute::take; use arrow::datatypes::*; +use arrow::util::bit_iterator::BitIndexIterator; use datafusion_common::{HashSet, Result, exec_datafusion_err}; use std::hash::{Hash, Hasher}; -use super::static_filter::StaticFilter; +use super::result::build_in_list_result; +use super::static_filter::{StaticFilter, handle_dictionary}; + +/// Bitmap filter for O(1) set membership via single bit test. +/// +/// `UInt8` has only 256 possible values, so the filter stores membership in a +/// 256-bit bitmap instead of using a hash table. +pub(super) struct UInt8BitmapFilter { + null_count: usize, + bits: [u64; 4], +} + +impl UInt8BitmapFilter { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let prim_array = in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array") + })?; + let mut bits = [0u64; 4]; + let mut set_bit = |v: u8| { + let index = usize::from(v); + bits[index / 64] |= 1u64 << (index % 64); + }; + + let values = prim_array.values(); + match prim_array.nulls() { + None => { + for &v in values { + set_bit(v); + } + } + Some(nulls) => { + for i in + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + { + set_bit(values[i]); + } + } + } + Ok(Self { + null_count: prim_array.null_count(), + bits, + }) + } + + #[inline(always)] + fn check(&self, needle: u8) -> bool { + let index = needle as usize; + (self.bits[index / 64] >> (index % 64)) & 1 != 0 + } +} + +impl StaticFilter for UInt8BitmapFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..v.len()`, which matches `input_values.len()`. + let needle = unsafe { *input_values.get_unchecked(i) }; + self.check(needle) + }, + )) + } +} /// Wrapper for f32 that implements Hash and Eq using bit comparison. /// This treats NaN values as equal to each other when they have the same bit pattern. @@ -94,9 +172,13 @@ macro_rules! primitive_static_filter { impl $Name { pub(super) fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); @@ -115,19 +197,14 @@ macro_rules! primitive_static_filter { } fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } + handle_dictionary!(self, v, negated); - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; let haystack_has_nulls = self.null_count > 0; let needle_values = v.values(); @@ -188,8 +265,10 @@ macro_rules! primitive_static_filter { } (true, true) => { // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + let needle_validity = + needle_nulls.map(|n| n.inner().clone()).unwrap_or_else( + || BooleanBuffer::new_set(needle_values.len()), + ); // Valid when original "in set" is true (see above) let haystack_validity = if negated { @@ -215,7 +294,6 @@ primitive_static_filter!(Int8StaticFilter, Int8Type); primitive_static_filter!(Int16StaticFilter, Int16Type); primitive_static_filter!(Int32StaticFilter, Int32Type); primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); primitive_static_filter!(UInt16StaticFilter, UInt16Type); primitive_static_filter!(UInt32StaticFilter, UInt32Type); primitive_static_filter!(UInt64StaticFilter, UInt64Type); @@ -231,3 +309,50 @@ macro_rules! float_static_filter { // Generate specialized filters for float types using ordered wrappers float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{DictionaryArray, Int8Array, UInt8Array}; + + fn assert_contains( + filter: &UInt8BitmapFilter, + needles: &dyn Array, + expected: Vec>, + ) -> Result<()> { + assert_eq!( + filter.contains(needles, false)?, + BooleanArray::from(expected) + ); + Ok(()) + } + + #[test] + fn bitmap_filter_u8_handles_nulls() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])); + let filter = UInt8BitmapFilter::try_new(&haystack)?; + let needles = UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]); + + assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])?; + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + Ok(()) + } + + #[test] + fn bitmap_filter_u8_handles_dictionary_needles() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])); + let filter = UInt8BitmapFilter::try_new(&haystack)?; + + let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]); + let values = Arc::new(UInt8Array::from(vec![Some(1), Some(2), Some(3)])); + let needles = DictionaryArray::try_new(keys, values)?; + + assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)]) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs index 218bd27950266..3c964d4183474 100644 --- a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -35,3 +35,20 @@ pub(super) trait StaticFilter { /// implementation unwraps the dictionary and operates on its values. fn contains(&self, v: &dyn Array, negated: bool) -> Result; } + +/// Evaluate dictionary-encoded needles by applying a filter to dictionary +/// values and remapping the result through the keys. +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + arrow::array::downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = arrow::compute::take(&values_contains, $v.keys(), None)?; + return Ok(arrow::array::downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +pub(super) use handle_dictionary; diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs index b7ee3dd1a3b9d..1fb8e03fe2040 100644 --- a/datafusion/physical-expr/src/expressions/in_list/strategy.rs +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -42,7 +42,7 @@ pub(super) fn instantiate_static_filter( DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8BitmapFilter::try_new(&in_array)?)), DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),