diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 50ff3936937bf..9acbcd6ab63e3 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -321,7 +321,7 @@ impl PhysicalExpr for InListExpr { Some(filter) => { match value { ColumnarValue::Array(array) => { - filter.contains(&array, self.negated)? + filter.contains(array, self.negated)? } ColumnarValue::Scalar(scalar) => { if scalar.is_null() { @@ -338,8 +338,7 @@ impl PhysicalExpr for InListExpr { // Use a 1 row array to avoid code duplication/branching // Since all we do is compute hash and lookup this should be efficient enough let array = scalar.to_array()?; - let result_array = - filter.contains(array.as_ref(), self.negated)?; + let result_array = filter.contains(array, self.negated)?; // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { diff --git a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs index 75e92dbcc59b4..8916df801dc19 100644 --- a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs @@ -16,8 +16,7 @@ // under the License. use arrow::array::{ - Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, - make_comparator, + ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, make_comparator, }; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{SortOptions, take}; @@ -26,6 +25,7 @@ use arrow::util::bit_iterator::BitIndexIterator; use datafusion_common::Result; use datafusion_common::hash_utils::{RandomState, with_hashes}; use hashbrown::HashTable; +use std::sync::Arc; use super::result::build_in_list_result; use super::static_filter::StaticFilter; @@ -99,18 +99,18 @@ impl ArrayStaticFilter { fn find_needles_in_haystack( &self, - needles: &dyn Array, + needles: &ArrayRef, negated: bool, ) -> Result { let needle_nulls = needles.logical_nulls(); let haystack_has_nulls = self.in_array.null_count() != 0; - with_hashes([needles], &self.state, |needle_hashes| { - let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?; + with_hashes([needles.as_ref()], &self.state, |needle_hashes| { + let cmp = make_comparator(&needles, &self.in_array, SortOptions::default())?; Ok(build_in_list_result( needles.len(), - needle_nulls.as_ref(), + needle_nulls, haystack_has_nulls, negated, #[inline(always)] @@ -129,7 +129,7 @@ impl StaticFilter for ArrayStaticFilter { } /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains(&self, v: &dyn Array, negated: bool) -> Result { + fn contains(&self, v: ArrayRef, negated: bool) -> Result { // Null type comparisons always return null (SQL three-valued logic) if v.data_type() == &DataType::Null || self.in_array.data_type() == &DataType::Null @@ -144,19 +144,20 @@ impl StaticFilter for ArrayStaticFilter { // Unwrap dictionary-encoded needles when the value type matches // in_array, evaluating against the dictionary values and mapping // back via keys. + let array = v.as_ref(); downcast_dictionary_array! { - v => { + array => { // Only unwrap when the haystack (in_array) type matches // the dictionary value type - if v.values().data_type() == self.in_array.data_type() { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; + if array.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(Arc::clone(array.values()), negated)?; + let result = take(&values_contains, array.keys(), None)?; return Ok(downcast_array(result.as_ref())); } } _ => {} } - self.find_needles_in_haystack(v, negated) + self.find_needles_in_haystack(&v, negated) } } diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs index e7647e5adb8b7..10e8e24417fb0 100644 --- a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -82,15 +82,19 @@ impl StaticFilter for UInt8BitmapFilter { self.null_count } - fn contains(&self, v: &dyn Array, negated: bool) -> Result { + fn contains(&self, v: ArrayRef, negated: bool) -> Result { handle_dictionary!(self, v, negated); - let v = v.as_primitive_opt::().ok_or_else(|| { - exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array") - })?; - let input_values = v.values(); + let v = v + .as_primitive_opt::() + .ok_or_else(|| { + exec_datafusion_err!("UInt8BitmapFilter: expected UInt8 array") + })? + .clone(); + let len = v.len(); + let (_, input_values, needle_nulls) = v.into_parts(); Ok(build_in_list_result( - v.len(), - v.nulls(), + len, + needle_nulls, self.null_count > 0, negated, #[inline(always)] @@ -196,20 +200,24 @@ macro_rules! primitive_static_filter { self.null_count } - fn contains(&self, v: &dyn Array, negated: bool) -> Result { + fn contains(&self, v: ArrayRef, negated: bool) -> Result { handle_dictionary!(self, v, negated); - let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast an array to a '{}' array", - stringify!($ArrowType) - ) - })?; + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })? + .clone(); let haystack_has_nulls = self.null_count > 0; - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; + let (_, needle_values, needle_nulls) = v.into_parts(); + let needle_has_nulls = needle_nulls + .as_ref() + .is_some_and(|nulls| nulls.null_count() > 0); // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: // ("-" means the value doesn't affect the result) @@ -250,7 +258,7 @@ macro_rules! primitive_static_filter { } (true, false) => { // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() + needle_nulls } (false, true) => { // Only haystack has nulls - result is null when value not in set @@ -266,9 +274,9 @@ macro_rules! primitive_static_filter { (true, true) => { // Both have nulls - combine needle nulls with haystack-induced nulls let needle_validity = - needle_nulls.map(|n| n.inner().clone()).unwrap_or_else( - || BooleanBuffer::new_set(needle_values.len()), - ); + needle_nulls.map(|n| n.into_inner()).unwrap_or_else(|| { + BooleanBuffer::new_set(needle_values.len()) + }); // Valid when original "in set" is true (see above) let haystack_validity = if negated { @@ -278,7 +286,8 @@ macro_rules! primitive_static_filter { }; // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; + let mut combined_validity = needle_validity; + combined_validity &= &haystack_validity; Some(NullBuffer::new(combined_validity)) } }; @@ -319,7 +328,7 @@ mod tests { fn assert_contains( filter: &UInt8BitmapFilter, - needles: &dyn Array, + needles: ArrayRef, expected: Vec>, ) -> Result<()> { assert_eq!( @@ -333,11 +342,16 @@ mod tests { fn bitmap_filter_u8_handles_nulls() -> Result<()> { let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])); let filter = UInt8BitmapFilter::try_new(&haystack)?; - let needles = UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]); - - assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])?; + let needles: ArrayRef = + Arc::new(UInt8Array::from(vec![Some(1), Some(2), None, Some(3)])); + + assert_contains( + &filter, + Arc::clone(&needles), + vec![Some(true), None, None, Some(true)], + )?; assert_eq!( - filter.contains(&needles, true)?, + filter.contains(needles, true)?, BooleanArray::from(vec![Some(false), None, None, Some(false)]) ); @@ -351,8 +365,8 @@ mod tests { let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]); let values = Arc::new(UInt8Array::from(vec![Some(1), Some(2), Some(3)])); - let needles = DictionaryArray::try_new(keys, values)?; + let needles: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values)?); - assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)]) + assert_contains(&filter, needles, vec![Some(true), None, None, Some(true)]) } } diff --git a/datafusion/physical-expr/src/expressions/in_list/result.rs b/datafusion/physical-expr/src/expressions/in_list/result.rs index 3ebdbfe19f743..efc793302324a 100644 --- a/datafusion/physical-expr/src/expressions/in_list/result.rs +++ b/datafusion/physical-expr/src/expressions/in_list/result.rs @@ -46,7 +46,7 @@ use arrow::buffer::{BooleanBuffer, NullBuffer}; #[inline] pub(crate) fn build_in_list_result( len: usize, - needle_nulls: Option<&NullBuffer>, + needle_nulls: Option, haystack_has_nulls: bool, negated: bool, contains: C, @@ -64,7 +64,7 @@ where /// It handles nulls using bitmap operations. #[inline] pub(crate) fn build_result_from_contains( - needle_nulls: Option<&NullBuffer>, + needle_nulls: Option, haystack_has_nulls: bool, negated: bool, contains_buf: BooleanBuffer, @@ -73,7 +73,8 @@ pub(crate) fn build_result_from_contains( // Haystack has nulls: result is null unless value is found. (Some(v), true, false) => { // values: valid & contains, nulls: valid & contains - let values = v.inner() & &contains_buf; + let mut values = contains_buf; + values &= v.inner(); BooleanArray::new(values.clone(), Some(NullBuffer::new(values))) } (None, true, false) => { @@ -83,8 +84,10 @@ pub(crate) fn build_result_from_contains( // NOT IN with nulls: false if found, null if not found or needle null. // values: valid & !contains, nulls: valid & contains let valid = v.inner(); - let values = valid & &(!&contains_buf); - let nulls = valid & &contains_buf; + let mut values = !&contains_buf; + values &= valid; + let mut nulls = contains_buf; + nulls &= valid; BooleanArray::new(values, Some(NullBuffer::new(nulls))) } (None, true, true) => { @@ -93,11 +96,15 @@ pub(crate) fn build_result_from_contains( // Haystack has no nulls: result validity follows needle validity. (Some(v), false, false) => { // values: valid & contains, nulls: valid - BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) + let mut values = contains_buf; + values &= v.inner(); + BooleanArray::new(values, Some(v)) } (Some(v), false, true) => { // values: valid & !contains, nulls: valid - BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + let mut values = !&contains_buf; + values &= v.inner(); + BooleanArray::new(values, Some(v)) } (None, false, false) => BooleanArray::new(contains_buf, None), (None, false, true) => BooleanArray::new(!&contains_buf, None), diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs index 3c964d4183474..be59c3f8ea7c9 100644 --- a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, BooleanArray}; +use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::Result; /// Trait for InList static filters. @@ -33,17 +33,19 @@ pub(super) trait StaticFilter { /// Checks if values in `v` (needle) are contained in this filter's /// haystack. `v` may be dictionary-encoded, in which case the /// implementation unwraps the dictionary and operates on its values. - fn contains(&self, v: &dyn Array, negated: bool) -> Result; + fn contains(&self, v: ArrayRef, negated: bool) -> Result; } /// Evaluate dictionary-encoded needles by applying a filter to dictionary /// values and remapping the result through the keys. macro_rules! handle_dictionary { ($self:ident, $v:ident, $negated:ident) => { + let array = $v.as_ref(); arrow::array::downcast_dictionary_array! { - $v => { - let values_contains = $self.contains($v.values().as_ref(), $negated)?; - let result = arrow::compute::take(&values_contains, $v.keys(), None)?; + array => { + let values_contains = + $self.contains(std::sync::Arc::clone(array.values()), $negated)?; + let result = arrow::compute::take(&values_contains, array.keys(), None)?; return Ok(arrow::array::downcast_array(result.as_ref())) } _ => {}