Skip to content

Commit 1352677

Browse files
fixup! Improvements and bug fixes
- Updates migration down revision refs - context-reraise exception - changes in the step-heartbeat logic - fix null heartbeat in list/get endpoints
1 parent 47f0cf9 commit 1352677

File tree

7 files changed

+152
-60
lines changed

7 files changed

+152
-60
lines changed

src/zenml/orchestrators/step_launcher.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
from zenml.orchestrators import utils as orchestrator_utils
4444
from zenml.orchestrators.step_runner import StepRunner
4545
from zenml.stack import Stack
46+
from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker
4647
from zenml.utils import env_utils, exception_utils, string_utils
48+
from zenml.utils.exception_utils import ContextReraise
4749
from zenml.utils.time_utils import utc_now
4850

4951
if TYPE_CHECKING:
@@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None:
167169

168170
try:
169171
client = Client()
170-
pipeline_run = None
171172

172173
if self._step_run:
173174
pipeline_run = client.get_pipeline_run(
@@ -443,35 +444,58 @@ def _run_step(
443444
)
444445

445446
start_time = time.time()
446-
try:
447-
if self._step.config.step_operator:
448-
step_operator_name = None
449-
if isinstance(self._step.config.step_operator, str):
450-
step_operator_name = self._step.config.step_operator
451-
452-
self._run_step_with_step_operator(
453-
step_operator_name=step_operator_name,
454-
step_run_info=step_run_info,
447+
448+
# To have a cross-platform compatible handling of main thread termination
449+
# we use Python's interrupt_main instead of termination signals (not Windows supported).
450+
# Since interrupt_main raises KeyboardInterrupt we want in this context to capture it
451+
# and handle it as a custom exception.
452+
453+
with ContextReraise(
454+
source_exceptions=[KeyboardInterrupt],
455+
target_exception=StepHeartBeatTerminationException,
456+
message=f"Step {self._invocation_id} has been remotely stopped - terminating",
457+
propagate_traceback=False,
458+
) as ctx_reraise:
459+
460+
logger.info(f"Initiating heartbeat for step {self._invocation_id}")
461+
462+
StepHeartbeatWorker(step_id=step_run.id).start()
463+
464+
try:
465+
if self._step.config.step_operator:
466+
step_operator_name = None
467+
if isinstance(self._step.config.step_operator, str):
468+
step_operator_name = self._step.config.step_operator
469+
470+
self._run_step_with_step_operator(
471+
step_operator_name=step_operator_name,
472+
step_run_info=step_run_info,
473+
)
474+
else:
475+
self._run_step_without_step_operator(
476+
pipeline_run=pipeline_run,
477+
step_run=step_run,
478+
step_run_info=step_run_info,
479+
input_artifacts=step_run.regular_inputs,
480+
output_artifact_uris=output_artifact_uris,
481+
)
482+
except StepHeartBeatTerminationException:
483+
logger.info(ctx_reraise.message)
484+
output_utils.remove_artifact_dirs(
485+
artifact_uris=list(output_artifact_uris.values())
455486
)
456-
else:
457-
self._run_step_without_step_operator(
458-
pipeline_run=pipeline_run,
459-
step_run=step_run,
460-
step_run_info=step_run_info,
461-
input_artifacts=step_run.regular_inputs,
462-
output_artifact_uris=output_artifact_uris,
487+
raise
488+
except: # noqa: E722
489+
output_utils.remove_artifact_dirs(
490+
artifact_uris=list(output_artifact_uris.values())
463491
)
464-
except: # noqa: E722
465-
output_utils.remove_artifact_dirs(
466-
artifact_uris=list(output_artifact_uris.values())
467-
)
468-
raise
492+
raise
469493

470-
duration = time.time() - start_time
471-
logger.info(
472-
f"Step `{self._invocation_id}` has finished in "
473-
f"`{string_utils.get_human_readable_time(duration)}`."
474-
)
494+
duration = time.time() - start_time
495+
logger.info(
496+
f"Step `{self._invocation_id}` has finished in "
497+
f"`{string_utils.get_human_readable_time(duration)}`."
498+
)
475499

476500
def _run_step_with_step_operator(
477501
self,

src/zenml/steps/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from zenml.steps.base_step import BaseStep
3030
from zenml.config.resource_settings import ResourceSettings
31+
from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException
3132
from zenml.steps.step_context import StepContext, get_step_context
3233
from zenml.steps.step_decorator import step
3334

@@ -36,5 +37,7 @@
3637
"ResourceSettings",
3738
"StepContext",
3839
"step",
39-
"get_step_context"
40+
"get_step_context",
41+
"StepHeartbeatWorker",
42+
"StepHeartBeatTerminationException",
4043
]

src/zenml/steps/heartbeat.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
import logging
1818
import threading
1919
import time
20-
from typing import Annotated
2120
from uuid import UUID
2221

23-
from pydantic import BaseModel, conint, model_validator
24-
2522
from zenml.enums import ExecutionStatus
2623

2724
logger = logging.getLogger(__name__)
@@ -33,36 +30,19 @@ class StepHeartBeatTerminationException(Exception):
3330
pass
3431

3532

36-
class StepHeartBeatOptions(BaseModel):
37-
"""Options group for step heartbeat execution."""
38-
39-
step_id: UUID
40-
interval: Annotated[int, conint(ge=10, le=60)]
41-
name: str | None = None
42-
43-
@model_validator(mode="after")
44-
def set_default_name(self) -> "StepHeartBeatOptions":
45-
"""Model validator - set name value if missing.
46-
47-
Returns:
48-
The validated step heartbeat options.
49-
"""
50-
if not self.name:
51-
self.name = f"HeartBeatWorker-{self.step_id}"
52-
53-
return self
54-
55-
56-
class HeartbeatWorker:
33+
class StepHeartbeatWorker:
5734
"""Worker class implementing heartbeat polling and remote termination."""
5835

59-
def __init__(self, options: StepHeartBeatOptions):
36+
STEP_HEARTBEAT_INTERVAL_SECONDS = 30
37+
38+
def __init__(self, step_id: UUID):
6039
"""Heartbeat worker constructor.
6140
6241
Args:
6342
options: Parameter group - polling interval, step id, etc.
6443
"""
65-
self.options = options
44+
45+
self._step_id = step_id
6646

6747
self._thread: threading.Thread | None = None
6848
self._running: bool = False
@@ -79,7 +59,7 @@ def interval(self) -> int:
7959
Returns:
8060
The heartbeat polling interval value.
8161
"""
82-
return self.options.interval
62+
return self.STEP_HEARTBEAT_INTERVAL_SECONDS
8363

8464
@property
8565
def name(self) -> str:
@@ -88,7 +68,7 @@ def name(self) -> str:
8868
Returns:
8969
The name of the heartbeat worker.
9070
"""
91-
return str(self.options.name)
71+
return f"HeartBeatWorker-{self.step_id}"
9272

9373
@property
9474
def step_id(self) -> UUID:
@@ -97,7 +77,7 @@ def step_id(self) -> UUID:
9777
Returns:
9878
The id of the step heartbeat is running for.
9979
"""
100-
return self.options.step_id
80+
return self.step_id
10181

10282
# public functions
10383

src/zenml/utils/exception_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import re
1919
import textwrap
2020
import traceback
21-
from typing import TYPE_CHECKING, Optional
21+
from contextlib import ContextDecorator
22+
from types import TracebackType
23+
from typing import TYPE_CHECKING, Optional, Type
2224

2325
from zenml.constants import MEDIUMTEXT_MAX_LENGTH
2426
from zenml.logger import get_logger
@@ -91,3 +93,41 @@ def collect_exception_information(
9193
traceback=tb_bytes.decode(errors="ignore"),
9294
step_code_line=line_number,
9395
)
96+
97+
98+
class ContextReraise(ContextDecorator):
99+
def __init__(
100+
self,
101+
source_exceptions: list[Type[BaseException]],
102+
target_exception: Type[BaseException],
103+
message: str,
104+
propagate_traceback: bool = True,
105+
) -> None:
106+
self._source_exceptions = source_exceptions
107+
self._target_exception = target_exception
108+
self._message = message
109+
self._propagate_traceback = propagate_traceback
110+
111+
@property
112+
def message(self) -> str:
113+
return self._message
114+
115+
def __enter__(self) -> "ContextReraise":
116+
return self
117+
118+
def __exit__(
119+
self,
120+
exc_type: Type[BaseException] | None,
121+
exc_value: BaseException | None,
122+
trace: TracebackType | None,
123+
) -> bool:
124+
if exc_type is None:
125+
return False
126+
if any(isinstance(exc_value, exc) for exc in self._source_exceptions):
127+
if self._propagate_traceback:
128+
raise self._target_exception(self._message).with_traceback(
129+
trace
130+
)
131+
else:
132+
raise self._target_exception(self._message)
133+
return False

src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# revision identifiers, used by Alembic.
1313
revision = "a5a17015b681"
14-
down_revision = "0.90.0"
14+
down_revision = "0.91.0"
1515
branch_labels = None
1616
depends_on = None
1717

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def to_model(
407407
is_retriable=self.is_retriable,
408408
start_time=self.start_time,
409409
end_time=self.end_time,
410+
latest_heartbeat=self.latest_heartbeat,
410411
created=self.created,
411412
updated=self.updated,
412413
model_version_id=self.model_version_id,

tests/unit/utils/test_exception_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import pytest
77

8-
from zenml.utils.exception_utils import collect_exception_information
8+
from zenml.utils.exception_utils import (
9+
ContextReraise,
10+
collect_exception_information,
11+
)
912

1013

1114
def test_regex_pattern_no_syntax_warning():
@@ -101,3 +104,44 @@ def test_regex_pattern_matches_windows_paths_and_special_chars():
101104
r' File "C:\Other\path\file.py", line 123, in some_function'
102105
)
103106
assert line_pattern_win.search(non_match_line) is None
107+
108+
109+
def test_context_reraise():
110+
# test source errors are captured and re-raised
111+
112+
class CustomError(Exception):
113+
pass
114+
115+
with pytest.raises(CustomError):
116+
with ContextReraise(
117+
source_exceptions=[ValueError, TypeError],
118+
target_exception=CustomError,
119+
message="Oh no",
120+
):
121+
raise ValueError("VALUE ERROR")
122+
123+
# test other errors propagate normally
124+
125+
with pytest.raises(ZeroDivisionError):
126+
with ContextReraise(
127+
source_exceptions=[ValueError, TypeError],
128+
target_exception=RuntimeError,
129+
message="Oh no",
130+
):
131+
_ = 1 / 0
132+
133+
# test inheritance works
134+
135+
class CustomValueError(ValueError):
136+
pass
137+
138+
class CustomTypeError(TypeError):
139+
pass
140+
141+
with pytest.raises(CustomTypeError):
142+
with ContextReraise(
143+
source_exceptions=[ValueError, TypeError],
144+
target_exception=CustomTypeError,
145+
message="Oh no",
146+
):
147+
raise CustomValueError("VALUE ERROR")

0 commit comments

Comments
 (0)