Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 148 additions & 23 deletions datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,94 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array,
};
//! Optimized primitive type filters for InList expressions.
//!
//! This module provides membership tests for Arrow primitive types.

use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::take;
use arrow::datatypes::*;
use arrow::util::bit_iterator::BitIndexIterator;
use datafusion_common::{HashSet, Result, exec_datafusion_err};
use std::hash::{Hash, Hasher};

use super::static_filter::StaticFilter;
use super::result::build_in_list_result;
use super::static_filter::{StaticFilter, handle_dictionary};

/// Bitmap filter for O(1) set membership via single bit test.
///
/// `UInt8` has only 256 possible values, so the filter stores membership in a
/// 256-bit bitmap instead of using a hash table.
pub(super) struct UInt8BitmapFilter {
null_count: usize,
bits: [u64; 4],
}

impl UInt8BitmapFilter {
pub(super) fn try_new(in_array: &ArrayRef) -> Result<Self> {
let prim_array = in_array.as_primitive_opt::<UInt8Type>().ok_or_else(|| {
exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array")
})?;
let mut bits = [0u64; 4];
let mut set_bit = |v: u8| {
let index = usize::from(v);
bits[index / 64] |= 1u64 << (index % 64);
};

let values = prim_array.values();
match prim_array.nulls() {
None => {
for &v in values {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values.iter().copied().for_each(set_bit) is shorter, but less legible. Both should compile the same.

set_bit(v);
}
}
Some(nulls) => {
for i in
BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
{
set_bit(values[i]);
}
}
}
Ok(Self {
null_count: prim_array.null_count(),
bits,
})
}

#[inline(always)]
fn check(&self, needle: u8) -> bool {
let index = needle as usize;
(self.bits[index / 64] >> (index % 64)) & 1 != 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to know of llvm could optimize the bounds check here (it knows that index <= 256 so index / 64 is always in bounds

@geoffreyclaude geoffreyclaude Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed (with some Codex help): rustc/LLVM does optimize this bounds check away in optimized builds.

Godbolt repro: https://godbolt.org/z/6r1WjvYv4

I used LLVM IR there because retained Rust bounds checks are easy to spot: they show up as a branch to a panic block calling panic_bounds_check.

For the u8 versions (bitmap_check / bitmap_set), there is no such branch. The IR widens the u8 with:

%index = zext i8 %needle to i64

then computes index / 64 as a shift and directly loads/stores from the [u64; 4].

For contrast, the repro also includes the same expression with index: usize; that version does emit:

icmp ult i64 %index, 256
br i1 ..., label %bb1, label %panic

followed by panic_bounds_check.

So I think the safe indexing here already compiles to the unchecked access we want on the hot path, without needing get_unchecked.

}
}

impl StaticFilter for UInt8BitmapFilter {
fn null_count(&self) -> usize {
self.null_count
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
handle_dictionary!(self, v, negated);
let v = v.as_primitive_opt::<UInt8Type>().ok_or_else(|| {
exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array")
})?;
let input_values = v.values();
Ok(build_in_list_result(
v.len(),
v.nulls(),
self.null_count > 0,
negated,
#[inline(always)]
|i| {
// SAFETY: `build_in_list_result` invokes this closure for
// indices in `0..v.len()`, which matches `input_values.len()`.
let needle = unsafe { *input_values.get_unchecked(i) };
self.check(needle)
},
))
}
}

/// Wrapper for f32 that implements Hash and Eq using bit comparison.
/// This treats NaN values as equal to each other when they have the same bit pattern.
Expand Down Expand Up @@ -94,9 +172,13 @@ macro_rules! primitive_static_filter {

impl $Name {
pub(super) fn try_new(in_array: &ArrayRef) -> Result<Self> {
let in_array = in_array
.as_primitive_opt::<$ArrowType>()
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
let in_array =
in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| {
exec_datafusion_err!(
"Failed to downcast an array to a '{}' array",
stringify!($ArrowType)
)
})?;

let mut values = HashSet::with_capacity(in_array.len());
let null_count = in_array.null_count();
Expand All @@ -115,19 +197,14 @@ macro_rules! primitive_static_filter {
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
// Handle dictionary arrays by recursing on the values
downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()))
}
_ => {}
}
handle_dictionary!(self, v, negated);

