diff --git a/src/dataframe.rs b/src/dataframe.rs index adecf6b..dc3a720 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -50,10 +50,10 @@ use crate::max_rows::MaxRowsExec; use crate::pre_fetch::PrefetchExec; use crate::stage::DFRayStageExec; use crate::stage_reader::DFRayStageReaderExec; +use crate::util::ResultExt; use crate::util::collect_from_stage; use crate::util::display_plan_with_partition_counts; use crate::util::physical_plan_to_bytes; -use crate::util::ResultExt; /// Internal rust class beyind the DFRayDataFrame python object /// @@ -96,49 +96,10 @@ impl DFRayDataFrame { ) -> PyResult> { let mut stages = vec![]; - // TODO: This can be done more efficiently, likely in one pass but I'm - // struggling to get the TreeNodeRecursion return values to make it do - // what I want. So, two steps for now - - // Step 2: we walk down this stage and replace stages earlier in the tree with - // RayStageReaderExecs as we will need to consume their output instead of - // execute that part of the tree ourselves - let down = |plan: Arc| { - trace!( - "examining plan down:\n{}", - display_plan_with_partition_counts(&plan) - ); - - if let Some(stage_exec) = plan.as_any().downcast_ref::() { - let input = plan.children(); - assert!(input.len() == 1, "RayStageExec must have exactly one child"); - let input = input[0]; - - trace!( - "inserting a ray stage reader to consume: {} with partitioning {}", - displayable(plan.as_ref()).one_line(), - plan.output_partitioning().partition_count() - ); - - let replacement = Arc::new(DFRayStageReaderExec::try_new( - plan.output_partitioning().clone(), - input.schema(), - stage_exec.stage_id, - )?) as Arc; - - Ok(Transformed { - data: replacement, - transformed: true, - tnr: TreeNodeRecursion::Jump, - }) - } else { - Ok(Transformed::no(plan)) - } - }; - let mut partition_groups = vec![]; let mut full_partitions = false; - // Step 1: we walk up the tree from the leaves to find the stages + // We walk up the tree from the leaves to find the stages, record ray stages, and replace + // each ray stage with a corresponding ray reader stage. let up = |plan: Arc| { trace!( "Examining plan up: {}", @@ -151,11 +112,15 @@ impl DFRayDataFrame { assert!(input.len() == 1, "RayStageExec must have exactly one child"); let input = input[0]; - let fixed_plan = input.clone().transform_down(down)?.data; + let replacement = Arc::new(DFRayStageReaderExec::try_new( + plan.output_partitioning().clone(), + input.schema(), + stage_exec.stage_id, + )?) as Arc; let stage = PyDFRayStage::new( stage_exec.stage_id, - fixed_plan, + input.clone(), partition_groups.clone(), full_partitions, ); @@ -163,7 +128,7 @@ impl DFRayDataFrame { full_partitions = false; stages.push(stage); - Ok(Transformed::no(plan)) + Ok(Transformed::yes(replacement)) } else if plan.as_any().downcast_ref::().is_some() { trace!("repartition exec"); let (calculated_partition_groups, replacement) = build_replacement(