Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl PhysicalExpr for InListExpr {
Some(filter) => {
match value {
ColumnarValue::Array(array) => {
filter.contains(&array, self.negated)?
filter.contains(array, self.negated)?
}
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
Expand All @@ -338,8 +338,7 @@ impl PhysicalExpr for InListExpr {
// Use a 1 row array to avoid code duplication/branching
// Since all we do is compute hash and lookup this should be efficient enough
let array = scalar.to_array()?;
let result_array =
filter.contains(array.as_ref(), self.negated)?;
let result_array = filter.contains(array, self.negated)?;
// Broadcast the single result to all rows
// Must check is_null() to preserve NULL values (SQL three-valued logic)
if result_array.is_null(0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
// under the License.

use arrow::array::{
Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array,
make_comparator,
ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, make_comparator,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::{SortOptions, take};
Expand All @@ -26,6 +25,7 @@ use arrow::util::bit_iterator::BitIndexIterator;
use datafusion_common::Result;
use datafusion_common::hash_utils::{RandomState, with_hashes};
use hashbrown::HashTable;
use std::sync::Arc;

use super::result::build_in_list_result;
use super::static_filter::StaticFilter;
Expand Down Expand Up @@ -99,18 +99,18 @@ impl ArrayStaticFilter {

fn find_needles_in_haystack(
&self,
needles: &dyn Array,
needles: &ArrayRef,
negated: bool,
) -> Result<BooleanArray> {
let needle_nulls = needles.logical_nulls();
let haystack_has_nulls = self.in_array.null_count() != 0;

with_hashes([needles], &self.state, |needle_hashes| {
let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?;
with_hashes([needles.as_ref()], &self.state, |needle_hashes| {
let cmp = make_comparator(&needles, &self.in_array, SortOptions::default())?;

Ok(build_in_list_result(
needles.len(),
needle_nulls.as_ref(),
needle_nulls,
haystack_has_nulls,
negated,
#[inline(always)]
Expand All @@ -129,7 +129,7 @@ impl StaticFilter for ArrayStaticFilter {
}

/// Checks if values in `v` are contained in the `in_array` using this hash set for lookup.
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
fn contains(&self, v: ArrayRef, negated: bool) -> Result<BooleanArray> {
// Null type comparisons always return null (SQL three-valued logic)
if v.data_type() == &DataType::Null
|| self.in_array.data_type() == &DataType::Null
Expand All @@ -144,19 +144,20 @@ impl StaticFilter for ArrayStaticFilter {
// Unwrap dictionary-encoded needles when the value type matches
// in_array, evaluating against the dictionary values and mapping
// back via keys.
let array = v.as_ref();
downcast_dictionary_array! {
v => {
array => {
// Only unwrap when the haystack (in_array) type matches
// the dictionary value type
if v.values().data_type() == self.in_array.data_type() {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
if array.values().data_type() == self.in_array.data_type() {
let values_contains = self.contains(Arc::clone(array.values()), negated)?;
let result = take(&values_contains, array.keys(), None)?;
return Ok(downcast_array(result.as_ref()));
}
}
_ => {}
}

self.find_needles_in_haystack(v, negated)
self.find_needles_in_haystack(&v, negated)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,19 @@ impl StaticFilter for UInt8BitmapFilter {
self.null_count
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
fn contains(&self, v: ArrayRef, negated: bool) -> Result<BooleanArray> {
handle_dictionary!(self, v, negated);
let v = v.as_primitive_opt::<UInt8Type>().ok_or_else(|| {
exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array")
})?;
let input_values = v.values();
let v = v
.as_primitive_opt::<UInt8Type>()
.ok_or_else(|| {
exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array")
})?
.clone();
let len = v.len();
let (_, input_values, needle_nulls) = v.into_parts();
Ok(build_in_list_result(
v.len(),
v.nulls(),
len,
needle_nulls,
self.null_count > 0,
negated,
#[inline(always)]
Expand Down Expand Up @@ -196,20 +200,24 @@ macro_rules! primitive_static_filter {
self.null_count
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
fn contains(&self, v: ArrayRef, negated: bool) -> Result<BooleanArray> {
handle_dictionary!(self, v, negated);

let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| {
exec_datafusion_err!(
"Failed to downcast an array to a '{}' array",
stringify!($ArrowType)
)
})?;
let v = v
.as_primitive_opt::<$ArrowType>()
.ok_or_else(|| {
exec_datafusion_err!(
"Failed to downcast an array to a '{}' array",
stringify!($ArrowType)
)
})?
.clone();

let haystack_has_nulls = self.null_count > 0;
let needle_values = v.values();
let needle_nulls = v.nulls();
let needle_has_nulls = v.null_count() > 0;
let (_, needle_values, needle_nulls) = v.into_parts();
let needle_has_nulls = needle_nulls
.as_ref()
.is_some_and(|nulls| nulls.null_count() > 0);

// Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
// ("-" means the value doesn't affect the result)
Expand Down Expand Up @@ -250,7 +258,7 @@ macro_rules! primitive_static_filter {
}
(true, false) => {
// Only needle has nulls - just use needle's null mask
needle_nulls.cloned()
needle_nulls
}
(false, true) => {
// Only haystack has nulls - result is null when value not in set
Expand All @@ -266,9 +274,9 @@ macro_rules! primitive_static_filter {
(true, true) => {
// Both have nulls - combine needle nulls with haystack-induced nulls
let needle_validity =
needle_nulls.map(|n| n.inner().clone()).unwrap_or_else(
|| BooleanBuffer::new_set(needle_values.len()),
);
needle_nulls.map(|n| n.into_inner()).unwrap_or_else(|| {
BooleanBuffer::new_set(needle_values.len())
});

// Valid when original "in set" is true (see above)
let haystack_validity = if negated {
Expand All @@ -278,7 +286,8 @@ macro_rules! primitive_static_filter {
};

// Combined validity: valid only where both are valid
let combined_validity = &needle_validity & &haystack_validity;
let mut combined_validity = needle_validity;
combined_validity &= &haystack_validity;
Some(NullBuffer::new(combined_validity))
}
};
Expand Down Expand Up @@ -319,7 +328,7 @@ mod tests {

fn assert_contains(
filter: &UInt8BitmapFilter,
needles: &dyn Array,
needles: ArrayRef,
expected: Vec<Option<bool>>,
) -> Result<()> {
assert_eq!(
Expand All @@ -333,11 +342,16 @@ mod tests {
fn bitmap_filter_u8_handles_nulls() -> Result<()> {
let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)]));
let filter = UInt8BitmapFilter::try_new(&haystack)?;
let needles = UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]);

assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])?;
let needles: ArrayRef =
Arc::new(UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]));

assert_contains(
&filter,
Arc::clone(&needles),
vec![Some(true), None, None, Some(true)],
)?;
assert_eq!(
filter.contains(&needles, true)?,
filter.contains(needles, true)?,
BooleanArray::from(vec![Some(false), None, None, Some(false)])
);

Expand All @@ -351,8 +365,8 @@ mod tests {

let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]);
let values = Arc::new(UInt8Array::from(vec![Some(1), Some(2), Some(3)]));
let needles = DictionaryArray::try_new(keys, values)?;
let needles: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values)?);

assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])
assert_contains(&filter, needles, vec![Some(true), None, None, Some(true)])
}
}
21 changes: 14 additions & 7 deletions datafusion/physical-expr/src/expressions/in_list/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use arrow::buffer::{BooleanBuffer, NullBuffer};
#[inline]
pub(crate) fn build_in_list_result<C>(
len: usize,
needle_nulls: Option<&NullBuffer>,
needle_nulls: Option<NullBuffer>,
haystack_has_nulls: bool,
negated: bool,
contains: C,
Expand All @@ -64,7 +64,7 @@ where
/// It handles nulls using bitmap operations.
#[inline]
pub(crate) fn build_result_from_contains(
needle_nulls: Option<&NullBuffer>,
needle_nulls: Option<NullBuffer>,
haystack_has_nulls: bool,
negated: bool,
contains_buf: BooleanBuffer,
Expand All @@ -73,7 +73,8 @@ pub(crate) fn build_result_from_contains(
// Haystack has nulls: result is null unless value is found.
(Some(v), true, false) => {
// values: valid & contains, nulls: valid & contains
let values = v.inner() & &contains_buf;
let mut values = contains_buf;
values &= v.inner();
BooleanArray::new(values.clone(), Some(NullBuffer::new(values)))
}
(None, true, false) => {
Expand All @@ -83,8 +84,10 @@ pub(crate) fn build_result_from_contains(
// NOT IN with nulls: false if found, null if not found or needle null.
// values: valid & !contains, nulls: valid & contains
let valid = v.inner();
let values = valid & &(!&contains_buf);
let nulls = valid & &contains_buf;
let mut values = !&contains_buf;
values &= valid;
let mut nulls = contains_buf;
nulls &= valid;
BooleanArray::new(values, Some(NullBuffer::new(nulls)))
}
(None, true, true) => {
Expand All @@ -93,11 +96,15 @@ pub(crate) fn build_result_from_contains(
// Haystack has no nulls: result validity follows needle validity.
(Some(v), false, false) => {
// values: valid & contains, nulls: valid
BooleanArray::new(v.inner() & &contains_buf, Some(v.clone()))
let mut values = contains_buf;
values &= v.inner();
BooleanArray::new(values, Some(v))
}
(Some(v), false, true) => {
// values: valid & !contains, nulls: valid
BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone()))
let mut values = !&contains_buf;
values &= v.inner();
BooleanArray::new(values, Some(v))
}
(None, false, false) => BooleanArray::new(contains_buf, None),
(None, false, true) => BooleanArray::new(!&contains_buf, None),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, BooleanArray};
use arrow::array::{ArrayRef, BooleanArray};
use datafusion_common::Result;

/// Trait for InList static filters.
Expand All @@ -33,17 +33,19 @@ pub(super) trait StaticFilter {
/// Checks if values in `v` (needle) are contained in this filter's
/// haystack. `v` may be dictionary-encoded, in which case the
/// implementation unwraps the dictionary and operates on its values.
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
fn contains(&self, v: ArrayRef, negated: bool) -> Result<BooleanArray>;
}

/// Evaluate dictionary-encoded needles by applying a filter to dictionary
/// values and remapping the result through the keys.
macro_rules! handle_dictionary {
($self:ident, $v:ident, $negated:ident) => {
let array = $v.as_ref();
arrow::array::downcast_dictionary_array! {
$v => {
let values_contains = $self.contains($v.values().as_ref(), $negated)?;
let result = arrow::compute::take(&values_contains, $v.keys(), None)?;
array => {
let values_contains =
$self.contains(std::sync::Arc::clone(array.values()), $negated)?;
let result = arrow::compute::take(&values_contains, array.keys(), None)?;
return Ok(arrow::array::downcast_array(result.as_ref()))
}
_ => {}
Expand Down
Loading