diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 159e2764303..07d8ec5cd36 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -191,7 +191,14 @@ def wrapper(*args, **kwargs): hidden=True, help="If specified, it identifies the task that started this resume call. It is in the form of {step_name}-{task_id}", ) -@click.argument("step-to-rerun", required=False) +@click.option( + "--until", + default=None, + show_default=True, + help="If specified, runs up to the specified step(s) (exclusive) and stops. " + "Multiple steps can be specified using a comma-separated list.", +) +@click.argument("step-to-rerun", required=False, nargs=-1) @click.command(help="Resume execution of a previous run of this flow.") @tracing.cli("cli/resume") @common_run_options @@ -200,6 +207,7 @@ def resume( obj, tags=None, step_to_rerun=None, + until=None, origin_run_id=None, run_id=None, clone_only=False, @@ -225,13 +233,14 @@ def resume( steps_to_rerun = set() else: # validate step name - if step_to_rerun not in obj.graph.nodes: - raise CommandException( - "invalid step name {0} specified, must be step present in " - "current form of execution graph. Valid step names include: {1}".format( - step_to_rerun, ",".join(list(obj.graph.nodes.keys())) + for s in step_to_rerun: + if s not in obj.graph.nodes: + raise CommandException( + "Invalid step name {0} specified, must be step present in " + "current form of execution graph. Valid step names include: {1}".format( + s, ",".join(list(obj.graph.nodes.keys())) + ) ) - ) ## TODO: instead of checking execution path here, can add a warning later ## instead of throwing an error. This is for resuming a step which was not @@ -245,8 +254,22 @@ def resume( # f"part of the original execution path for run '{origin_run_id}'." # ) - steps_to_rerun = {step_to_rerun} - + steps_to_rerun = set(step_to_rerun) + + if clone_only and until is not None: + raise CommandException("Cannot specify both --clone-only and --until") + + until_steps = None + if until is not None: + until_steps = set(until.split(",")) + for step in until_steps: + if step not in obj.graph.nodes: + raise CommandException( + "Invalid until step name {0} specified, must be step present in " + "current form of execution graph. Valid step names include: {1}".format( + step, ",".join(list(obj.graph.nodes.keys())) + ) + ) if run_id: # Run-ids that are provided by the metadata service are always integers. # External providers or run-ids (like external schedulers) always need to @@ -274,12 +297,16 @@ def resume( clone_only=clone_only, reentrant=reentrant, steps_to_rerun=steps_to_rerun, + until_steps=until_steps, max_workers=max_workers, max_num_splits=max_num_splits, max_log_size=max_log_size * 1024 * 1024, resume_identifier=resume_identifier, ) write_file(run_id_file, runtime.run_id) + if until is not None: + write_latest_run_id(obj, runtime.run_id) + runtime.print_workflow_info() runtime.persist_constants() diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 8f3864df3af..952001c9389 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -26,7 +26,6 @@ from contextlib import contextmanager from . import get_namespace -from .metadata_provider import MetaDatum from .metaflow_config import FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, MAX_ATTEMPTS, UI_URL from .exception import ( MetaflowException, @@ -44,7 +43,6 @@ from .unbounded_foreach import ( CONTROL_TASK_TAG, UBF_CONTROL, - UBF_TASK, ) from .user_configs.config_options import ConfigInput @@ -103,6 +101,7 @@ def __init__( clone_only=False, reentrant=False, steps_to_rerun=None, + until_steps=None, max_workers=MAX_WORKERS, max_num_splits=MAX_NUM_SPLITS, max_log_size=MAX_LOG_SIZE, @@ -145,14 +144,29 @@ def __init__( self._skip_decorator_hooks = skip_decorator_hooks # If steps_to_rerun is specified, we will not clone them in resume mode. - self._steps_to_rerun = steps_to_rerun or {} - # sorted_nodes are in topological order already, so we only need to - # iterate through the nodes once to get a stable set of rerun steps. + self._steps_to_rerun = steps_to_rerun or set() + self._steps_can_clone = set() + self._steps_no_run = until_steps or set() + + all_steps = set() + # If clone_only is specified, we should have no until_steps and no steps_to_rerun + # so the computation below yields reruning all the steps that we previously + # executed. + # In the other cases, we will allow the cloning of steps up to but not + # inclusive of anything in steps_to_rerun and at the end, steps_to_rerun + # will contain all steps up to but not inclusive of anything in _steps_no_run. for step_name in self._graph.sorted_nodes: - if step_name in self._steps_to_rerun: - out_funcs = self._graph[step_name].out_funcs or [] + all_steps.add(step_name) + out_funcs = self._graph[step_name].out_funcs or [] + if step_name in self._steps_no_run: + for next_step in out_funcs: + self._steps_no_run.add(next_step) + elif step_name in self._steps_to_rerun: for next_step in out_funcs: self._steps_to_rerun.add(next_step) + # Remove any steps that should not be run (--until takes precedence) + self._steps_to_rerun = self._steps_to_rerun - self._steps_no_run + self._steps_can_clone = all_steps - self._steps_to_rerun - self._steps_no_run self._origin_ds_set = None if clone_run_id: @@ -399,7 +413,7 @@ def clone_original_run(self, generate_task_obj=False, verbose=True): if ( task_ds["_task_ok"] and step_name != "_parameters" - and (step_name not in self._steps_to_rerun) + and (step_name in self._steps_can_clone) ): # "_unbounded_foreach" is a special flag to indicate that the transition # is an unbounded foreach. @@ -677,6 +691,21 @@ def execute(self): system_msg=True, ) self._params_task.mark_resume_done() + elif self._steps_no_run: + # Ran a subset of the graph + count_cloned = -1 # Account for _parameters task + count_reexec = 0 + for t in self._is_cloned.values(): + if t: + count_cloned += 1 + else: + count_reexec += 1 + + self._logger( + f"Partial resume complete -- cloned {count_cloned} step(s) and " + f"executed {count_reexec} step(s)", + system_msg=True, + ) else: raise MetaflowInternalError( "The *end* step was not successful by the end of flow." @@ -1073,6 +1102,7 @@ def _queue_task_foreach(self, task, next_steps): def _queue_tasks(self, finished_tasks): # finished tasks include only successful tasks for task in finished_tasks: + step_name, _, _ = task.finished_id self._finished[task.finished_id] = task.path self._is_cloned[task.path] = task.is_cloned @@ -1137,6 +1167,15 @@ def _queue_tasks(self, finished_tasks): ) ) + if self._steps_no_run: + # We need to filter next_steps to only include steps that are in + # self._steps_to_rerun + next_steps = [ + step for step in next_steps if step not in self._steps_no_run + ] + if not next_steps: + # No steps to execute, so we can stop + return # Different transition types require different treatment if any(self._graph[f].type == "join" for f in next_steps): # Next step is a join