let v = v
.as_primitive_opt::<$ArrowType>()
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| {
exec_datafusion_err!(
"Failed to downcast an array to a '{}' array",
stringify!($ArrowType)
)
})?;

let haystack_has_nulls = self.null_count > 0;
let needle_values = v.values();
Expand Down Expand Up @@ -188,8 +265,10 @@ macro_rules! primitive_static_filter {
}
(true, true) => {
// Both have nulls - combine needle nulls with haystack-induced nulls
let needle_validity = needle_nulls.map(|n| n.inner().clone())
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
let needle_validity =
needle_nulls.map(|n| n.inner().clone()).unwrap_or_else(
|| BooleanBuffer::new_set(needle_values.len()),
);

// Valid when original "in set" is true (see above)
let haystack_validity = if negated {
Expand All @@ -215,7 +294,6 @@ primitive_static_filter!(Int8StaticFilter, Int8Type);
primitive_static_filter!(Int16StaticFilter, Int16Type);
primitive_static_filter!(Int32StaticFilter, Int32Type);
primitive_static_filter!(Int64StaticFilter, Int64Type);
primitive_static_filter!(UInt8StaticFilter, UInt8Type);
primitive_static_filter!(UInt16StaticFilter, UInt16Type);
primitive_static_filter!(UInt32StaticFilter, UInt32Type);
primitive_static_filter!(UInt64StaticFilter, UInt64Type);
Expand All @@ -231,3 +309,50 @@ macro_rules! float_static_filter {
// Generate specialized filters for float types using ordered wrappers
float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32);
float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64);

#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;

use arrow::array::{DictionaryArray, Int8Array, UInt8Array};

fn assert_contains(
filter: &UInt8BitmapFilter,
needles: &dyn Array,
expected: Vec<Option<bool>>,
) -> Result<()> {
assert_eq!(
filter.contains(needles, false)?,
BooleanArray::from(expected)
);
Ok(())
}

#[test]
fn bitmap_filter_u8_handles_nulls() -> Result<()> {
let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)]));
let filter = UInt8BitmapFilter::try_new(&haystack)?;
let needles = UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]);

assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])?;
assert_eq!(
filter.contains(&needles, true)?,
BooleanArray::from(vec![Some(false), None, None, Some(false)])
);

Ok(())
}

#[test]
fn bitmap_filter_u8_handles_dictionary_needles() -> Result<()> {
let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)]));
let filter = UInt8BitmapFilter::try_new(&haystack)?;

let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]);
let values = Arc::new(UInt8Array::from(vec![Some(1), Some(2), Some(3)]));
let needles = DictionaryArray::try_new(keys, values)?;

assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])
}
}
17 changes: 17 additions & 0 deletions datafusion/physical-expr/src/expressions/in_list/static_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ pub(super) trait StaticFilter {
/// implementation unwraps the dictionary and operates on its values.
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
}

/// Evaluate dictionary-encoded needles by applying a filter to dictionary
/// values and remapping the result through the keys.
macro_rules! handle_dictionary {
($self:ident, $v:ident, $negated:ident) => {
arrow::array::downcast_dictionary_array! {
$v => {
let values_contains = $self.contains($v.values().as_ref(), $negated)?;
let result = arrow::compute::take(&values_contains, $v.keys(), None)?;
return Ok(arrow::array::downcast_array(result.as_ref()))
}
_ => {}
}
};
}

pub(super) use handle_dictionary;
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(super) fn instantiate_static_filter(
DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)),
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)),
DataType::UInt8 => Ok(Arc::new(UInt8BitmapFilter::try_new(&in_array)?)),
DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)),
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
Expand Down
Loading