Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/core/tests/data/recursive_cte/closure.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
start,end
1,2
2,3
2,4
2,4
4,1
8 changes: 1 addition & 7 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err,
exec_err, get_target_functional_dependencies, internal_datafusion_err,
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
Expand Down Expand Up @@ -178,12 +178,6 @@ impl LogicalPlanBuilder {
recursive_term: LogicalPlan,
is_distinct: bool,
) -> Result<Self> {
// TODO: we need to do a bunch of validation here. Maybe more.
if is_distinct {
return not_impl_err!(
"Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported"
);
}
// Ensure that the static term and the recursive term have the same number of fields
let static_fields_len = self.plan.schema().fields().len();
let recursive_fields_len = recursive_term.schema().fields().len();
Expand Down
94 changes: 82 additions & 12 deletions datafusion/physical-plan/src/recursive_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
use crate::aggregates::group_values::{new_group_values, GroupValues};
use crate::aggregates::order::GroupOrdering;
use crate::execution_plan::{Boundedness, EmissionType};
use crate::metrics::RecordOutput;
use crate::work_table::{ReservedBatches, WorkTable, WorkTableExec};
use crate::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};

use arrow::array::{BooleanArray, BooleanBuilder};
use arrow::compute::filter_record_batch;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -195,8 +199,9 @@ impl ExecutionPlan for RecursiveQueryExec {
Arc::clone(&self.work_table),
Arc::clone(&self.recursive_term),
static_stream,
self.is_distinct,
baseline_metrics,
)))
)?))
}

fn metrics(&self) -> Option<MetricsSet> {
Expand Down Expand Up @@ -268,8 +273,10 @@ struct RecursiveQueryStream {
buffer: Vec<RecordBatch>,
/// Tracks the memory used by the buffer
reservation: MemoryReservation,
// /// Metrics.
_baseline_metrics: BaselineMetrics,
/// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables
distinct_deduplicator: Option<DistinctDeduplicator>,
/// Metrics.
baseline_metrics: BaselineMetrics,
}

impl RecursiveQueryStream {
Expand All @@ -279,12 +286,16 @@ impl RecursiveQueryStream {
work_table: Arc<WorkTable>,
recursive_term: Arc<dyn ExecutionPlan>,
static_stream: SendableRecordBatchStream,
is_distinct: bool,
baseline_metrics: BaselineMetrics,
) -> Self {
) -> Result<Self> {
let schema = static_stream.schema();
let reservation =
MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
Self {
let distinct_deduplicator = is_distinct
.then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context))
.transpose()?;
Ok(Self {
task_context,
work_table,
recursive_term,
Expand All @@ -293,21 +304,28 @@ impl RecursiveQueryStream {
schema,
buffer: vec![],
reservation,
_baseline_metrics: baseline_metrics,
}
distinct_deduplicator,
baseline_metrics,
})
}

/// Push a clone of the given batch to the in memory buffer, and then return
/// a poll with it.
fn push_batch(
mut self: std::pin::Pin<&mut Self>,
batch: RecordBatch,
mut batch: RecordBatch,
) -> Poll<Option<Result<RecordBatch>>> {
let baseline_metrics = self.baseline_metrics.clone();
if let Some(deduplicator) = &mut self.distinct_deduplicator {
let _timer_guard = baseline_metrics.elapsed_compute().timer();
batch = deduplicator.deduplicate(&batch)?;
}

if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
return Poll::Ready(Some(Err(e)));
}

self.buffer.push(batch.clone());
(&batch).record_output(&baseline_metrics);
Poll::Ready(Some(Ok(batch)))
}

