Skip to content

Improve documentation and add examples for ArrowPredicateFn #7480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions parquet/src/arrow/arrow_reader/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ use arrow_schema::ArrowError;

/// A predicate operating on [`RecordBatch`]
///
/// See [`RowFilter`] for more information on the use of this trait.
/// See also:
/// * [`RowFilter`] for more information on applying filters during the
/// Parquet decoding process.
/// * [`ArrowPredicateFn`] for a concrete implementation based on a function
pub trait ArrowPredicate: Send + 'static {
/// Returns the [`ProjectionMask`] that describes the columns required
/// to evaluate this predicate. All projected columns will be provided in the `batch`
/// passed to [`evaluate`](Self::evaluate)
/// to evaluate this predicate.
///
/// All projected columns will be provided in the `batch` passed to
/// [`evaluate`](Self::evaluate). The projection mask should be as small as
/// possible because any columns needed for the overall projection mask are
/// decoded again after a predicate is applied.
fn projection(&self) -> &ProjectionMask;

/// Evaluate this predicate for the given [`RecordBatch`] containing the columns
Expand All @@ -38,7 +45,63 @@ pub trait ArrowPredicate: Send + 'static {
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError>;
}

/// An [`ArrowPredicate`] created from an [`FnMut`]
/// An [`ArrowPredicate`] created from an [`FnMut`] and a [`ProjectionMask`]
///
/// See [`RowFilter`] for more information on applying filters during the
/// Parquet decoding process.
///
/// The function is passed `RecordBatch`es with only the columns specified in
/// the [`ProjectionMask`].
///
/// The function must return a [`BooleanArray`] that has the same length as the
/// input `batch` where each row indicates whether the row should be returned:
/// * `true`: the row should be returned
/// * `false` or `null`: the row should not be returned
///
/// # Example:
///
/// Given an input schema: `"a:int64", "b:int64"`, you can create a predicate that
/// evaluates `b > 0` like this:
///
/// ```
/// # use std::sync::Arc;
/// # use arrow::compute::kernels::cmp::gt;
/// # use arrow_array::{BooleanArray, Int64Array, RecordBatch};
/// # use arrow_array::cast::AsArray;
/// # use arrow_array::types::Int64Type;
/// # use parquet::arrow::arrow_reader::ArrowPredicateFn;
/// # use parquet::arrow::ProjectionMask;
/// # use parquet::schema::types::{SchemaDescriptor, Type};
/// # use parquet::basic; // note there are two `Type`s that are different
/// # // Schema for a table with one columns: "a" (int64) and "b" (int64)
/// # let descriptor = SchemaDescriptor::new(
/// # Arc::new(
/// # Type::group_type_builder("my_schema")
/// # .with_fields(vec![
/// # Arc::new(
/// # Type::primitive_type_builder("a", basic::Type::INT64)
/// # .build().unwrap()
/// # ),
/// # Arc::new(
/// # Type::primitive_type_builder("b", basic::Type::INT64)
/// # .build().unwrap()
/// # ),
/// # ])
/// # .build().unwrap()
/// # )
/// # );
/// // Create a mask for selecting only the second column "b" (index 1)
/// let projection_mask = ProjectionMask::leaves(&descriptor, [1]);
/// // Closure that evaluates "b > 0"
/// let predicate = |batch: RecordBatch| {
/// let scalar_0 = Int64Array::new_scalar(0);
/// let column = batch.column(0).as_primitive::<Int64Type>();
/// // call the gt kernel to compute `>` which returns a BooleanArray
/// gt(column, &scalar_0)
/// };
/// // Create ArrowPredicateFn that can be passed to RowFilter
/// let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate);
/// ```
pub struct ArrowPredicateFn<F> {
f: F,
projection: ProjectionMask,
Expand All @@ -48,10 +111,8 @@ impl<F> ArrowPredicateFn<F>
where
F: FnMut(RecordBatch) -> Result<BooleanArray, ArrowError> + Send + 'static,
{
/// Create a new [`ArrowPredicateFn`]. `f` will be passed batches
/// that contains the columns specified in `projection`
/// and returns a [`BooleanArray`] that describes which rows should
/// be passed along
/// Create a new [`ArrowPredicateFn`] that invokes `f` on the columns
/// specified in `projection`.
pub fn new(projection: ProjectionMask, f: F) -> Self {
Self { f, projection }
}
Expand Down
Loading