Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 10 additions & 45 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ use crate::max_rows::MaxRowsExec;
use crate::pre_fetch::PrefetchExec;
use crate::stage::DFRayStageExec;
use crate::stage_reader::DFRayStageReaderExec;
use crate::util::ResultExt;
use crate::util::collect_from_stage;
use crate::util::display_plan_with_partition_counts;
use crate::util::physical_plan_to_bytes;
use crate::util::ResultExt;

/// Internal rust class beyind the DFRayDataFrame python object
///
Expand Down Expand Up @@ -96,49 +96,10 @@ impl DFRayDataFrame {
) -> PyResult<Vec<PyDFRayStage>> {
let mut stages = vec![];

// TODO: This can be done more efficiently, likely in one pass but I'm
// struggling to get the TreeNodeRecursion return values to make it do
// what I want. So, two steps for now

// Step 2: we walk down this stage and replace stages earlier in the tree with
// RayStageReaderExecs as we will need to consume their output instead of
// execute that part of the tree ourselves
let down = |plan: Arc<dyn ExecutionPlan>| {
trace!(
"examining plan down:\n{}",
display_plan_with_partition_counts(&plan)
);

if let Some(stage_exec) = plan.as_any().downcast_ref::<DFRayStageExec>() {
let input = plan.children();
assert!(input.len() == 1, "RayStageExec must have exactly one child");
let input = input[0];

trace!(
"inserting a ray stage reader to consume: {} with partitioning {}",
displayable(plan.as_ref()).one_line(),
plan.output_partitioning().partition_count()
);

let replacement = Arc::new(DFRayStageReaderExec::try_new(
plan.output_partitioning().clone(),
input.schema(),
stage_exec.stage_id,
)?) as Arc<dyn ExecutionPlan>;

Ok(Transformed {
data: replacement,
transformed: true,
tnr: TreeNodeRecursion::Jump,
})
} else {
Ok(Transformed::no(plan))
}
};

let mut partition_groups = vec![];
let mut full_partitions = false;
// Step 1: we walk up the tree from the leaves to find the stages
// We walk up the tree from the leaves to find the stages, record ray stages, and replace
// each ray stage with a corresponding ray reader stage.
let up = |plan: Arc<dyn ExecutionPlan>| {
trace!(
"Examining plan up: {}",
Expand All @@ -151,19 +112,23 @@ impl DFRayDataFrame {
assert!(input.len() == 1, "RayStageExec must have exactly one child");
let input = input[0];

let fixed_plan = input.clone().transform_down(down)?.data;
let replacement = Arc::new(DFRayStageReaderExec::try_new(
plan.output_partitioning().clone(),
input.schema(),
stage_exec.stage_id,
)?) as Arc<dyn ExecutionPlan>;

let stage = PyDFRayStage::new(
stage_exec.stage_id,
fixed_plan,
input.clone(),
partition_groups.clone(),
full_partitions,
);
partition_groups = vec![];
full_partitions = false;

stages.push(stage);
Ok(Transformed::no(plan))
Ok(Transformed::yes(replacement))
} else if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
trace!("repartition exec");
let (calculated_partition_groups, replacement) = build_replacement(
Expand Down
Loading