Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/core/tests/data/recursive_cte/closure.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
start,end
1,2
2,3
2,4
2,4
4,1
8 changes: 1 addition & 7 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err,
exec_err, get_target_functional_dependencies, internal_datafusion_err,
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
Expand Down Expand Up @@ -178,12 +178,6 @@ impl LogicalPlanBuilder {
recursive_term: LogicalPlan,
is_distinct: bool,
) -> Result<Self> {
// TODO: we need to do a bunch of validation here. Maybe more.
if is_distinct {
return not_impl_err!(
"Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported"
);
}
// Ensure that the static term and the recursive term have the same number of fields
let static_fields_len = self.plan.schema().fields().len();
let recursive_fields_len = recursive_term.schema().fields().len();
Expand Down
88 changes: 78 additions & 10 deletions datafusion/physical-plan/src/recursive_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
use crate::aggregates::group_values::{new_group_values, GroupValues};
use crate::aggregates::order::GroupOrdering;
use crate::execution_plan::{Boundedness, EmissionType};
use crate::metrics::RecordOutput;
use crate::work_table::{ReservedBatches, WorkTable, WorkTableExec};
use crate::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};

use arrow::array::{BooleanArray, BooleanBuilder};
use arrow::compute::filter_record_batch;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -195,8 +199,9 @@ impl ExecutionPlan for RecursiveQueryExec {
Arc::clone(&self.work_table),
Arc::clone(&self.recursive_term),
static_stream,
self.is_distinct,
baseline_metrics,
)))
)?))
}

fn metrics(&self) -> Option<MetricsSet> {
Expand Down Expand Up @@ -268,8 +273,10 @@ struct RecursiveQueryStream {
buffer: Vec<RecordBatch>,
/// Tracks the memory used by the buffer
reservation: MemoryReservation,
/// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
distinct_deduplicator: Option<DistinctDeduplicator>,
// /// Metrics.
_baseline_metrics: BaselineMetrics,
baseline_metrics: BaselineMetrics,
}

impl RecursiveQueryStream {
Expand All @@ -279,12 +286,16 @@ impl RecursiveQueryStream {
work_table: Arc<WorkTable>,
recursive_term: Arc<dyn ExecutionPlan>,
static_stream: SendableRecordBatchStream,
is_distinct: bool,
baseline_metrics: BaselineMetrics,
) -> Self {
) -> Result<Self> {
let schema = static_stream.schema();
let reservation =
MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
Self {
let distinct_deduplicator = is_distinct
.then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
.transpose()?;
Ok(Self {
task_context,
work_table,
recursive_term,
Expand All @@ -293,21 +304,28 @@ impl RecursiveQueryStream {
schema,
buffer: vec![],
reservation,
_baseline_metrics: baseline_metrics,
}
distinct_deduplicator,
baseline_metrics,
})
}

/// Push a clone of the given batch to the in memory buffer, and then return
/// a poll with it.
fn push_batch(
mut self: std::pin::Pin<&mut Self>,
batch: RecordBatch,
mut batch: RecordBatch,
) -> Poll<Option<Result<RecordBatch>>> {
let baseline_metrics = self.baseline_metrics.clone();
if let Some(deduplicator) = &mut self.distinct_deduplicator {
let _timer_guard = baseline_metrics.elapsed_compute().timer();
batch = deduplicator.deduplicate(&batch)?;
}

if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
return Poll::Ready(Some(Err(e)));
}

self.buffer.push(batch.clone());
(&batch).record_output(&baseline_metrics);
Poll::Ready(Some(Ok(batch)))
}

Expand Down Expand Up @@ -434,5 +452,55 @@ impl RecordBatchStream for RecursiveQueryStream {
}
}

/// Deduplicator based on a hash table.
struct DistinctDeduplicator {
/// Grouped rows used for distinct
group_values: Box<dyn GroupValues>,
reservation: MemoryReservation,
intern_output_buffer: Vec<usize>,
}

impl DistinctDeduplicator {
fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
let group_values = new_group_values(schema, &GroupOrdering::None)?;
let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
.register(task_context.memory_pool());
Ok(Self {
group_values,
reservation,
intern_output_buffer: Vec::new(),
})
}

fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
// We use the hash table to allocate new group ids.
// If they are new, i.e., if they have ids >= length before interning, we keep them.
// We also detect duplicates by enforcing that group ids are increasing.
let size_before = self.group_values.len();
self.intern_output_buffer.reserve(batch.num_rows());
self.group_values
.intern(batch.columns(), &mut self.intern_output_buffer)?;
let mask = are_increasing_mask(&self.intern_output_buffer, size_before);
self.intern_output_buffer.clear();
// We update the reservation to reflect the new size of the hash table.
self.reservation.try_resize(self.group_values.size())?;
Ok(filter_record_batch(batch, &mask)?)
}
}

