diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs index e7647e5adb8b7..8242ba09bddc6 100644 --- a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -104,6 +104,81 @@ impl StaticFilter for UInt8BitmapFilter { } } +/// Bitmap filter for O(1) `UInt16` set membership via single bit test. +/// +/// `UInt16` has 65,536 possible values, so the filter stores membership in an +/// 8 KiB heap-allocated bitmap instead of using a hash table. +pub(super) struct UInt16BitmapFilter { + null_count: usize, + bits: Box<[u64; 1024]>, +} + +impl UInt16BitmapFilter { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let prim_array = in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("UInt16BitmapFilter: expected UInt16 array") + })?; + let mut bits = Box::new([0u64; 1024]); + let mut set_bit = |v: u16| { + let index = usize::from(v); + bits[index / 64] |= 1u64 << (index % 64); + }; + + let values = prim_array.values(); + match prim_array.nulls() { + None => { + for &v in values { + set_bit(v); + } + } + Some(nulls) => { + for i in + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + { + set_bit(values[i]); + } + } + } + Ok(Self { + null_count: prim_array.null_count(), + bits, + }) + } + + #[inline(always)] + fn check(&self, needle: u16) -> bool { + let index = needle as usize; + (self.bits[index / 64] >> (index % 64)) & 1 != 0 + } +} + +impl StaticFilter for UInt16BitmapFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("UInt16BitmapFilter: expected UInt16 array") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..v.len()`, which matches `input_values.len()`. + let needle = unsafe { *input_values.get_unchecked(i) }; + self.check(needle) + }, + )) + } +} + /// Wrapper for f32 that implements Hash and Eq using bit comparison. /// This treats NaN values as equal to each other when they have the same bit pattern. #[derive(Clone, Copy)] @@ -294,7 +369,6 @@ primitive_static_filter!(Int8StaticFilter, Int8Type); primitive_static_filter!(Int16StaticFilter, Int16Type); primitive_static_filter!(Int32StaticFilter, Int32Type); primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); primitive_static_filter!(UInt32StaticFilter, UInt32Type); primitive_static_filter!(UInt64StaticFilter, UInt64Type); @@ -315,10 +389,10 @@ mod tests { use super::*; use std::sync::Arc; - use arrow::array::{DictionaryArray, Int8Array, UInt8Array}; + use arrow::array::{DictionaryArray, Int8Array, UInt8Array, UInt16Array}; fn assert_contains( - filter: &UInt8BitmapFilter, + filter: &dyn StaticFilter, needles: &dyn Array, expected: Vec>, ) -> Result<()> { @@ -355,4 +429,29 @@ mod tests { assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)]) } + + #[test] + fn bitmap_filter_u16_handles_boundaries_and_nulls() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt16Array::from(vec![ + Some(0), + None, + Some(1024), + Some(u16::MAX), + ])); + let filter = UInt16BitmapFilter::try_new(&haystack)?; + let needles = + UInt16Array::from(vec![Some(0), Some(1), Some(1024), Some(u16::MAX), None]); + + assert_contains( + &filter, + &needles, + vec![Some(true), None, Some(true), Some(true), None], + )?; + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), None, Some(false), Some(false), None]) + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs index 1fb8e03fe2040..aec94bddb920b 100644 --- a/datafusion/physical-expr/src/expressions/in_list/strategy.rs +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -43,7 +43,7 @@ pub(super) fn instantiate_static_filter( DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), DataType::UInt8 => Ok(Arc::new(UInt8BitmapFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16BitmapFilter::try_new(&in_array)?)), DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), // Float primitive types (use ordered wrappers for Hash/Eq)