Skip to content

Commit 47f0cf9

Browse files
Feature:3963 Step HeartBeat components
- Backend heartbeat support (DB, API) - Heartbeat monitoring worker
1 parent 900a588 commit 47f0cf9

File tree

10 files changed

+334
-3
lines changed

10 files changed

+334
-3
lines changed

src/zenml/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
439439
STATUS = "/status"
440440
STEP_CONFIGURATION = "/step-configuration"
441441
STEPS = "/steps"
442+
HEARTBEAT = "heartbeat"
442443
STOP = "/stop"
443444
TAGS = "/tags"
444445
TAG_RESOURCES = "/tag_resources"

src/zenml/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@
346346
StepRunResponse,
347347
StepRunResponseBody,
348348
StepRunResponseMetadata,
349-
StepRunResponseResources
349+
StepRunResponseResources,
350+
StepHeartbeatResponse,
350351
)
351352
from zenml.models.v2.core.tag import (
352353
TagFilter,
@@ -908,4 +909,5 @@
908909
"StepRunIdentifier",
909910
"ArtifactVersionIdentifier",
910911
"ModelVersionIdentifier",
912+
"StepHeartbeatResponse",
911913
]

src/zenml/models/v2/core/step_run.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from uuid import UUID
2828

29-
from pydantic import ConfigDict, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030

