|
43 | 43 | from zenml.orchestrators import utils as orchestrator_utils |
44 | 44 | from zenml.orchestrators.step_runner import StepRunner |
45 | 45 | from zenml.stack import Stack |
| 46 | +from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker |
46 | 47 | from zenml.utils import env_utils, exception_utils, string_utils |
| 48 | +from zenml.utils.exception_utils import ContextReraise |
47 | 49 | from zenml.utils.time_utils import utc_now |
48 | 50 |
|
49 | 51 | if TYPE_CHECKING: |
@@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None: |
167 | 169 |
|
168 | 170 | try: |
169 | 171 | client = Client() |
170 | | - pipeline_run = None |
171 | 172 |
|
172 | 173 | if self._step_run: |
173 | 174 | pipeline_run = client.get_pipeline_run( |
@@ -443,35 +444,58 @@ def _run_step( |
443 | 444 | ) |
444 | 445 |
|
445 | 446 | start_time = time.time() |
446 | | - try: |
447 | | - if self._step.config.step_operator: |
448 | | - step_operator_name = None |
449 | | - if isinstance(self._step.config.step_operator, str): |
450 | | - step_operator_name = self._step.config.step_operator |
451 | | - |
452 | | - self._run_step_with_step_operator( |
453 | | - step_operator_name=step_operator_name, |
454 | | - step_run_info=step_run_info, |
| 447 | + |
| 448 | + # To have a cross-platform compatible handling of main thread termination |
| 449 | + # we use Python's interrupt_main instead of termination signals (not Windows supported). |
| 450 | + # Since interrupt_main raises KeyboardInterrupt we want in this context to capture it |
| 451 | + # and handle it as a custom exception. |
| 452 | + |
| 453 | + with ContextReraise( |
| 454 | + source_exceptions=[KeyboardInterrupt], |
| 455 | + target_exception=StepHeartBeatTerminationException, |
| 456 | + message=f"Step {self._invocation_id} has been remotely stopped - terminating", |
| 457 | + propagate_traceback=False, |
| 458 | + ) as ctx_reraise: |
| 459 | + |
| 460 | + logger.info(f"Initiating heartbeat for step {self._invocation_id}") |
| 461 | + |
| 462 | + StepHeartbeatWorker(step_id=step_run.id).start() |
| 463 | + |
| 464 | + try: |
| 465 | + if self._step.config.step_operator: |
| 466 | + step_operator_name = None |
| 467 | + if isinstance(self._step.config.step_operator, str): |
| 468 | + step_operator_name = self._step.config.step_operator |
| 469 | + |
| 470 | + self._run_step_with_step_operator( |
| 471 | + step_operator_name=step_operator_name, |
| 472 | + step_run_info=step_run_info, |
| 473 | + ) |
| 474 | + else: |
| 475 | + self._run_step_without_step_operator( |
| 476 | + pipeline_run=pipeline_run, |
| 477 | + step_run=step_run, |
| 478 | + step_run_info=step_run_info, |
| 479 | + input_artifacts=step_run.regular_inputs, |
| 480 | + output_artifact_uris=output_artifact_uris, |
| 481 | + ) |
| 482 | + except StepHeartBeatTerminationException: |
| 483 | + logger.info(ctx_reraise.message) |
| 484 | + output_utils.remove_artifact_dirs( |
| 485 | + artifact_uris=list(output_artifact_uris.values()) |
455 | 486 | ) |
456 | | - else: |
457 | | - self._run_step_without_step_operator( |
458 | | - pipeline_run=pipeline_run, |
459 | | - step_run=step_run, |
460 | | - step_run_info=step_run_info, |
461 | | - input_artifacts=step_run.regular_inputs, |
462 | | - output_artifact_uris=output_artifact_uris, |
| 487 | + raise |
| 488 | + except: # noqa: E722 |
| 489 | + output_utils.remove_artifact_dirs( |
| 490 | + artifact_uris=list(output_artifact_uris.values()) |
463 | 491 | ) |
464 | | - except: # noqa: E722 |
465 | | - output_utils.remove_artifact_dirs( |
466 | | - artifact_uris=list(output_artifact_uris.values()) |
467 | | - ) |
468 | | - raise |
| 492 | + raise |
469 | 493 |
|
470 | | - duration = time.time() - start_time |
471 | | - logger.info( |
472 | | - f"Step `{self._invocation_id}` has finished in " |
473 | | - f"`{string_utils.get_human_readable_time(duration)}`." |
474 | | - ) |
| 494 | + duration = time.time() - start_time |
| 495 | + logger.info( |
| 496 | + f"Step `{self._invocation_id}` has finished in " |
| 497 | + f"`{string_utils.get_human_readable_time(duration)}`." |
| 498 | + ) |
475 | 499 |
|
476 | 500 | def _run_step_with_step_operator( |
477 | 501 | self, |
|
0 commit comments