/// Return a mask, each element true if the value is greater than all previous ones and greater or equal than the min_value
fn are_increasing_mask(values: &[usize], mut min_value: usize) -> BooleanArray {
let mut output = BooleanBuilder::with_capacity(values.len());
for value in values {
if *value >= min_value {
output.append_value(true);
min_value = *value + 1; // We want to be increasing
} else {
output.append_value(false);
}
}
output.finish()
}

#[cfg(test)]
mod tests {}
90 changes: 78 additions & 12 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ WITH RECURSIVE nodes AS (
statement ok
set datafusion.execution.enable_recursive_ctes = true;


# DISTINCT UNION is not supported
query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported
WITH RECURSIVE nodes AS (
SELECT 1 as id
UNION
SELECT id + 1 as id
FROM nodes
WHERE id < 3
) SELECT * FROM nodes


# trivial recursive CTE works
query I rowsort
WITH RECURSIVE nodes AS (
Expand Down Expand Up @@ -122,6 +110,22 @@ physical_plan
08)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)------------WorkTableExec: name=nodes

# trivial deduplicating recursive CTE works
query I
WITH RECURSIVE nodes AS (
SELECT 1 as id
UNION ALL
SELECT id + 1 as id
FROM nodes
WHERE id < 4
)
SELECT * FROM nodes
----
1
2
3
4

# setup
statement ok
CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true');
Expand Down Expand Up @@ -1049,6 +1053,68 @@ physical_plan
05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false]
06)------WorkTableExec: name=r

# setup
statement ok
CREATE EXTERNAL TABLE closure STORED as CSV LOCATION '../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header' 'true');

# transitive closure with loop
query II
WITH RECURSIVE trans AS (
SELECT * FROM closure
UNION
SELECT l.start, r.end
FROM trans as l, closure AS r
WHERE l.end = r.start
) SELECT * FROM trans ORDER BY start, end
----
1 1
1 2
1 3
1 4
2 1
2 2
2 3
2 4
4 1
4 2
4 3
4 4

query TT
EXPLAIN WITH RECURSIVE trans AS (
SELECT * FROM closure
UNION
SELECT l.start, r.end
FROM trans as l, closure AS r
WHERE l.end = r.start
) SELECT * FROM trans
----
logical_plan
01)SubqueryAlias: trans
02)--RecursiveQuery: is_distinct=true
03)----Projection: closure.start, closure.end
04)------TableScan: closure
05)----Projection: l.start, r.end
06)------Inner Join: l.end = r.start
07)--------SubqueryAlias: l
08)----------TableScan: trans
09)--------SubqueryAlias: r
10)----------TableScan: closure
physical_plan
01)RecursiveQueryExec: name=trans, is_distinct=true
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true
03)--CoalescePartitionsExec
04)----CoalesceBatchesExec: target_batch_size=8182
05)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(end@1, start@0)], projection=[start@0, end@3]
06)--------CoalesceBatchesExec: target_batch_size=8182
07)----------RepartitionExec: partitioning=Hash([end@1], 4), input_partitions=4
08)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)--------------WorkTableExec: name=trans
10)--------CoalesceBatchesExec: target_batch_size=8182
11)----------RepartitionExec: partitioning=Hash([start@0], 4), input_partitions=4
12)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
13)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true

statement count 0
set datafusion.execution.enable_recursive_ctes = false;

Expand Down