3131
from zenml.config.step_configurations import StepConfiguration, StepSpec
3232
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
@@ -210,6 +210,10 @@ class StepRunResponseBody(ProjectScopedResponseBody):
210210
title="The end time of the step run.",
211211
default=None,
212212
)
213+
latest_heartbeat: Optional[datetime] = Field(
214+
title="The latest heartbeat of the step run.",
215+
default=None,
216+
)
213217
model_version_id: Optional[UUID] = Field(
214218
title="The ID of the model version that was "
215219
"configured by this step run explicitly.",
@@ -589,6 +593,15 @@ def end_time(self) -> Optional[datetime]:
589593
"""
590594
return self.get_body().end_time
591595

596+
@property
597+
def latest_heartbeat(self) -> Optional[datetime]:
598+
"""The `latest_heartbeat` property.
599+
600+
Returns:
601+
the value of the property.
602+
"""
603+
return self.get_body().latest_heartbeat
604+
592605
@property
593606
def logs(self) -> Optional["LogsResponse"]:
594607
"""The `logs` property.
@@ -795,3 +808,14 @@ def get_custom_filters(
795808
custom_filters.append(cache_expiration_filter)
796809

797810
return custom_filters
811+
812+
813+
# ------------------ Heartbeat Model ---------------
814+
815+
816+
class StepHeartbeatResponse(BaseModel):
817+
"""Light-weight model for Step Heartbeat responses."""
818+
819+
id: UUID
820+
status: str
821+
latest_heartbeat: datetime

src/zenml/steps/heartbeat.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""ZenML Step HeartBeat functionality."""
15+
16+
import _thread
17+
import logging
18+
import threading
19+
import time
20+
from typing import Annotated
21+
from uuid import UUID
22+
23+
from pydantic import BaseModel, conint, model_validator
24+
25+
from zenml.enums import ExecutionStatus
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class StepHeartBeatTerminationException(Exception):
31+
"""Custom exception class for heartbeat termination."""
32+
33+
pass
34+
35+
36+
class StepHeartBeatOptions(BaseModel):
37+
"""Options group for step heartbeat execution."""
38+
39+
step_id: UUID
40+
interval: Annotated[int, conint(ge=10, le=60)]
41+
name: str | None = None
42+
43+
@model_validator(mode="after")
44+
def set_default_name(self) -> "StepHeartBeatOptions":
45+
"""Model validator - set name value if missing.
46+
47+
Returns:
48+
The validated step heartbeat options.
49+
"""
50+
if not self.name:
51+
self.name = f"HeartBeatWorker-{self.step_id}"
52+
53+
return self
54+
55+
56+
class HeartbeatWorker:
57+
"""Worker class implementing heartbeat polling and remote termination."""
58+
59+
def __init__(self, options: StepHeartBeatOptions):
60+
"""Heartbeat worker constructor.
61+
62+
Args:
63+
options: Parameter group - polling interval, step id, etc.
64+
"""
65+
self.options = options
66+
67+
self._thread: threading.Thread | None = None
68+
self._running: bool = False
69+
self._terminated: bool = (
70+
False # one-shot guard to avoid repeated interrupts
71+
)
72+
73+
# properties
74+
75+
@property
76+
def interval(self) -> int:
77+
"""Property function for heartbeat interval.
78+
79+
Returns:
80+
The heartbeat polling interval value.
81+
"""
82+
return self.options.interval
83+
84+
@property
85+
def name(self) -> str:
86+
"""Property function for heartbeat worker name.
87+
88+
Returns:
89+
The name of the heartbeat worker.
90+
"""
91+
return str(self.options.name)
92+
93+
@property
94+
def step_id(self) -> UUID:
95+
"""Property function for heartbeat worker step ID.
96+
97+
Returns:
98+
The id of the step heartbeat is running for.
99+
"""
100+
return self.options.step_id
101+
102+
# public functions
103+
104+
def start(self) -> None:
105+
"""Start the heartbeat worker on a background thread."""
106+
if self._thread and self._thread.is_alive():
107+
logger.info("%s already running; start() is a no-op", self.name)
108+
return
109+
110+
self._running = True
111+
self._terminated = False
112+
self._thread = threading.Thread(
113+
target=self._run, name=self.name, daemon=True
114+
)
115+
self._thread.start()
116+
logger.info(
117+
"Daemon thread %s started (interval=%s)", self.name, self.interval
118+
)
119+
120+
def stop(self) -> None:
121+
"""Stops the heartbeat worker."""
122+
if not self._running:
123+
return
124+
self._running = False
125+
logger.info("%s stop requested", self.name)
126+
127+
def is_alive(self) -> bool:
128+
"""Liveness of the heartbeat worker thread.
129+
130+
Returns:
131+
True if the heartbeat worker thread is alive, False otherwise.
132+
"""
133+
t = self._thread
134+
return bool(t and t.is_alive())
135+
136+
def _run(self) -> None:
137+
logger.info("%s run() loop entered", self.name)
138+
try:
139+
while self._running:
140+
try:
141+
self._heartbeat()
142+
except StepHeartBeatTerminationException:
143+
# One-shot: signal the main thread and stop the loop.
144+
if not self._terminated:
145+
self._terminated = True
146+
logger.info(
147+
"%s received HeartBeatTerminationException; "
148+
"interrupting main thread",
149+
self.name,
150+
)
151+
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
152+
# Ensure we stop our own loop as well.
153+
self._running = False
154+
except Exception:
155+
# Log-and-continue policy for all other errors.
156+
logger.exception(
157+
"%s heartbeat() failed; continuing", self.name
158+
)
159+
# Sleep after each attempt (even after errors, unless stopped).
160+
if self._running:
161+
time.sleep(self.interval)
162+
finally:
163+
logger.info("%s run() loop exiting", self.name)
164+
165+
def _heartbeat(self) -> None:
166+
from zenml.config.global_config import GlobalConfiguration
167+
168+
store = GlobalConfiguration().zen_store
169+
170+
response = store.update_step_heartbeat(step_run_id=self.step_id)
171+
172+
if response.status in {
173+
ExecutionStatus.STOPPED,
174+
ExecutionStatus.STOPPING,
175+
}:
176+
raise StepHeartBeatTerminationException(
177+
f"Step {self.step_id} remotely stopped with status {response.status}."
178+
)

src/zenml/zen_server/routers/steps_endpoints.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from zenml.constants import (
2222
API,
23+
HEARTBEAT,
2324
LOGS,
2425
STATUS,
2526
STEP_CONFIGURATION,
@@ -38,6 +39,7 @@
3839
StepRunResponse,
3940
StepRunUpdate,
4041
)
42+
from zenml.models.v2.core.step_run import StepHeartbeatResponse
4143
from zenml.zen_server.auth import (
4244
AuthContext,
4345
authorize,
@@ -200,6 +202,30 @@ def update_step(
200202
return dehydrate_response_model(updated_step)
201203

202204

205+
@router.put(
206+
"/{step_run_id}/" + HEARTBEAT,
207+
responses={401: error_response, 404: error_response, 422: error_response},
208+
)
209+
@async_fastapi_endpoint_wrapper(deduplicate=True)
210+
def update_heartbeat(
211+
step_run_id: UUID,
212+
_: AuthContext = Security(authorize),
213+
) -> StepHeartbeatResponse:
214+
"""Updates a step.
215+
216+
Args:
217+
step_run_id: ID of the step.
218+
219+
Returns:
220+
The step heartbeat response (id, status, last_heartbeat).
221+
"""
222+
step = zen_store().get_run_step(step_run_id, hydrate=True)
223+
pipeline_run = zen_store().get_run(step.pipeline_run_id)
224+
verify_permission_for_model(pipeline_run, action=Action.UPDATE)
225+
226+
return zen_store().update_step_heartbeat(step_run_id=step_run_id)
227+
228+
203229
@router.get(
204230
"/{step_id}" + STEP_CONFIGURATION,
205231
responses={401: error_response, 404: error_response, 422: error_response},
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Add heartbeat column for step runs [a5a17015b681].
2+
3+
Revision ID: a5a17015b681
4+
Revises: 0.90.0
5+
Create Date: 2025-10-13 12:24:12.470803
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "a5a17015b681"
14+
down_revision = "0.90.0"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
"""Upgrade database schema and/or data, creating a new revision."""
21+
with op.batch_alter_table("step_run", schema=None) as batch_op:
22+
batch_op.add_column(
23+
sa.Column("latest_heartbeat", sa.DateTime(), nullable=True)
24+
)
25+
26+
27+
def downgrade() -> None:
28+
"""Downgrade database schema and/or data back to the previous revision."""
29+
with op.batch_alter_table("step_run", schema=None) as batch_op:
30+
batch_op.drop_column("latest_heartbeat")

src/zenml/zen_stores/rest_zen_store.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING,
7676
EVENT_SOURCES,
7777
FLAVORS,
78+
HEARTBEAT,
7879
INFO,
7980
LOGIN,
8081
LOGS,
@@ -258,6 +259,7 @@
258259
StackRequest,
259260
StackResponse,
260261
StackUpdate,
262+
StepHeartbeatResponse,
261263
StepRunFilter,
262264
StepRunRequest,
263265
StepRunResponse,
@@ -3378,6 +3380,23 @@ def update_run_step(
33783380
route=STEPS,
33793381
)
33803382

3383+
def update_step_heartbeat(
3384+
self, step_run_id: UUID
3385+
) -> StepHeartbeatResponse:
3386+
"""Updates a step run heartbeat.
3387+
3388+
Args:
3389+
step_run_id: The ID of the step to update.
3390+
3391+
Returns:
3392+
The step heartbeat response.
3393+
"""
3394+
response_body = self.put(
3395+
f"{STEPS}/{str(step_run_id)}/{HEARTBEAT}", body=None, params=None
3396+
)
3397+
3398+
return StepHeartbeatResponse.model_validate(response_body)
3399+
33813400
# -------------------- Triggers --------------------
33823401

33833402
def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse:

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
8686
# Fields
8787
start_time: Optional[datetime] = Field(nullable=True)
8888
end_time: Optional[datetime] = Field(nullable=True)
89+
latest_heartbeat: Optional[datetime] = Field(
90+
nullable=True,
91+
description="The latest execution heartbeat.",
92+
)
8993
status: str = Field(nullable=False)
9094

9195
docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))

0 commit comments

Comments
 (0)