Expand Down Expand Up @@ -397,7 +415,6 @@ impl Stream for RecursiveQueryStream {
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
// TODO: we should use this poll to record some metrics!
if let Some(static_stream) = &mut self.static_stream {
// While the static term's stream is available, we'll be forwarding the batches from it (also
// saving them for the initial iteration of the recursive term).
Expand Down Expand Up @@ -434,5 +451,58 @@ impl RecordBatchStream for RecursiveQueryStream {
}
}

/// Deduplicator based on a hash table.
struct DistinctDeduplicator {
/// Grouped rows used for distinct
group_values: Box<dyn GroupValues>,
reservation: MemoryReservation,
intern_output_buffer: Vec<usize>,
}

impl DistinctDeduplicator {
fn new(schema: SchemaRef, task_context: &TaskContext) -> Result<Self> {
let group_values = new_group_values(schema, &GroupOrdering::None)?;
let reservation = MemoryConsumer::new("RecursiveQueryHashTable")
.register(task_context.memory_pool());
Ok(Self {
group_values,
reservation,
intern_output_buffer: Vec::new(),
})
}

/// Remove duplicated rows from the given batch, keeping a state between batches.
///
/// We use a hash table to allocate new group ids for the new rows.
/// [`GroupValues`] allocate increasing group ids.
/// Hence, if groups (i.e., rows) are now, then they have ids >= length before interning, we keep them.
/// We also detect duplicates by enforcing that group ids are increasing.
fn deduplicate(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
let size_before = self.group_values.len();
self.intern_output_buffer.reserve(batch.num_rows());
self.group_values
.intern(batch.columns(), &mut self.intern_output_buffer)?;
let mask = are_increasing_mask(&self.intern_output_buffer, size_before);
self.intern_output_buffer.clear();
// We update the reservation to reflect the new size of the hash table.
self.reservation.try_resize(self.group_values.size())?;
Ok(filter_record_batch(batch, &mask)?)
}
}

/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided min_value
fn are_increasing_mask(values: &[usize], mut min_value: usize) -> BooleanArray {
let mut output = BooleanBuilder::with_capacity(values.len());
for value in values {
if *value >= min_value {
output.append_value(true);
min_value = *value + 1; // We want to be increasing
} else {
output.append_value(false);
}
}
output.finish()
}

#[cfg(test)]
mod tests {}
90 changes: 78 additions & 12 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ WITH RECURSIVE nodes AS (
statement ok
set datafusion.execution.enable_recursive_ctes = true;


# DISTINCT UNION is not supported
query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported
WITH RECURSIVE nodes AS (
SELECT 1 as id
UNION
SELECT id + 1 as id
FROM nodes
WHERE id < 3
) SELECT * FROM nodes


# trivial recursive CTE works
query I rowsort
WITH RECURSIVE nodes AS (
Expand Down Expand Up @@ -122,6 +110,22 @@ physical_plan
08)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)------------WorkTableExec: name=nodes

# simple deduplicating recursive CTE works
query I
WITH RECURSIVE nodes AS (
SELECT id from (VALUES (1), (2)) nodes(id)
UNION
SELECT id + 1 as id
FROM nodes
WHERE id < 4
)
SELECT * FROM nodes
----
1
2
3
4

# setup
statement ok
CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true');
Expand Down Expand Up @@ -1049,6 +1053,68 @@ physical_plan
05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false]
06)------WorkTableExec: name=r

# setup
statement ok
CREATE EXTERNAL TABLE closure STORED as CSV LOCATION '../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header' 'true');

# transitive closure with loop
query II
WITH RECURSIVE trans AS (
SELECT * FROM closure
UNION
SELECT l.start, r.end
FROM trans as l, closure AS r
WHERE l.end = r.start
) SELECT * FROM trans ORDER BY start, end
----
1 1
1 2
1 3
1 4
2 1
2 2
2 3
2 4
4 1
4 2
4 3
4 4

query TT
EXPLAIN WITH RECURSIVE trans AS (
SELECT * FROM closure
UNION
SELECT l.start, r.end
FROM trans as l, closure AS r
WHERE l.end = r.start
) SELECT * FROM trans
----
logical_plan
01)SubqueryAlias: trans
02)--RecursiveQuery: is_distinct=true
03)----Projection: closure.start, closure.end
04)------TableScan: closure
05)----Projection: l.start, r.end
06)------Inner Join: l.end = r.start
07)--------SubqueryAlias: l
08)----------TableScan: trans
09)--------SubqueryAlias: r
10)----------TableScan: closure
physical_plan
01)RecursiveQueryExec: name=trans, is_distinct=true
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true
03)--CoalescePartitionsExec
04)----CoalesceBatchesExec: target_batch_size=8182
05)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(end@1, start@0)], projection=[start@0, end@3]
06)--------CoalesceBatchesExec: target_batch_size=8182
07)----------RepartitionExec: partitioning=Hash([end@1], 4), input_partitions=4
08)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)--------------WorkTableExec: name=trans
10)--------CoalesceBatchesExec: target_batch_size=8182
11)----------RepartitionExec: partitioning=Hash([start@0], 4), input_partitions=4
12)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
13)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true

statement count 0
set datafusion.execution.enable_recursive_ctes = false;

Expand Down