From 78bfae55b2a827de7e57a27dd49db807ca67ffc5 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sat, 6 Sep 2025 02:45:32 +0100 Subject: [PATCH 1/8] initial commit with all step runtime and capture changes --- .../serving/advanced/capture-and-runtime.md | 114 +++ docs/book/serving/overview.md | 96 +++ docs/book/serving/toc.md | 5 + docs/book/toc.md | 3 + examples/serving/weather_pipeline.py | 1 + src/zenml/config/compiler.py | 25 + src/zenml/config/pipeline_configurations.py | 41 + .../config/pipeline_run_configuration.py | 2 + src/zenml/execution/__init__.py | 23 + src/zenml/execution/capture_policy.py | 107 +++ src/zenml/execution/factory.py | 101 +++ src/zenml/execution/realtime_runtime.py | 456 +++++++++++ src/zenml/execution/step_runtime.py | 750 ++++++++++++++++++ .../kubernetes_orchestrator_entrypoint.py | 13 + src/zenml/orchestrators/step_launcher.py | 374 +++++++-- src/zenml/orchestrators/step_run_utils.py | 28 +- src/zenml/orchestrators/step_runner.py | 74 +- src/zenml/pipelines/pipeline_decorator.py | 4 + src/zenml/pipelines/pipeline_definition.py | 13 + 19 files changed, 2123 insertions(+), 107 deletions(-) create mode 100644 docs/book/serving/advanced/capture-and-runtime.md create mode 100644 docs/book/serving/overview.md create mode 100644 docs/book/serving/toc.md create mode 100644 src/zenml/execution/__init__.py create mode 100644 src/zenml/execution/capture_policy.py create mode 100644 src/zenml/execution/factory.py create mode 100644 src/zenml/execution/realtime_runtime.py create mode 100644 src/zenml/execution/step_runtime.py diff --git a/docs/book/serving/advanced/capture-and-runtime.md b/docs/book/serving/advanced/capture-and-runtime.md new file mode 100644 index 00000000000..1dd2ffae418 --- /dev/null +++ b/docs/book/serving/advanced/capture-and-runtime.md @@ -0,0 +1,114 @@ +--- +title: Capture Policy & Execution Runtimes (Advanced) +--- + +# Capture Policy & Execution Runtimes (Advanced) + +This page explains how capture options map to execution runtimes and how to tune them for production serving. + +## Execution Runtimes + +- DefaultStepRuntime + - Standard ZenML execution: persists artifacts, creates runs and step runs, captures metadata/logs per config. + +- RealtimeStepRuntime + - Focus: Low latency + observability. + - Features: + - In-process artifact value cache for downstream steps in the same process. + - Tunables: `ttl_seconds`, `max_entries` via capture options (or env vars `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`). + - Async server updates with a background worker. + - `flush_on_step_end` controls whether to block at step boundary to flush updates. + - In serving with `mode=REALTIME`, `flush_on_step_end` defaults to `false` unless explicitly set. + +- OffStepRuntime + - Focus: Lightweight operation with minimal overhead. + - Behavior: Persists artifacts; skips metadata/logs/visualizations/caching (compiler disables these by default in OFF). + +- MemoryStepRuntime + - Focus: Pure in-memory execution (no server, no persistence). + - Behavior: Inter-step data is exchanged via in-process memory handles; no runs or artifacts. + - Configure with REALTIME: `capture={"mode": "REALTIME", "runs": "off"}` or `{"persistence": "memory"}`. + +## Capture Configuration + +Where to set: +- In code: `@pipeline(capture=...)` +- In run config YAML: `capture: ...` + +Supported options (commonly used): +```yaml +capture: + mode: BATCH | REALTIME | OFF | CUSTOM + runs: on | off # off → no runs (memory-only when REALTIME) + persistence: sync | async | memory | off + logs: all | errors-only | off + metadata: true | false + visualization: true | false + cache_enabled: true | false + code: true | false # skip docstring/source capture if false + flush_on_step_end: true | false + ttl_seconds: 600 # Realtime cache TTL + max_entries: 2048 # Realtime cache size bound +``` + +Notes: +- `mode` determines the base runtime. +- `runs: off` or `persistence: memory/off` under REALTIME maps to MemoryStepRuntime (pure in-memory execution). +- `flush_on_step_end`: If `false`, serving returns immediately; tracking is published asynchronously by the runtime worker. +- `code: false`: Skips docstring/source capture (metadata), but does not affect code execution. + +## Serving Defaults + +- REALTIME + serving context: + - If `flush_on_step_end` is not provided, it defaults to `false` for better latency. + - Users can override by setting `flush_on_step_end: true`. + +## Step Operators & Remote Execution + +- Step operators inherit capture via environment (e.g., `ZENML_CAPTURE_MODE`). +- Remote entrypoints construct the matching runtime and honor capture options. + +## Memory-Only Internals (for deeper understanding) + +- Handle format: `mem:////` +- Memory runtime: + - `resolve_step_inputs`: constructs handles from `run_id` + substitutions. + - `load_input_artifact`: resolves handle to value from a thread-safe in-process store. + - `store_output_artifacts`: stores outputs back to the store; returns new handles for downstream steps. +- No server calls; no runs or artifacts are created. + +## Environment Variables + +- `ZENML_CAPTURE_MODE`: global default capture when not set in the pipeline. +- `ZENML_SERVING_CAPTURE_DEFAULT`: used internally to reduce tracking when capture is not set (serving compatibility). +- `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`: Realtime cache controls. + +## Recipes + +- Low-latency serving (eventual consistency): + - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` + +- Strict serving (strong consistency): + - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": true})` + +- Memory-only (stateless service): + - `@pipeline(capture={"mode": "REALTIME", "runs": "off"})` + +- Compliance mode: + - `@pipeline(capture="BATCH")` or + - `@pipeline(capture={"mode": "REALTIME", "logs": "all", "metadata": true, "flush_on_step_end": true})` + +## FAQ + +- Can I enable only partial capture (e.g., errors-only logs)? + - Yes, e.g., `logs: errors-only` and `metadata: false`. + +- Does `code: false` break step execution? + - No. It only disables docstring/source capture. Steps still run normally. + +- How does caching interact with REALTIME? + - Default caching behavior is unchanged. Set `cache_enabled: false` to bypass caching entirely. + +- Can memory-only work with parallelism? + - Memory-only is per-process. For multi-process/multi-container setups, use persistence for cross-process data. + diff --git a/docs/book/serving/overview.md b/docs/book/serving/overview.md new file mode 100644 index 00000000000..e813ba2ed45 --- /dev/null +++ b/docs/book/serving/overview.md @@ -0,0 +1,96 @@ +--- +title: Pipeline Serving Overview +--- + +# Pipeline Serving Overview + +## What Is Pipeline Serving? + +- Purpose: Expose a ZenML pipeline as a low-latency service (e.g., via FastAPI) that executes steps on incoming requests and returns results. +- Value: Production-grade orchestration with flexible capture policies to balance latency, observability, and lineage. +- Modes: Default batch-style execution, optimized realtime execution, and pure in-memory execution for maximum speed. + +## Quick Start + +1) Define your pipeline +- Use your normal `@pipeline` and `@step` definitions. +- No serving-specific changes required. + +2) Choose a capture configuration (recommended) +- Low-latency, non-blocking tracking (serving-friendly): + - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` +- Pure in-memory execution (no runs, no artifacts): + - `@pipeline(capture={"mode": "REALTIME", "runs": "off"})` + +3) Deploy the serving service with your preferred deployer and call the FastAPI endpoint. + +## Capture Modes (Essentials) + +- BATCH (default) + - Behavior: Standard ZenML behavior (pipeline runs + step runs + artifacts + metadata/logs depending on config). + - Use when: Full lineage and strong consistency are required. + +- REALTIME + - Behavior: Optimized for latency and throughput. + - In-memory cache of artifact values within the same process. + - Async server updates by default; in serving, defaults to non-blocking responses (tracking finishes in background). + - Use when: You need low-latency serving with observability. + +- OFF + - Behavior: Lightweight tracking. + - Persists artifacts but skips metadata/logs/visualizations/caching for reduced overhead. + - Use when: You need a smaller footprint while preserving artifacts for downstream consumers. + +- Memory-only (special case inside REALTIME) + - Configure: `capture={"mode": "REALTIME", "runs": "off"}` or `capture={"mode": "REALTIME", "persistence": "memory"}` + - Behavior: Pure in-memory execution: + - No pipeline runs or step runs, no artifacts, no server calls. + - Steps exchange data in-process; response returns immediately. + - Use when: Maximum speed (prototyping, ultra-low-latency paths) without lineage. + +## Where To Configure Capture + +- In code (recommended) + - `@pipeline(capture="REALTIME")` + - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` + +- In run config YAML +```yaml +capture: REALTIME + +# or + +capture: + mode: REALTIME + flush_on_step_end: false +``` + +- Environment (fallbacks) + - `ZENML_CAPTURE_MODE=BATCH|REALTIME|OFF|CUSTOM` + - Serving defaults leverage `ZENML_SERVING_CAPTURE_DEFAULT` when capture is not set (used internally to reduce tracking overhead). + +## Best Practices + +- Most users (serving-ready) + - `capture={"mode": "REALTIME", "flush_on_step_end": false}` + - Good balance of immediate response and production tracking. + +- Maximum speed (no tracking at all) + - `capture={"mode": "REALTIME", "runs": "off"}` (pure in-memory) + - Great for tests, benchmarks, or hot paths where lineage is not needed. + +- Compliance or rich lineage + - `capture="BATCH"` or fine-tune REALTIME with `flush_on_step_end: true`, `logs: "all"`, `metadata: true`. + +## FAQ (Essentials) + +- Does serving always create pipeline runs? + - BATCH/REALTIME/OFF: Yes (OFF reduces overhead of metadata/logs). + - Memory-only (REALTIME with `runs: off`): No; executes purely in memory. + +- Will serving block responses to flush tracking? + - REALTIME in serving defaults to non-blocking (returns immediately), unless you explicitly set `flush_on_step_end: true`. + +- Is memory-only safe for production? + - Yes for stateless, speed-critical paths. Note: No lineage or persisted artifacts. + diff --git a/docs/book/serving/toc.md b/docs/book/serving/toc.md new file mode 100644 index 00000000000..3dcc3d1ae74 --- /dev/null +++ b/docs/book/serving/toc.md @@ -0,0 +1,5 @@ +# Serving + +* [Pipeline Serving Overview](overview.md) +* Advanced + * [Capture Policy & Runtimes](advanced/capture-and-runtime.md) diff --git a/docs/book/toc.md b/docs/book/toc.md index eb7d4ebb15d..acce2d8b1b6 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -53,6 +53,9 @@ * [Models](how-to/models/models.md) * [Templates](how-to/templates/templates.md) * [Dashboard](how-to/dashboard/dashboard-features.md) +* Serving + * [Pipeline Serving Overview](serving/overview.md) + * [Capture Policy & Runtimes (Advanced)](serving/advanced/capture-and-runtime.md) * [Serving Pipelines](how-to/serving/serving.md) * [Pipeline Serving Capture Policies](how-to/serving/capture-policies.md) diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index 79edd235058..80e1d7b2b81 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -213,6 +213,7 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: @pipeline( on_init=init_hook, + capture="realtime", settings={ "docker": docker_settings, "deployer.gcp": { diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index c4a01d8e740..ee959b24c3d 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -202,6 +202,7 @@ def _apply_run_configuration( enable_artifact_visualization=config.enable_artifact_visualization, enable_step_logs=config.enable_step_logs, enable_pipeline_logs=config.enable_pipeline_logs, + capture=config.capture, settings=config.settings, tags=config.tags, extra=config.extra, @@ -210,6 +211,30 @@ def _apply_run_configuration( parameters=config.parameters, ) + # Apply additional defaults based on capture mode + try: + capture_cfg = pipeline.configuration.capture + mode_str = None + if isinstance(capture_cfg, str): + mode_str = capture_cfg.upper() + elif isinstance(capture_cfg, dict): + mode = capture_cfg.get("mode") + if isinstance(mode, str): + mode_str = mode.upper() + if mode_str == "OFF": + # Disable overhead while keeping correctness + with pipeline.__suppress_configure_warnings__(): + pipeline.configure( + enable_cache=False, + enable_artifact_metadata=False, + enable_artifact_visualization=False, + enable_step_logs=False, + enable_pipeline_logs=False, + ) + except Exception: + # Non-fatal; leave configuration as-is + pass + invalid_step_configs = set(config.steps) - set(pipeline.invocations) if invalid_step_configs: logger.warning( diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 4111e2bd006..0794a3429a0 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -42,6 +42,9 @@ class PipelineConfigurationUpdate(StrictBaseModel): enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None + # Capture policy mode for execution semantics (e.g., BATCH, REALTIME, OFF, CUSTOM) + # Capture policy can be a mode string or a dict with options + capture: Optional[Union[str, Dict[str, Any]]] = None settings: Dict[str, SerializeAsAny[BaseSettings]] = {} tags: Optional[List[Union[str, "Tag"]]] = None extra: Dict[str, Any] = {} @@ -85,6 +88,44 @@ class PipelineConfiguration(PipelineConfigurationUpdate): name: str + @field_validator("capture") + @classmethod + def validate_capture_mode( + cls, value: Optional[Union[str, Dict[str, Any]]] + ) -> Optional[Union[str, Dict[str, Any]]]: + """Validates the capture mode. + + Args: + value: The capture mode to validate. + + Returns: + The validated capture mode. + """ + if value is None: + return value + if isinstance(value, dict): + mode = value.get("mode") + if mode is None: + # default to BATCH if mode not provided + value = {**value, "mode": "BATCH"} + mode = "BATCH" + allowed = {"BATCH", "REALTIME", "OFF", "CUSTOM"} + if str(mode).upper() not in allowed: + raise ValueError( + f"Invalid capture mode '{mode}'. Allowed: {sorted(allowed)}" + ) + # normalize mode to upper + value = {**value, "mode": str(mode).upper()} + return value + else: + allowed = {"BATCH", "REALTIME", "OFF", "CUSTOM"} + v = str(value).upper() + if v not in allowed: + raise ValueError( + f"Invalid capture mode '{value}'. Allowed: {sorted(allowed)}" + ) + return v + @field_validator("name") @classmethod def ensure_pipeline_name_allowed(cls, name: str) -> str: diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index b8203cbeab0..33e72d60682 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -41,6 +41,8 @@ class PipelineRunConfiguration( enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None + # Optional override for capture per run: mode string or dict with options + capture: Optional[Union[str, Dict[str, Any]]] = None schedule: Optional[Schedule] = None build: Union[PipelineBuildBase, UUID, None] = Field( default=None, union_mode="left_to_right" diff --git a/src/zenml/execution/__init__.py b/src/zenml/execution/__init__.py new file mode 100644 index 00000000000..7b2a01ce5d0 --- /dev/null +++ b/src/zenml/execution/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Execution runtime abstractions. + +This module defines the runtime interface used by the step runner / launcher +to interact with artifacts, metadata, and server updates. It is introduced as +an internal scaffolding to consolidate execution-time responsibilities behind +one facade without changing current behavior. + +NOTE: This is an internal API and subject to change. +""" + diff --git a/src/zenml/execution/capture_policy.py b/src/zenml/execution/capture_policy.py new file mode 100644 index 00000000000..32786dbe5ff --- /dev/null +++ b/src/zenml/execution/capture_policy.py @@ -0,0 +1,107 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Capture policy scaffolding and presets. + +This is a lightweight placeholder to enable runtime selection without changing +public pipeline APIs. The capture mode can be controlled via the environment +variable `ZENML_CAPTURE_MODE` with values: `BATCH` (default), `REALTIME`, +`OFF`, or `CUSTOM`. +""" + +import os +from enum import Enum +from typing import Any, Dict, Optional, Union + + +class CaptureMode(str, Enum): + """Capture mode enum.""" + + BATCH = "BATCH" + REALTIME = "REALTIME" + OFF = "OFF" + CUSTOM = "CUSTOM" + + +class CapturePolicy: + """Minimal capture policy container with optional options.""" + + def __init__( + self, + mode: CaptureMode = CaptureMode.BATCH, + options: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the capture policy. + + Args: + mode: The capture mode. + options: The capture options. + """ + self.mode = mode + self.options = options or {} + + @staticmethod + def from_env() -> "CapturePolicy": + """Create a capture policy from the environment. + + Returns: + The capture policy. + """ + val = os.getenv("ZENML_CAPTURE_MODE", "BATCH").upper() + try: + mode = CaptureMode(val) + except ValueError: + mode = CaptureMode.BATCH + # No options provided from env here; runtimes may read env as fallback + return CapturePolicy(mode=mode, options={}) + + @staticmethod + def from_value( + value: Optional[Union[str, Dict[str, Any]]], + ) -> "CapturePolicy": + """Create a capture policy from a value. + + Args: + value: The value to create the capture policy from. + + Returns: + The capture policy. + """ + if value is None: + return CapturePolicy.from_env() + if isinstance(value, dict): + mode = str(value.get("mode", "BATCH")).upper() + try: + cm = CaptureMode(mode) + except Exception: + cm = CaptureMode.BATCH + # store other keys as options + options = {k: v for k, v in value.items() if k != "mode"} + return CapturePolicy(mode=cm, options=options) + else: + try: + return CapturePolicy(mode=CaptureMode(str(value).upper())) + except Exception: + return CapturePolicy.from_env() + + def get_option(self, key: str, default: Any = None) -> Any: + """Get an option from the capture policy. + + Args: + key: The key of the option to get. + default: The default value to return if the option is not found. + + Returns: + The option value. + """ + return self.options.get(key, default) diff --git a/src/zenml/execution/factory.py b/src/zenml/execution/factory.py new file mode 100644 index 00000000000..067b24585b3 --- /dev/null +++ b/src/zenml/execution/factory.py @@ -0,0 +1,101 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Factory to construct a step runtime based on capture policy.""" + +from typing import Callable, Dict, Optional + +from zenml.execution.capture_policy import CaptureMode, CapturePolicy +from zenml.execution.step_runtime import ( + BaseStepRuntime, + DefaultStepRuntime, + MemoryStepRuntime, + OffStepRuntime, +) + +# Registry of runtime builders keyed by capture mode +_RUNTIME_REGISTRY: Dict[ + CaptureMode, Callable[[CapturePolicy], BaseStepRuntime] +] = {} + + +def register_runtime( + mode: CaptureMode, builder: Callable[[CapturePolicy], BaseStepRuntime] +) -> None: + """Register a runtime builder for a capture mode.""" + _RUNTIME_REGISTRY[mode] = builder + + +def get_runtime(policy: Optional[CapturePolicy]) -> BaseStepRuntime: + """Return a runtime implementation using the registry. + + Falls back to the default runtime if no builder is registered. + + Args: + policy: The capture policy. + + Returns: + The runtime implementation. + """ + policy = policy or CapturePolicy() + builder = _RUNTIME_REGISTRY.get(policy.mode) + if builder is not None: + return builder(policy) + return DefaultStepRuntime() + + +# Register default builders +def _build_default(_: CapturePolicy) -> BaseStepRuntime: + """Build the default runtime. + + Args: + policy: The capture policy. + + Returns: + The runtime implementation. + """ + return DefaultStepRuntime() + + +def _build_off(_: CapturePolicy) -> BaseStepRuntime: + """Build the off runtime (lightweight: persist artifacts, skip metadata).""" + return OffStepRuntime() + + +def _build_realtime(policy: CapturePolicy) -> BaseStepRuntime: + """Build the realtime runtime. + + Args: + policy: The capture policy. + + Returns: + The runtime implementation. + """ + # Import here to avoid circular imports + from zenml.execution.realtime_runtime import RealtimeStepRuntime + + # If runs are off or persistence is memory/off, use memory runtime + runs_opt = str(policy.get_option("runs", "on")).lower() + persistence = str(policy.get_option("persistence", "async")).lower() + if runs_opt in {"off", "false", "0"} or persistence in {"memory", "off"}: + return MemoryStepRuntime() + + ttl = policy.get_option("ttl_seconds") + max_entries = policy.get_option("max_entries") + return RealtimeStepRuntime(ttl_seconds=ttl, max_entries=max_entries) + + +register_runtime(CaptureMode.BATCH, _build_default) +register_runtime(CaptureMode.CUSTOM, _build_default) +register_runtime(CaptureMode.OFF, _build_off) +register_runtime(CaptureMode.REALTIME, _build_realtime) diff --git a/src/zenml/execution/realtime_runtime.py b/src/zenml/execution/realtime_runtime.py new file mode 100644 index 00000000000..61da826e0ef --- /dev/null +++ b/src/zenml/execution/realtime_runtime.py @@ -0,0 +1,456 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Realtime runtime with simple in-memory caching and async updates. + +This implementation prioritizes in-memory loads when available and otherwise +delegates to the default runtime persistence. It lays groundwork for future +write-behind persistence without changing current behavior. +""" + +import os +import queue +import threading +import time +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from zenml.execution.step_runtime import DefaultStepRuntime +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models import ArtifactVersionResponse +from zenml.orchestrators import publish_utils +from zenml.stack.stack import Stack +from zenml.steps.utils import OutputSignature + +if TYPE_CHECKING: + from uuid import UUID + + from zenml.metadata.metadata_types import MetadataType + + +class RealtimeStepRuntime(DefaultStepRuntime): + """Realtime runtime optimized for low-latency loads via memory cache.""" + + def __init__( + self, + ttl_seconds: Optional[int] = None, + max_entries: Optional[int] = None, + ) -> None: + """Initialize the realtime runtime. + + Args: + ttl_seconds: The TTL in seconds. + max_entries: The maximum number of entries in the cache. + """ + super().__init__() + # Simple LRU cache with TTL + self._cache: "OrderedDict[str, Tuple[Any, float]]" = OrderedDict() + self._lock = threading.RLock() + # Event queue: (kind, args, kwargs) + Event = Tuple[str, Tuple[Any, ...], Dict[str, Any]] + self._q: "queue.Queue[Event]" = queue.Queue() + self._worker: Optional[threading.Thread] = None + self._stop = threading.Event() + self._errors_since_last_flush: int = 0 + self._total_errors: int = 0 + self._last_error: Optional[BaseException] = None + self._logger = get_logger(__name__) + self._queued_count: int = 0 + self._processed_count: int = 0 + # Tunables via env: TTL seconds and max entries + # Options precedence: explicit args > env > defaults + if ttl_seconds is not None: + self._ttl_seconds = int(ttl_seconds) + else: + try: + self._ttl_seconds = int( + os.getenv("ZENML_RT_CACHE_TTL_SECONDS", "300") + ) + except Exception: + self._ttl_seconds = 300 + if max_entries is not None: + self._max_entries = int(max_entries) + else: + try: + self._max_entries = int( + os.getenv("ZENML_RT_CACHE_MAX_ENTRIES", "1024") + ) + except Exception: + self._max_entries = 1024 + # Flush behavior (can be disabled for serving non-blocking) + self._flush_on_step_end: bool = True + + # --- lifecycle --- + def start(self) -> None: + """Start the realtime runtime.""" + if self._worker is not None: + return + + def _run() -> None: + while not self._stop.is_set(): + try: + kind, args, kwargs = self._q.get(timeout=0.1) + except queue.Empty: + # Opportunistic cache sweep: evict expired from head + self._sweep_expired() + continue + try: + if kind == "pipeline_metadata": + publish_utils.publish_pipeline_run_metadata( + *args, **kwargs + ) + elif kind == "step_metadata": + publish_utils.publish_step_run_metadata( + *args, **kwargs + ) + elif kind == "step_success": + publish_utils.publish_successful_step_run( + *args, **kwargs + ) + elif kind == "step_failed": + publish_utils.publish_failed_step_run(*args, **kwargs) + except BaseException as e: # noqa: BLE001 + with self._lock: + self._errors_since_last_flush += 1 + self._total_errors += 1 + self._last_error = e + self._logger.warning( + "Realtime runtime failed to publish '%s': %s", kind, e + ) + finally: + with self._lock: + self._processed_count += 1 + self._q.task_done() + + self._worker = threading.Thread( + target=_run, name="zenml-realtime-runtime", daemon=True + ) + self._worker.start() + + def on_step_start(self) -> None: + """Optional hook when a step begins execution.""" + # no-op for now + return + + # Prefer in-memory values if available + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack of the artifact. + + Returns: + The loaded artifact. + """ + key = str(artifact.id) + with self._lock: + if key in self._cache: + value, expires_at = self._cache.get(key, (None, 0)) + now = time.time() + if now <= expires_at: + # Touch entry for LRU + self._cache.move_to_end(key) + return value + else: + # Expired + try: + del self._cache[key] + except KeyError: + pass + + # Fallback to default loading + return super().load_input_artifact( + artifact=artifact, data_type=data_type, stack=stack + ) + + # Store synchronously (behavior parity), and cache the raw values in memory + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + responses = super().store_output_artifacts( + output_data=output_data, + output_materializers=output_materializers, + output_artifact_uris=output_artifact_uris, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) + + # Cache by artifact id for later fast loads with TTL and LRU bounds + with self._lock: + now = time.time() + for name, resp in responses.items(): + if name in output_data: + expires_at = now + max(0, self._ttl_seconds) + self._cache[str(resp.id)] = (output_data[name], expires_at) + # Touch to end (most recently used) + self._cache.move_to_end(str(resp.id)) + # Enforce size bound + while len(self._cache) > max(1, self._max_entries): + try: + self._cache.popitem(last=False) # Evict LRU + except KeyError: + break + + return responses + + # --- async server updates --- + def publish_pipeline_run_metadata( + self, + *, + pipeline_run_id: "UUID", + pipeline_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + # Enqueue for async processing + self._q.put( + ( + "pipeline_metadata", + (), + { + "pipeline_run_id": pipeline_run_id, + "pipeline_run_metadata": pipeline_run_metadata, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_step_run_metadata( + self, + *, + step_run_id: "UUID", + step_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + self._q.put( + ( + "step_metadata", + (), + { + "step_run_id": step_run_id, + "step_run_metadata": step_run_metadata, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_successful_step_run( + self, + *, + step_run_id: "UUID", + output_artifact_ids: Dict[str, List["UUID"]], + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + self._q.put( + ( + "step_success", + (), + { + "step_run_id": step_run_id, + "output_artifact_ids": output_artifact_ids, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_failed_step_run( + self, + *, + step_run_id: "UUID", + ) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + self._q.put(("step_failed", (), {"step_run_id": step_run_id})) + with self._lock: + self._queued_count += 1 + + def flush(self) -> None: + """Flush the realtime runtime by draining queued events synchronously.""" + # Drain the queue in the calling thread to avoid waiting on the worker + while True: + try: + kind, args, kwargs = self._q.get_nowait() + except queue.Empty: + break + try: + if kind == "pipeline_metadata": + publish_utils.publish_pipeline_run_metadata( + *args, **kwargs + ) + elif kind == "step_metadata": + publish_utils.publish_step_run_metadata(*args, **kwargs) + elif kind == "step_success": + publish_utils.publish_successful_step_run(*args, **kwargs) + elif kind == "step_failed": + publish_utils.publish_failed_step_run(*args, **kwargs) + except BaseException as e: # noqa: BLE001 + with self._lock: + self._errors_since_last_flush += 1 + self._total_errors += 1 + self._last_error = e + self._logger.warning( + "Realtime runtime flush failed to publish '%s': %s", + kind, + e, + ) + finally: + with self._lock: + self._processed_count += 1 + try: + self._q.task_done() + except ValueError: + # If task_done called more than put() count due to races, ignore + pass + # Post-flush maintenance + self._sweep_expired() + with self._lock: + if self._errors_since_last_flush: + count = self._errors_since_last_flush + last = self._last_error + self._errors_since_last_flush = 0 + raise RuntimeError( + f"Realtime runtime encountered {count} error(s) while publishing. Last error: {last}" + ) + + def on_step_end(self) -> None: + """Optional hook when a step ends execution.""" + # no-op for now + return + + def shutdown(self) -> None: + """Shutdown the realtime runtime.""" + # Wait for remaining tasks and stop + self.flush() + self._stop.set() + # Join worker with timeout + worker = self._worker + if worker is not None: + worker.join(timeout=15.0) + if worker.is_alive(): + self._logger.warning( + "Realtime runtime worker did not terminate gracefully within timeout." + ) + self._worker = None + + # Flush behavior controls + def set_flush_on_step_end(self, value: bool) -> None: + """Set the flush on step end behavior. + + Args: + value: The value to set. + """ + self._flush_on_step_end = bool(value) + + def should_flush_on_step_end(self) -> bool: + """Whether the runtime should flush on step end. + + Returns: + Whether the runtime should flush on step end. + """ + return self._flush_on_step_end + + def get_metrics(self) -> Dict[str, Any]: + """Return runtime metrics snapshot. + + Returns: + The runtime metrics snapshot. + """ + with self._lock: + queued = self._queued_count + processed = self._processed_count + failed_total = self._total_errors + ttl_seconds = getattr(self, "_ttl_seconds", None) + max_entries = getattr(self, "_max_entries", None) + try: + depth = self._q.qsize() + except Exception: + depth = 0 + return { + "queued": queued, + "processed": processed, + "failed_total": failed_total, + "queue_depth": depth, + "ttl_seconds": ttl_seconds, + "max_entries": max_entries, + } + + # --- internal helpers --- + def _sweep_expired(self) -> None: + """Remove expired entries from the head (LRU) side.""" + with self._lock: + now = time.time() + # Pop from head while expired + keys = list(self._cache.keys()) + for k in keys[:32]: # limit per sweep to bound work + try: + value, expires_at = self._cache[k] + except KeyError: + continue + if now > expires_at: + try: + del self._cache[k] + except KeyError: + pass + else: + # Stop at first non-expired near head + break diff --git a/src/zenml/execution/step_runtime.py b/src/zenml/execution/step_runtime.py new file mode 100644 index 00000000000..3df62f6b3bb --- /dev/null +++ b/src/zenml/execution/step_runtime.py @@ -0,0 +1,750 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Step runtime facade for step execution responsibilities. + +This scaffolds a minimal, behavior-preserving runtime abstraction that the +step runner can call into for artifact I/O and input resolution. The default +implementation delegates to existing ZenML utilities. + +Enable usage by setting environment variable `ZENML_ENABLE_STEP_RUNTIME=true`. +""" + +import threading +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from uuid import UUID + +from zenml.client import Client +from zenml.models import ArtifactVersionResponse + +if TYPE_CHECKING: + from zenml.artifact_stores import BaseArtifactStore + from zenml.config.step_configurations import Step + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.models import PipelineRunResponse, StepRunResponse + from zenml.models.v2.core.step_run import StepRunInputResponse + from zenml.stack import Stack + from zenml.steps.utils import OutputSignature + + +class BaseStepRuntime(ABC): + """Abstract execution-time interface for step I/O and interactions. + + Implementations may optimize persistence, caching, logging, and server + updates based on capture policy. This base class only covers the minimal + responsibilities we want to centralize first. + """ + + @abstractmethod + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, "StepRunInputResponse"]: + """Resolve input artifacts for the given step. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + + Returns: + The resolved inputs. + """ + + @abstractmethod + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load materialized value for an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + """ + + @abstractmethod + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Materialize and persist output artifacts and return their versions. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + """ + + # --- Cache Helpers (optional) --- + def compute_cache_key( + self, + *, + step: "Step", + input_artifact_ids: Dict[str, UUID], + artifact_store: "BaseArtifactStore", + project_id: UUID, + ) -> str: + """Compute a cache key for a step using existing utilities. + + Default implementation delegates to `cache_utils`. + + Args: + step: The step to compute the cache key for. + input_artifact_ids: The input artifact IDs. + artifact_store: The artifact store to compute the cache key for. + project_id: The project ID to compute the cache key for. + + Returns: + The computed cache key. + """ + from zenml.orchestrators import cache_utils + + return cache_utils.generate_cache_key( + step=step, + input_artifact_ids=input_artifact_ids, + artifact_store=artifact_store, + project_id=project_id, + ) + + def get_cached_step_run( + self, *, cache_key: str + ) -> Optional["StepRunResponse"]: + """Return a cached step run if available. + + Default implementation delegates to `cache_utils`. + + Args: + cache_key: The cache key to get the cached step run for. + + Returns: + The cached step run if available, otherwise None. + """ + from zenml.orchestrators import cache_utils + + return cache_utils.get_cached_step_run(cache_key=cache_key) + + # --- Server update helpers (may be batched/async by implementations) --- + def start(self) -> None: + """Optional start hook for runtime lifecycles.""" + + def on_step_start(self) -> None: + """Optional hook when a step begins execution.""" + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + from zenml.orchestrators.publish_utils import ( + publish_pipeline_run_metadata, + ) + + publish_pipeline_run_metadata( + pipeline_run_id=pipeline_run_id, + pipeline_run_metadata=pipeline_run_metadata, + ) + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + from zenml.orchestrators.publish_utils import publish_step_run_metadata + + publish_step_run_metadata( + step_run_id=step_run_id, step_run_metadata=step_run_metadata + ) + + def publish_successful_step_run( + self, *, step_run_id: Any, output_artifact_ids: Any + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + from zenml.orchestrators.publish_utils import ( + publish_successful_step_run, + ) + + publish_successful_step_run( + step_run_id=step_run_id, output_artifact_ids=output_artifact_ids + ) + + def publish_failed_step_run(self, *, step_run_id: Any) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + from zenml.orchestrators.publish_utils import publish_failed_step_run + + publish_failed_step_run(step_run_id) + + def flush(self) -> None: + """Ensure all queued updates are sent.""" + + def on_step_end(self) -> None: + """Optional hook when a step finishes execution.""" + + def shutdown(self) -> None: + """Optional shutdown hook for runtime lifecycles.""" + + def get_metrics(self) -> Dict[str, Any]: + """Optional runtime metrics for observability. + + Default implementation returns an empty dict. + """ + return {} + + # --- Flush behavior --- + def should_flush_on_step_end(self) -> bool: + """Whether the runner should call flush() at step end. + + Implementations may override to disable flush for non-blocking serving. + """ + return True + + +class DefaultStepRuntime(BaseStepRuntime): + """Default runtime delegating to existing ZenML utilities. + + This keeps current behavior intact while providing a single place for the + step runner to call into. It intentionally mirrors logic from + `step_runner.py` and `orchestrators/input_utils.py`. + """ + + # --- Input Resolution --- + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, "StepRunInputResponse"]: + """Resolve step inputs. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + """ + from zenml.orchestrators import input_utils + + return input_utils.resolve_step_inputs( + step=step, pipeline_run=pipeline_run, step_runs=step_runs + ) + + # --- Artifact Load --- + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + """ + from typing import Any as _Any + + from zenml.artifacts.unmaterialized_artifact import ( + UnmaterializedArtifact, + ) + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.orchestrators.utils import ( + register_artifact_store_filesystem, + ) + from zenml.utils import source_utils + from zenml.utils.typing_utils import get_origin, is_union + + # Skip materialization for `UnmaterializedArtifact`. + if data_type == UnmaterializedArtifact: + return UnmaterializedArtifact( + **artifact.get_hydrated_version().model_dump() + ) + + if data_type in (None, _Any) or is_union(get_origin(data_type)): + # Use the stored artifact datatype when function annotation is not specific + data_type = source_utils.load(artifact.data_type) + + materializer_class: Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + artifact.materializer, expected_class=BaseMaterializer + ) + ) + + def _load(artifact_store: "BaseArtifactStore") -> Any: + materializer: BaseMaterializer = materializer_class( + uri=artifact.uri, artifact_store=artifact_store + ) + materializer.validate_load_type_compatibility(data_type) + return materializer.load(data_type=data_type) + + if artifact.artifact_store_id == stack.artifact_store.id: + stack.artifact_store._register() + return _load(artifact_store=stack.artifact_store) + else: + with register_artifact_store_filesystem( + artifact.artifact_store_id + ) as target_store: + return _load(artifact_store=target_store) + + # --- Artifact Store --- + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + from typing import Type as _Type + + from zenml.artifacts.utils import ( + _store_artifact_data_and_prepare_request, + ) + from zenml.enums import ArtifactSaveType + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.steps.step_context import get_step_context + from zenml.utils import materializer_utils, source_utils, tag_utils + + step_context = get_step_context() + artifact_requests: List[Any] = [] + + for output_name, return_value in output_data.items(): + data_type = type(return_value) + materializer_classes = output_materializers[output_name] + if materializer_classes: + materializer_class: _Type[BaseMaterializer] = ( + materializer_utils.select_materializer( + data_type=data_type, + materializer_classes=materializer_classes, + ) + ) + else: + # Runtime selection if no explicit materializer recorded + from zenml.materializers.materializer_registry import ( + materializer_registry, + ) + + default_materializer_source = ( + step_context.step_run.config.outputs[ + output_name + ].default_materializer_source + if step_context and step_context.step_run + else None + ) + + if default_materializer_source: + default_materializer_class: _Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + default_materializer_source, + expected_class=BaseMaterializer, + ) + ) + materializer_registry.default_materializer = ( + default_materializer_class + ) + + materializer_class = materializer_registry[data_type] + + uri = output_artifact_uris[output_name] + artifact_config = output_annotations[output_name].artifact_config + + artifact_type = None + if artifact_config is not None: + has_custom_name = bool(artifact_config.name) + version = artifact_config.version + artifact_type = artifact_config.artifact_type + else: + has_custom_name, version = False, None + + # Name resolution mirrors existing behavior + if has_custom_name: + artifact_name = output_name + else: + if step_context.pipeline_run.pipeline: + pipeline_name = step_context.pipeline_run.pipeline.name + else: + pipeline_name = "unlisted" + step_name = step_context.step_run.name + artifact_name = f"{pipeline_name}::{step_name}::{output_name}" + + # Collect user metadata and tags + user_metadata = step_context.get_output_metadata(output_name) + tags = step_context.get_output_tags(output_name) + if step_context.pipeline_run.config.tags is not None: + for tag in step_context.pipeline_run.config.tags: + if isinstance(tag, tag_utils.Tag) and tag.cascade is True: + tags.append(tag.name) + + artifact_request = _store_artifact_data_and_prepare_request( + name=artifact_name, + data=return_value, + materializer_class=materializer_class, + uri=uri, + artifact_type=artifact_type, + store_metadata=artifact_metadata_enabled, + store_visualizations=artifact_visualization_enabled, + has_custom_name=has_custom_name, + version=version, + tags=tags, + save_type=ArtifactSaveType.STEP_OUTPUT, + metadata=user_metadata, + ) + artifact_requests.append(artifact_request) + + responses = Client().zen_store.batch_create_artifact_versions( + artifact_requests + ) + return dict(zip(output_data.keys(), responses)) + + +class OffStepRuntime(DefaultStepRuntime): + """OFF mode runtime: minimize overhead but keep correctness. + + Notes: + - We intentionally keep artifact persistence and success/failure status + updates to avoid breaking input resolution across steps. + - We no-op metadata publishing calls to reduce server traffic. + """ + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + # No-op: skip pipeline run metadata in OFF mode + return + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + # No-op: skip step run metadata in OFF mode + return + + +class MemoryStepRuntime(BaseStepRuntime): + """Pure in-memory execution runtime: no server calls, no persistence.""" + + _STORE: Dict[str, Dict[Tuple[str, str], Any]] = {} + _LOCK: Any = threading.RLock() # initialized at class load + + @staticmethod + def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: + """Make a handle ID for an output artifact. + + Args: + run_id: The run ID. + step_name: The step name. + output_name: The output name. + + Returns: + The handle ID. + """ + return f"mem://{run_id}/{step_name}/{output_name}" + + @staticmethod + def parse_handle_id(handle_id: str) -> Tuple[str, str, str]: + """Parse a handle ID for an output artifact. + + Args: + handle_id: The handle ID. + + Returns: + The run ID, step name, and output name. + """ + if not isinstance(handle_id, str) or not handle_id.startswith( + "mem://" + ): + raise ValueError("Invalid memory handle id") + rest = handle_id[len("mem://") :] + # split into exactly 3 parts: run_id, step_name, output_name + parts = rest.split("/", 2) + if len(parts) != 3: + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = parts + # basic sanitization + for p in (run_id, step_name, output_name): + if not p or "\n" in p or "\r" in p: + raise ValueError("Invalid memory handle component") + return run_id, step_name, output_name + + class Handle: + """A handle for an output artifact.""" + + def __init__(self, id: str) -> None: + """Initialize the handle. + + Args: + id: The handle ID. + """ + self.id = id + + # Instance-scoped context for handle resolution (set by launcher) + def __init__(self) -> None: + """Initialize the memory runtime.""" + super().__init__() + self._ctx_run_id: Optional[str] = None + self._ctx_substitutions: Dict[str, str] = {} + + def set_context( + self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None + ) -> None: + """Set current memory-only context for handle resolution. + + Args: + run_id: The run ID. + substitutions: The substitutions. + """ + self._ctx_run_id = run_id + self._ctx_substitutions = substitutions or {} + + def resolve_step_inputs( + self, *, step, pipeline_run, step_runs=None + ) -> Dict[str, Any]: + """Resolve step inputs by constructing in-memory handles. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + + Returns: + A mapping of input name to MemoryStepRuntime.Handle. + """ + from zenml.utils import string_utils + + run_id = self._ctx_run_id or str(getattr(pipeline_run, "id", "local")) + subs = self._ctx_substitutions or {} + handles: Dict[str, Any] = {} + for name, input_ in step.spec.inputs.items(): + resolved_output_name = string_utils.format_name_template( + input_.output_name, substitutions=subs + ) + handle_id = self.make_handle_id( + run_id=run_id, + step_name=input_.step_name, + output_name=resolved_output_name, + ) + handles[name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def load_input_artifact( + self, *, artifact: Any, data_type: Type[Any], stack: Any + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + + Returns: + The loaded artifact. + """ + handle_id_any = getattr(artifact, "id", None) + if not isinstance(handle_id_any, str): + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = self.parse_handle_id(handle_id_any) + with MemoryStepRuntime._LOCK: + return MemoryStepRuntime._STORE.get(run_id, {}).get( + (step_name, output_name) + ) + + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type[Any], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, Any], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, Any]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + from zenml.steps.step_context import get_step_context + + ctx = get_step_context() + run_id = str(getattr(ctx.pipeline_run, "id", "local")) + step_name = str(getattr(ctx.step_run, "name", "step")) + handles: Dict[str, Any] = {} + with MemoryStepRuntime._LOCK: + rr = MemoryStepRuntime._STORE.setdefault(run_id, {}) + for output_name, value in output_data.items(): + rr[(step_name, output_name)] = value + handle_id = self.make_handle_id(run_id, step_name, output_name) + handles[output_name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def compute_cache_key( + self, + *, + step: Any, + input_artifact_ids: Dict[str, Any], + artifact_store: Any, + project_id: Any, + ) -> str: + """Compute a cache key. + + Args: + step: The step to compute the cache key for. + input_artifact_ids: The input artifact IDs. + artifact_store: The artifact store to compute the cache key for. + project_id: The project ID to compute the cache key for. + + Returns: + The computed cache key. + """ + return "" + + def get_cached_step_run(self, *, cache_key: str) -> None: + """Get a cached step run. + + Args: + cache_key: The cache key to get the cached step run for. + + Returns: + The cached step run if available, otherwise None. + """ + return None + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + return + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + return + + def publish_successful_step_run( + self, *, step_run_id: Any, output_artifact_ids: Any + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + return + + def publish_failed_step_run(self, *, step_run_id: Any) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + return + + def start(self) -> None: + """Start the memory runtime.""" + return + + def on_step_start(self) -> None: + """Optional hook when a step starts execution.""" + return + + def flush(self) -> None: + """Flush the memory runtime.""" + return + + def on_step_end(self) -> None: + """Optional hook when a step ends execution.""" + return + + def shutdown(self) -> None: + """Shutdown the memory runtime.""" + return diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 82d9aeca8cc..78e103c0977 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -308,10 +308,23 @@ def main() -> None: for owner_reference in owner_references: owner_reference.controller = False + # Build a runtime for request factory using capture mode from config + try: + from zenml.execution.capture_policy import CapturePolicy + from zenml.execution.factory import get_runtime + + mode_cfg = getattr( + deployment.pipeline_configuration, "capture", None + ) + _runtime = get_runtime(CapturePolicy.from_value(mode_cfg)) + except Exception: + _runtime = None + step_run_request_factory = StepRunRequestFactory( deployment=deployment, pipeline_run=pipeline_run, stack=active_stack, + runtime=_runtime, ) step_runs = {} diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index ebda46f608d..43ad1d2e18f 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -29,6 +29,8 @@ from zenml.enums import ExecutionStatus from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException +from zenml.execution.capture_policy import CapturePolicy +from zenml.execution.factory import get_runtime from zenml.logger import get_logger from zenml.logging import step_logging from zenml.models import ( @@ -232,8 +234,56 @@ def launch(self) -> None: if self._deployment.pipeline_configuration.settings else None ) + + # Determine capture-based runtime and memory-only mode early + mode_cfg = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + capture_policy = CapturePolicy.from_value(mode_cfg) + runtime = get_runtime(capture_policy) + # Store for later use + self._runtime = runtime + runs_opt = str(capture_policy.get_option("runs", "on")).lower() + persistence = str( + capture_policy.get_option("persistence", "async") + ).lower() + memory_only = runs_opt in {"off", "false", "0"} or persistence in { + "memory", + "off", + } + + if memory_only: + self._launch_memory_only() + return pipeline_run, run_was_created = self._create_or_reuse_run() + # runtime already constructed above; configure flush behavior + # Default for serving (REALTIME): do not flush at step end unless user specifies + import os as _os + + in_serving_ctx = ( + _os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None + ) + if ( + capture_policy.mode.name == "REALTIME" + and "flush_on_step_end" + not in getattr(capture_policy, "options", {}) + and in_serving_ctx + ): + flush_opt = False + else: + # Honor capture option: flush_on_step_end (default True) + flush_opt = capture_policy.get_option("flush_on_step_end", True) + # Configure runtime flush behavior if supported + set_flush = getattr(runtime, "set_flush_on_step_end", None) + if callable(set_flush): + try: + set_flush(bool(flush_opt)) + except Exception as e: + logger.debug( + "Could not configure runtime flush behavior: %s", e + ) + # Enable or disable step logs storage if ( handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False) @@ -272,7 +322,8 @@ def launch(self) -> None: pipeline_run_metadata = self._stack.get_pipeline_run_metadata( run_id=pipeline_run.id ) - publish_utils.publish_pipeline_run_metadata( + runtime.start() + runtime.publish_pipeline_run_metadata( pipeline_run_id=pipeline_run.id, pipeline_run_metadata=pipeline_run_metadata, ) @@ -281,93 +332,132 @@ def launch(self) -> None: model_version=model_version ) - request_factory = step_run_utils.StepRunRequestFactory( - deployment=self._deployment, - pipeline_run=pipeline_run, - stack=self._stack, - ) - step_run_request = request_factory.create_request( - invocation_id=self._step_name - ) - step_run_request.logs = logs_model + # Honor capture.code flag (default True) + code_opt = capture_policy.get_option("code", True) + code_enabled = str(code_opt).lower() not in {"false", "0", "off"} + request_factory = step_run_utils.StepRunRequestFactory( + deployment=self._deployment, + pipeline_run=pipeline_run, + stack=self._stack, + runtime=runtime, + skip_code_capture=not code_enabled, + ) + step_run_request = request_factory.create_request( + invocation_id=self._step_name + ) + step_run_request.logs = logs_model + + # If this step has upstream dependencies and runtime uses non-blocking + # publishes, ensure previous step updates are flushed so input + # resolution via server succeeds. + if ( + self._step.spec.upstream_steps + and not runtime.should_flush_on_step_end() + ): try: - # Always populate request to ensure proper input/output flow - request_factory.populate_request(request=step_run_request) - - # In no-capture mode, force fresh execution (bypass cache) - if tracking_disabled: - step_run_request.original_step_run_id = None - step_run_request.outputs = {} - step_run_request.status = ExecutionStatus.RUNNING - except BaseException as e: - logger.exception(f"Failed preparing step `{self._step_name}`.") - step_run_request.status = ExecutionStatus.FAILED - step_run_request.end_time = utc_now() - step_run_request.exception_info = ( - exception_utils.collect_exception_information(e) + runtime.flush() + except Exception as e: + logger.debug( + "Non-blocking flush before input resolution failed: %s", e ) - raise - finally: - # Always create real step run for proper input/output flow - step_run = Client().zen_store.create_run_step(step_run_request) - self._step_run = step_run - if not tracking_disabled and ( - model_version := step_run.model_version - ): - step_run_utils.log_model_version_dashboard_url( - model_version=model_version - ) - if not step_run.status.is_finished: - logger.info(f"Step `{self._step_name}` has started.") + try: + # Always populate request to ensure proper input/output flow + request_factory.populate_request(request=step_run_request) + + # In no-capture mode, force fresh execution (bypass cache) + if tracking_disabled: + step_run_request.original_step_run_id = None + step_run_request.outputs = {} + step_run_request.status = ExecutionStatus.RUNNING + except BaseException as e: + logger.exception(f"Failed preparing step `{self._step_name}`.") + step_run_request.status = ExecutionStatus.FAILED + step_run_request.end_time = utc_now() + step_run_request.exception_info = ( + exception_utils.collect_exception_information(e) + ) + raise + finally: + # Always create real step run for proper input/output flow + step_run = Client().zen_store.create_run_step(step_run_request) + self._step_run = step_run + if not tracking_disabled and ( + model_version := step_run.model_version + ): + step_run_utils.log_model_version_dashboard_url( + model_version=model_version + ) - try: - # here pass a forced save_to_file callable to be - # used as a dump function to use before starting - # the external jobs in step operators - if isinstance( - logs_context, - step_logging.PipelineLogsStorageContext, - ): - force_write_logs = ( - logs_context.storage.send_merge_event - ) - else: + if not step_run.status.is_finished: + logger.info(f"Step `{self._step_name}` has started.") - def _bypass() -> None: - return None + try: + # here pass a forced save_to_file callable to be + # used as a dump function to use before starting + # the external jobs in step operators + if isinstance( + logs_context, + step_logging.PipelineLogsStorageContext, + ): + force_write_logs = logs_context.storage.send_merge_event + else: - force_write_logs = _bypass - self._run_step( - pipeline_run=pipeline_run, - step_run=step_run, - force_write_logs=force_write_logs, - ) - except RunStoppedException as e: - raise e - except BaseException as e: # noqa: E722 - logger.error( - "Failed to run step `%s`: %s", - self._step_name, - e, + def _bypass() -> None: + return None + + force_write_logs = _bypass + self._run_step( + pipeline_run=pipeline_run, + step_run=step_run, + force_write_logs=force_write_logs, + ) + except RunStoppedException as e: + raise e + except BaseException as e: # noqa: E722 + logger.error( + "Failed to run step `%s`: %s", + self._step_name, + e, + ) + if not tracking_disabled: + runtime.publish_failed_step_run(step_run_id=step_run.id) + if runtime.should_flush_on_step_end(): + runtime.flush() + raise + else: + logger.info(f"Using cached version of step `{self._step_name}`.") + if not tracking_disabled: + if ( + model_version := step_run.model_version + or pipeline_run.model_version + ): + step_run_utils.link_output_artifacts_to_model_version( + artifacts=step_run.outputs, + model_version=model_version, ) - if not tracking_disabled: - publish_utils.publish_failed_step_run(step_run.id) - raise - else: + # Ensure any queued updates are flushed for cached path (if enabled) + if runtime.should_flush_on_step_end(): + runtime.flush() + # Ensure runtime shutdown after launch + try: + metrics = {} + try: + metrics = runtime.get_metrics() or {} + except Exception: + metrics = {} + runtime.shutdown() + if metrics: logger.info( - f"Using cached version of step `{self._step_name}`." + "Runtime metrics: queued=%s processed=%s failed_total=%s queue_depth=%s", + metrics.get("queued"), + metrics.get("processed"), + metrics.get("failed_total"), + metrics.get("queue_depth"), ) - if not tracking_disabled: - if ( - model_version := step_run.model_version - or pipeline_run.model_version - ): - step_run_utils.link_output_artifacts_to_model_version( - artifacts=step_run.outputs, - model_version=model_version, - ) + except Exception as e: + logger.debug(f"Runtime shutdown/metrics retrieval error: {e}") def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: """Creates a pipeline run or reuses an existing one. @@ -488,6 +578,24 @@ def _run_step( f"`{string_utils.get_human_readable_time(duration)}`." ) + # If runtime is non-blocking and there are downstream steps depending + # on this step, flush now so that downstream input resolution sees + # this step's outputs on the server. + runtime = getattr(self, "_runtime", None) + if runtime is not None and not runtime.should_flush_on_step_end(): + has_downstream = any( + self._step_name in cfg.spec.upstream_steps + for name, cfg in self._deployment.step_configurations.items() + ) + if has_downstream: + try: + runtime.flush() + except Exception as e: + logger.debug( + "Non-blocking runtime flush after step finish failed: %s", + e, + ) + def _run_step_with_step_operator( self, step_operator_name: Optional[str], @@ -520,6 +628,17 @@ def _run_step_with_step_operator( environment.update(secrets) environment[ENV_ZENML_STEP_OPERATOR] = "True" + # Propagate capture mode to the step operator environment so that + # the entrypoint can construct the appropriate runtime. + try: + mode_cfg = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + if mode_cfg: + environment["ZENML_CAPTURE_MODE"] = str(mode_cfg).upper() + environment["ZENML_ENABLE_STEP_RUNTIME"] = "true" + except Exception: + pass logger.info( "Using step operator `%s` to run step `%s`.", step_operator.name, @@ -548,7 +667,11 @@ def _run_step_without_step_operator( input_artifacts: The input artifact versions of the current step. output_artifact_uris: The output artifact URIs of the current step. """ - runner = StepRunner(step=self._step, stack=self._stack) + # Use runtime determined at launch + runtime = getattr(self, "_runtime", None) + runner = StepRunner( + step=self._step, stack=self._stack, runtime=runtime + ) runner.run( pipeline_run=pipeline_run, step_run=step_run, @@ -556,3 +679,96 @@ def _run_step_without_step_operator( output_artifact_uris=output_artifact_uris, step_run_info=step_run_info, ) + + def _launch_memory_only(self) -> None: + """Launch the step in pure memory-only mode (no runs, no persistence).""" + from dataclasses import dataclass + from typing import Any + + from zenml.config.step_run_info import StepRunInfo + from zenml.execution.step_runtime import MemoryStepRuntime + from zenml.utils.time_utils import utc_now + + run_id = self._orchestrator_run_id + start_time = utc_now() + substitutions = ( + self._deployment.pipeline_configuration.finalize_substitutions( + start_time=start_time + ) + ) + + @dataclass + class _Cfg: + tags: Any = None + + @dataclass + class _PipelineRunStub: + id: str + model_version: Any = None + pipeline: Any = None + config: Any = _Cfg() + + @dataclass + class _StepCfg: + substitutions: Any + outputs: Any + + @dataclass + class _StepRunStub: + id: str + name: str + model_version: Any + config: Any + is_retriable: bool = True + + pipeline_run_stub = _PipelineRunStub(id=run_id) + step_run_stub = _StepRunStub( + id=run_id, # valid UUID string preferred + name=self._step_name, + model_version=None, + config=_StepCfg( + substitutions=substitutions, outputs=self._step.config.outputs + ), + is_retriable=True, + ) + + # Build URIs from declared outputs (no imports needed) + output_names = list(self._step.config.outputs.keys()) + output_artifact_uris = { + name: f"memory://{run_id}/{self._step_name}/{name}" + for name in output_names + } + + # Resolve inputs via runtime to avoid duplication + if isinstance(self._runtime, MemoryStepRuntime): + self._runtime.set_context( + run_id=run_id, substitutions=substitutions + ) + input_artifacts = self._runtime.resolve_step_inputs( + step=self._step, pipeline_run=pipeline_run_stub + ) + else: + input_artifacts = {} + + runner = StepRunner( + step=self._step, stack=self._stack, runtime=self._runtime + ) + step_run_info = StepRunInfo( + config=self._step.config, + pipeline=self._deployment.pipeline_configuration, + run_name=self._deployment.run_name_template, + pipeline_step_name=self._step_name, + run_id=run_id, + step_run_id=step_run_stub.id, + force_write_logs=lambda: None, + ) + + from typing import Any, cast + + runner.run( + pipeline_run=cast(Any, pipeline_run_stub), + step_run=cast(Any, step_run_stub), + input_artifacts=input_artifacts, + output_artifact_uris=output_artifact_uris, + step_run_info=step_run_info, + ) diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 042c927090a..7de15a26716 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -20,6 +20,7 @@ from zenml.config.step_configurations import Step from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH from zenml.enums import ExecutionStatus +from zenml.execution.step_runtime import BaseStepRuntime, DefaultStepRuntime from zenml.logger import get_logger from zenml.model.utils import link_artifact_version_to_model_version from zenml.models import ( @@ -30,7 +31,7 @@ StepRunRequest, StepRunResponse, ) -from zenml.orchestrators import cache_utils, input_utils, utils +from zenml.orchestrators import utils from zenml.stack import Stack from zenml.utils import pagination_utils from zenml.utils.time_utils import utc_now @@ -46,6 +47,8 @@ def __init__( deployment: "PipelineDeploymentResponse", pipeline_run: "PipelineRunResponse", stack: "Stack", + runtime: Optional[BaseStepRuntime] = None, + skip_code_capture: bool = False, ) -> None: """Initialize the object. @@ -54,10 +57,14 @@ def __init__( pipeline_run: The pipeline run for which to create step run requests. stack: The stack on which the pipeline run is happening. + runtime: The runtime to use for the step run requests. + skip_code_capture: Whether to skip code/docstring capture. """ self.deployment = deployment self.pipeline_run = pipeline_run self.stack = stack + self.runtime: BaseStepRuntime = runtime or DefaultStepRuntime() + self.skip_code_capture = skip_code_capture def has_caching_enabled(self, invocation_id: str) -> bool: """Check if the step has caching enabled. @@ -112,7 +119,7 @@ def populate_request( """ step = self.deployment.step_configurations[request.name] - input_artifacts = input_utils.resolve_step_inputs( + input_artifacts = self.runtime.resolve_step_inputs( step=step, pipeline_run=self.pipeline_run, step_runs=step_runs, @@ -126,7 +133,7 @@ def populate_request( name: [artifact.id] for name, artifact in input_artifacts.items() } - cache_key = cache_utils.generate_cache_key( + cache_key = self.runtime.compute_cache_key( step=step, input_artifact_ids=input_artifact_ids, artifact_store=self.stack.artifact_store, @@ -134,13 +141,14 @@ def populate_request( ) request.cache_key = cache_key - ( - docstring, - source_code, - ) = self._get_docstring_and_source_code(invocation_id=request.name) + if not self.skip_code_capture: + ( + docstring, + source_code, + ) = self._get_docstring_and_source_code(invocation_id=request.name) - request.docstring = docstring - request.source_code = source_code + request.docstring = docstring + request.source_code = source_code request.code_hash = step.config.parameters.get( CODE_HASH_PARAMETER_NAME ) @@ -151,7 +159,7 @@ def populate_request( ) if cache_enabled: - if cached_step_run := cache_utils.get_cached_step_run( + if cached_step_run := self.runtime.get_cached_step_run( cache_key=cache_key ): request.inputs = { diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 8cad58ad1f7..930e4ebb93c 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -40,6 +40,7 @@ ) from zenml.enums import ArtifactSaveType from zenml.exceptions import StepInterfaceError +from zenml.execution.step_runtime import BaseStepRuntime, DefaultStepRuntime from zenml.logger import get_logger from zenml.logging.step_logging import PipelineLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer @@ -49,7 +50,6 @@ ) from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, - publish_successful_step_run, step_exception_info, ) from zenml.orchestrators.utils import ( @@ -90,15 +90,24 @@ class StepRunner: """Class to run steps.""" - def __init__(self, step: "Step", stack: "Stack"): + def __init__( + self, + step: "Step", + stack: "Stack", + runtime: Optional[BaseStepRuntime] = None, + ): """Initializes the step runner. Args: step: The step to run. stack: The stack on which the step should run. + runtime: The runtime to use for the step run. """ self._step = step self._stack = stack + # Initialize runtime behind an opt-in flag to preserve behavior + # Always have a runtime to avoid branching; default to behavior-parity runtime + self._runtime: BaseStepRuntime = runtime or DefaultStepRuntime() @property def configuration(self) -> StepConfiguration: @@ -213,6 +222,8 @@ def run( step_failed = False try: + if self._runtime is not None: + self._runtime.on_step_start() return_values = step_instance.call_entrypoint( **function_params ) @@ -253,10 +264,16 @@ def run( step_run_metadata = self._stack.get_step_run_metadata( info=step_run_info, ) - publish_step_run_metadata( - step_run_id=step_run_info.step_run_id, - step_run_metadata=step_run_metadata, - ) + if self._runtime is not None: + self._runtime.publish_step_run_metadata( + step_run_id=step_run_info.step_run_id, + step_run_metadata=step_run_metadata, + ) + else: + publish_step_run_metadata( + step_run_id=step_run_info.step_run_id, + step_run_metadata=step_run_metadata, + ) self._stack.cleanup_step_run( info=step_run_info, step_failed=step_failed ) @@ -302,14 +319,24 @@ def run( is_enabled_on_step=step_run_info.config.enable_artifact_visualization, is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization, ) - output_artifacts = self._store_output_artifacts( - output_data=output_data, - output_artifact_uris=output_artifact_uris, - output_materializers=output_materializers, - output_annotations=output_annotations, - artifact_metadata_enabled=artifact_metadata_enabled, - artifact_visualization_enabled=artifact_visualization_enabled, - ) + if self._runtime is not None: + output_artifacts = self._runtime.store_output_artifacts( + output_data=output_data, + output_artifact_uris=output_artifact_uris, + output_materializers=output_materializers, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) + else: + output_artifacts = self._store_output_artifacts( + output_data=output_data, + output_artifact_uris=output_artifact_uris, + output_materializers=output_materializers, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) if ( model_version := step_run.model_version @@ -336,10 +363,14 @@ def run( ] for output_name, artifact in output_artifacts.items() } - publish_successful_step_run( + self._runtime.publish_successful_step_run( step_run_id=step_run_info.step_run_id, output_artifact_ids=output_artifact_ids, ) + # Ensure updates are flushed at end of step unless disabled + self._runtime.on_step_end() + if self._runtime.should_flush_on_step_end(): + self._runtime.flush() def _evaluate_artifact_names_in_collections( self, @@ -441,9 +472,16 @@ def _parse_inputs( arg_type = resolve_type_annotation(arg_type) if arg in input_artifacts: - function_params[arg] = self._load_input_artifact( - input_artifacts[arg], arg_type - ) + if self._runtime is not None: + function_params[arg] = self._runtime.load_input_artifact( + artifact=input_artifacts[arg], + data_type=arg_type, + stack=self._stack, + ) + else: + function_params[arg] = self._load_input_artifact( + input_artifacts[arg], arg_type + ) elif arg in self.configuration.parameters: param_value = self.configuration.parameters[arg] # Pydantic bridging: convert dict to Pydantic model if possible diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index d14ffba235f..85d40738340 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -62,6 +62,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Dict[str, Any]] = None, ) -> Callable[["F"], "Pipeline"]: ... @@ -83,6 +84,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Dict[str, Any]] = None, ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]: """Decorator to create a pipeline. @@ -113,6 +115,7 @@ def pipeline( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. + capture: Capture policy for the pipeline. Returns: A pipeline instance. @@ -138,6 +141,7 @@ def inner_decorator(func: "F") -> "Pipeline": model=model, retry=retry, substitutions=substitutions, + capture=capture, ) p.__doc__ = func.__doc__ diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index ba8d9f13f71..2d1d94d8edf 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -148,6 +148,7 @@ def __init__( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Any] = None, ) -> None: """Initializes a pipeline. @@ -180,6 +181,7 @@ def __init__( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. + capture: Capture policy for the pipeline. """ self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -205,6 +207,7 @@ def __init__( model=model, retry=retry, substitutions=substitutions, + capture=capture, ) self.entrypoint = entrypoint self._parameters: Dict[str, Any] = {} @@ -330,6 +333,7 @@ def configure( parameters: Optional[Dict[str, Any]] = None, merge: bool = True, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Any] = None, ) -> Self: """Configures the pipeline. @@ -376,6 +380,7 @@ def configure( retry: Retry configuration for the pipeline steps. parameters: input parameters for the pipeline. substitutions: Extra placeholders to use in the name templates. + capture: Capture policy for the pipeline. Returns: The pipeline instance that this method was called on. @@ -405,6 +410,13 @@ def configure( # merges dicts tags = self._configuration.tags + tags + # Normalize capture to upper-case string if provided + if capture is not None: + try: + capture = str(capture).upper() + except Exception: + capture = str(capture) + values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -423,6 +435,7 @@ def configure( "retry": retry, "parameters": parameters, "substitutions": substitutions, + "capture": capture, } ) if not self.__suppress_warnings_flag__: From 9d295bfe0aea8ef297e0026e8b3d86b33ccd0e0b Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sat, 6 Sep 2025 20:49:24 +0100 Subject: [PATCH 2/8] Refactor capture configuration and runtime management This commit introduces a new capture configuration system that simplifies the handling of capture modes in ZenML. The `Capture` class is now used to define capture settings, allowing for more explicit and typed configurations. The pipeline decorator and runtime management have been updated to support this new structure, enhancing clarity and usability. Additionally, the `MemoryStepRuntime` and `RealtimeStepRuntime` classes have been improved to better manage runtime states and error reporting, including the implementation of a circuit breaker for resilience under load. This refactor aims to streamline the serving architecture and improve the overall performance and maintainability of the codebase. --- .../serving/advanced/capture-and-runtime.md | 112 ++++-- docs/book/serving/advanced/realtime-tuning.md | 86 +++++ docs/book/serving/overview.md | 51 ++- examples/serving/weather_pipeline.py | 26 +- src/zenml/capture/config.py | 212 ++++++++++++ src/zenml/config/compiler.py | 57 ++-- src/zenml/config/pipeline_configurations.py | 49 +-- src/zenml/deployers/serving/service.py | 31 +- src/zenml/execution/capture_policy.py | 107 ------ src/zenml/execution/factory.py | 29 +- src/zenml/execution/realtime_runtime.py | 171 ++++++++-- src/zenml/execution/step_runtime.py | 58 +++- .../kubernetes_orchestrator_entrypoint.py | 2 +- src/zenml/orchestrators/run_entity_manager.py | 173 ++++++++++ src/zenml/orchestrators/runtime_manager.py | 83 +++++ src/zenml/orchestrators/step_launcher.py | 318 +++++++++--------- src/zenml/orchestrators/step_runner.py | 9 + src/zenml/orchestrators/utils.py | 16 + src/zenml/pipelines/pipeline_decorator.py | 23 +- src/zenml/pipelines/pipeline_definition.py | 35 +- 20 files changed, 1200 insertions(+), 448 deletions(-) create mode 100644 docs/book/serving/advanced/realtime-tuning.md create mode 100644 src/zenml/capture/config.py delete mode 100644 src/zenml/execution/capture_policy.py create mode 100644 src/zenml/orchestrators/run_entity_manager.py create mode 100644 src/zenml/orchestrators/runtime_manager.py diff --git a/docs/book/serving/advanced/capture-and-runtime.md b/docs/book/serving/advanced/capture-and-runtime.md index 1dd2ffae418..06631afeb39 100644 --- a/docs/book/serving/advanced/capture-and-runtime.md +++ b/docs/book/serving/advanced/capture-and-runtime.md @@ -20,48 +20,55 @@ This page explains how capture options map to execution runtimes and how to tune - `flush_on_step_end` controls whether to block at step boundary to flush updates. - In serving with `mode=REALTIME`, `flush_on_step_end` defaults to `false` unless explicitly set. -- OffStepRuntime - - Focus: Lightweight operation with minimal overhead. - - Behavior: Persists artifacts; skips metadata/logs/visualizations/caching (compiler disables these by default in OFF). - - MemoryStepRuntime - Focus: Pure in-memory execution (no server, no persistence). - Behavior: Inter-step data is exchanged via in-process memory handles; no runs or artifacts. - - Configure with REALTIME: `capture={"mode": "REALTIME", "runs": "off"}` or `{"persistence": "memory"}`. + - Configure with REALTIME: `@pipeline(capture=Capture(memory_only=True))`. ## Capture Configuration Where to set: -- In code: `@pipeline(capture=...)` -- In run config YAML: `capture: ...` +- In code: `@pipeline(capture=...)` (typed only) +- In run config YAML: `capture: REALTIME|BATCH` -Supported options (commonly used): -```yaml -capture: - mode: BATCH | REALTIME | OFF | CUSTOM - runs: on | off # off → no runs (memory-only when REALTIME) - persistence: sync | async | memory | off - logs: all | errors-only | off - metadata: true | false - visualization: true | false - cache_enabled: true | false - code: true | false # skip docstring/source capture if false - flush_on_step_end: true | false - ttl_seconds: 600 # Realtime cache TTL - max_entries: 2048 # Realtime cache size bound +Recommended API (typed) +```python +from zenml.capture.config import Capture + +# Not required for defaults, but explicit usage examples: + +# Realtime (default in serving), non-blocking reporting +@pipeline(capture=Capture()) +def serve(...): + ... + +# Realtime, blocking reporting +@pipeline(capture=Capture(flush_on_step_end=True)) + +# Realtime, memory-only (serving only) +@pipeline(capture=Capture(memory_only=True)) ``` Notes: -- `mode` determines the base runtime. -- `runs: off` or `persistence: memory/off` under REALTIME maps to MemoryStepRuntime (pure in-memory execution). -- `flush_on_step_end`: If `false`, serving returns immediately; tracking is published asynchronously by the runtime worker. -- `code: false`: Skips docstring/source capture (metadata), but does not affect code execution. +- Modes are inferred by context (batch vs serving), you only set options: + - `flush_on_step_end`: If `False`, serving returns immediately; tracking is published asynchronously by the runtime worker. + - `memory_only=True`: Pure in-memory execution (no runs/artifacts), serving only. + - `code=False`: Skips docstring/source capture (metadata), but does not affect code execution. ## Serving Defaults - REALTIME + serving context: - - If `flush_on_step_end` is not provided, it defaults to `false` for better latency. - - Users can override by setting `flush_on_step_end: true`. + - If capture is unset, defaults to non-blocking (`flush_on_step_end=False`). + - Users can set `flush_on_step_end=True` to block at step boundary. + +## Validation & Behavior + +- Realtime capture outside serving: + - Allowed for development; logs a warning and continues. In production, use the serving service. +- memory_only outside serving: + - Ignored with a warning; standard execution proceeds (Batch/Realtime as applicable). +- Contradictory options: + - Capture(memory_only=True, flush_on_step_end=True) → raises ValueError. ## Step Operators & Remote Execution @@ -86,17 +93,55 @@ Notes: ## Recipes - Low-latency serving (eventual consistency): - - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` + - `@pipeline(capture=Capture())` - Strict serving (strong consistency): - - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": true})` + - `@pipeline(capture=Capture(flush_on_step_end=True))` - Memory-only (stateless service): - - `@pipeline(capture={"mode": "REALTIME", "runs": "off"})` + - `@pipeline(capture=Capture(memory_only=True))` + +### Control logs/metadata/visualizations (Batch & Realtime) -- Compliance mode: - - `@pipeline(capture="BATCH")` or - - `@pipeline(capture={"mode": "REALTIME", "logs": "all", "metadata": true, "flush_on_step_end": true})` +These are pipeline settings, not capture options. Set them via `pipeline.configure(...)` or YAML: + +```python +@pipeline() +def train(...): + ... + +# In code +train = train.with_options() +train.configure( + enable_step_logs=True, + enable_artifact_metadata=True, + enable_artifact_visualization=False, +) +``` + +Or in run config YAML: + +```yaml +enable_step_logs: true +enable_artifact_metadata: true +enable_artifact_visualization: false +``` + +### Disable code capture (docstring/source) + +Code capture affects metadata only (not execution). You can disable it via capture in both modes: + +```python +from zenml.capture.config import Capture + +@pipeline(capture=Capture(code=False)) +def serve(...): + ... + +@pipeline(capture=Capture(code=False)) +def train(...): + ... +``` ## FAQ @@ -111,4 +156,3 @@ Notes: - Can memory-only work with parallelism? - Memory-only is per-process. For multi-process/multi-container setups, use persistence for cross-process data. - diff --git a/docs/book/serving/advanced/realtime-tuning.md b/docs/book/serving/advanced/realtime-tuning.md new file mode 100644 index 00000000000..abd4bfec946 --- /dev/null +++ b/docs/book/serving/advanced/realtime-tuning.md @@ -0,0 +1,86 @@ +--- +title: Realtime Runtime Tuning & Circuit Breakers +--- + +# Realtime Runtime Tuning & Circuit Breakers + +This page documents advanced environment variables and metrics for tuning the Realtime runtime in production deployments. These knobs let you balance latency, throughput, and resilience under load. + +## When To Use This + +- High-QPS serving pipelines where latency and CPU efficiency matter +- Deployments needing stronger guardrails against cascading failures +- Teams instrumenting detailed metrics (cache hit rate, p95/p99 latencies) + +## Environment Variables + +Cache & Limits + +- `ZENML_RT_CACHE_TTL_SECONDS` (default: `60`) + - TTL for cached artifact values in seconds (in-process cache). +- `ZENML_RT_CACHE_MAX_ENTRIES` (default: `256`) + - LRU cache entry bound to prevent unbounded growth. + +Background Error Reporting + +- `ZENML_RT_ERR_REPORT_INTERVAL` (default: `15`) + - Minimum seconds between repeated background error logs (prevents log spam while maintaining visibility). + +Circuit Breaker (async → inline fallback) + +- `ZENML_RT_CB_ERR_THRESHOLD` (default: `0.1`) + - Error rate threshold to open the breaker (e.g., `0.1` = 10%). +- `ZENML_RT_CB_MIN_EVENTS` (default: `100`) + - Minimum number of publish events to evaluate before opening breaker. +- `ZENML_RT_CB_OPEN_SECONDS` (default: `300`) + - Duration (seconds) to keep breaker open; inline publishing is used while open. + +Capture & Mode (context) + +- `ZENML_CAPTURE_MODE`: default runtime mode from environment (`BATCH|REALTIME`). +- `ZENML_SERVING_CAPTURE_DEFAULT`: when present, serving defaults to `REALTIME` if capture is not set. + +Notes + +- Realtime outside serving logs a warning and continues (for local development). For production serving, run via the serving service. +- YAML/ENV can still set `capture: REALTIME|BATCH` for run configs; code paths are typed-only (`Capture`, `BatchCapture`, `RealtimeCapture`). + +## Metrics & Observability + +`RealtimeStepRuntime.get_metrics()` returns a snapshot of: + +- Queue & Errors: `queued`, `processed`, `failed_total`, `queue_depth` +- Cache: `cache_hits`, `cache_misses`, `cache_hit_rate` +- Latency (op publish): `op_latency_p50_s`, `op_latency_p95_s`, `op_latency_p99_s` +- Config: `ttl_seconds`, `max_entries` + +Recommendation + +- Export metrics to your telemetry system (e.g., Prometheus) and alert on: + - Rising `failed_total` and sustained `queue_depth` + - Low `cache_hit_rate` + - High `op_latency_p95_s` / `op_latency_p99_s` + +## Recommended Production Defaults + +- Start conservative, then tune based on SLOs: + - `ZENML_RT_CACHE_TTL_SECONDS=60` + - `ZENML_RT_CACHE_MAX_ENTRIES=256` + - `ZENML_RT_ERR_REPORT_INTERVAL=15` + - `ZENML_RT_CB_ERR_THRESHOLD=0.1` + - `ZENML_RT_CB_MIN_EVENTS=100` + - `ZENML_RT_CB_OPEN_SECONDS=300` + +## Runbook (Common Scenarios) + +- High background errors: + - Check logs for circuit breaker events. If open, runtime will publish inline. Investigate upstream store or network failures. + - Consider temporarily reducing load or increasing `ZENML_RT_CB_OPEN_SECONDS` while recovering. + +- Rising queue depth / latency: + - Verify artifact store and API latency. + - Reduce cache TTL or size to reduce memory pressure; consider scaling workers. + +- Low cache hit rate: + - Check step dependencies and cache TTL; ensure downstream steps run in the same process to benefit from warm cache. + diff --git a/docs/book/serving/overview.md b/docs/book/serving/overview.md index e813ba2ed45..7c802f90238 100644 --- a/docs/book/serving/overview.md +++ b/docs/book/serving/overview.md @@ -16,11 +16,14 @@ title: Pipeline Serving Overview - Use your normal `@pipeline` and `@step` definitions. - No serving-specific changes required. -2) Choose a capture configuration (recommended) -- Low-latency, non-blocking tracking (serving-friendly): - - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` -- Pure in-memory execution (no runs, no artifacts): - - `@pipeline(capture={"mode": "REALTIME", "runs": "off"})` +2) Choose capture only when you need to change defaults +- You don’t need to set capture in most cases: + - Normal runs default to Batch. + - Serving defaults to Realtime (non-blocking). +- Optional tweaks (typed API only): + - Low-latency, non-blocking (explicit): `@pipeline(capture=Capture())` + - Blocking realtime (serving): `@pipeline(capture=Capture(flush_on_step_end=True))` + - Pure in-memory (serving only): `@pipeline(capture=Capture(memory_only=True))` 3) Deploy the serving service with your preferred deployer and call the FastAPI endpoint. @@ -36,61 +39,49 @@ title: Pipeline Serving Overview - Async server updates by default; in serving, defaults to non-blocking responses (tracking finishes in background). - Use when: You need low-latency serving with observability. -- OFF - - Behavior: Lightweight tracking. - - Persists artifacts but skips metadata/logs/visualizations/caching for reduced overhead. - - Use when: You need a smaller footprint while preserving artifacts for downstream consumers. - - Memory-only (special case inside REALTIME) - - Configure: `capture={"mode": "REALTIME", "runs": "off"}` or `capture={"mode": "REALTIME", "persistence": "memory"}` - Behavior: Pure in-memory execution: - No pipeline runs or step runs, no artifacts, no server calls. - Steps exchange data in-process; response returns immediately. - Use when: Maximum speed (prototyping, ultra-low-latency paths) without lineage. + - Note: Outside serving contexts, `memory_only=True` is ignored with a warning and standard execution proceeds. ## Where To Configure Capture -- In code (recommended) - - `@pipeline(capture="REALTIME")` - - `@pipeline(capture={"mode": "REALTIME", "flush_on_step_end": false})` +- In code (typed only) + - `@pipeline(capture=Capture())` + - `@pipeline(capture=Capture(flush_on_step_end=False))` - In run config YAML ```yaml -capture: REALTIME - -# or - -capture: - mode: REALTIME - flush_on_step_end: false +capture: REALTIME # or BATCH ``` - Environment (fallbacks) - - `ZENML_CAPTURE_MODE=BATCH|REALTIME|OFF|CUSTOM` - - Serving defaults leverage `ZENML_SERVING_CAPTURE_DEFAULT` when capture is not set (used internally to reduce tracking overhead). + - `ZENML_CAPTURE_MODE=BATCH|REALTIME` + - Serving sets `ZENML_SERVING_CAPTURE_DEFAULT` internally to switch default to Realtime when capture is not set. ## Best Practices - Most users (serving-ready) - - `capture={"mode": "REALTIME", "flush_on_step_end": false}` + - `@pipeline(capture=Capture())` - Good balance of immediate response and production tracking. - Maximum speed (no tracking at all) - - `capture={"mode": "REALTIME", "runs": "off"}` (pure in-memory) + - `@pipeline(capture=Capture(memory_only=True))` - Great for tests, benchmarks, or hot paths where lineage is not needed. - Compliance or rich lineage - - `capture="BATCH"` or fine-tune REALTIME with `flush_on_step_end: true`, `logs: "all"`, `metadata: true`. + - Use Batch (default in non-serving) or set: `@pipeline(capture=Capture(flush_on_step_end=True))`. ## FAQ (Essentials) - Does serving always create pipeline runs? - - BATCH/REALTIME/OFF: Yes (OFF reduces overhead of metadata/logs). - - Memory-only (REALTIME with `runs: off`): No; executes purely in memory. + - Batch/Realtime: Yes. + - Memory-only (Realtime with `memory_only=True`): No; executes purely in memory. - Will serving block responses to flush tracking? - - REALTIME in serving defaults to non-blocking (returns immediately), unless you explicitly set `flush_on_step_end: true`. + - Realtime in serving defaults to non-blocking (returns immediately), unless you set `flush_on_step_end=True`. - Is memory-only safe for production? - Yes for stateless, speed-critical paths. Note: No lineage or persisted artifacts. - diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index e44138efadc..f6427322b77 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -12,11 +12,13 @@ Perfect for real-time inference and AI applications. """ +import logging import os import random from typing import Dict from zenml import pipeline, step +from zenml.capture.config import Capture from zenml.config import DockerSettings # Import enums for type-safe capture mode configuration @@ -24,6 +26,9 @@ from zenml.config.resource_settings import ResourceSettings from zenml.steps.step_context import get_step_context +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + # Note: You can use either approach: # 1. String literals: "full", "metadata", "sampled", "errors_only", "none" # 2. Type-safe enums: CaptureMode.FULL, CaptureMode.METADATA, etc. @@ -70,7 +75,7 @@ def init_hook() -> PipelineState: return PipelineState() -@step +@step(enable_cache=False) def get_weather(city: str) -> Dict[str, float]: """Simulate getting weather data for a city. @@ -87,13 +92,14 @@ def get_weather(city: str) -> Dict[str, float]: } -@step +@step(enable_cache=False) def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: """Use LLM to analyze weather and provide intelligent recommendations. In run-only mode, this step receives weather data via in-memory handoff and returns analysis with no database or filesystem writes. """ + import time temp = weather_data["temperature"] humidity = weather_data["humidity"] wind = weather_data["wind_speed"] @@ -103,11 +109,12 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: client = None if pipeline_state: + logger.debug("Pipeline state is a PipelineState") assert isinstance(pipeline_state, PipelineState), ( "Pipeline state is not a PipelineState" ) client = pipeline_state.client - + logger.debug("Client is %s", client) if client: # Create a prompt for the LLM weather_prompt = f"""You are a weather expert AI assistant. Analyze the following weather data for {city} and provide detailed insights and recommendations. @@ -126,9 +133,10 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: 5. Any weather warnings or tips Keep your response concise but informative.""" - + logger.info("[LLM] Starting OpenAI request for city=%s", city) + t0 = time.perf_counter() response = client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-5-mini", messages=[ { "role": "system", @@ -136,9 +144,9 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: }, {"role": "user", "content": weather_prompt}, ], - max_tokens=300, - temperature=0.7, ) + dt = time.perf_counter() - t0 + logger.info("[LLM] OpenAI request finished in %.3fs", dt) llm_analysis = response.choices[0].message.content @@ -213,8 +221,8 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: @pipeline( + capture=Capture(memory_only=True), on_init=init_hook, - capture="realtime", settings={ "docker": docker_settings, "deployer.gcp": { @@ -266,6 +274,8 @@ def weather_agent_pipeline(city: str = "London") -> str: # Create deployment without running deployment = weather_agent_pipeline._create_deployment() + weather_agent_pipeline() + print("\n✅ Pipeline deployed for run-only serving!") print(f"📋 Deployment ID: {deployment.id}") print("\n🚀 Start serving with millisecond latency:") diff --git a/src/zenml/capture/config.py b/src/zenml/capture/config.py new file mode 100644 index 00000000000..243dc7a5e10 --- /dev/null +++ b/src/zenml/capture/config.py @@ -0,0 +1,212 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Capture configuration for ZenML.""" + +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional, Union + + +class ServingMode(str, Enum): + """Serving mode enum.""" + + BATCH = "BATCH" + REALTIME = "REALTIME" + + +# Backwards-compat alias used by runtime factory and others +CaptureMode = ServingMode + + +@dataclass(frozen=True) +class Capture: + """Unified capture configuration with simple, typed options. + + Modes are inferred by context: + - Orchestrated runs default to batch semantics. + - Serving defaults to realtime semantics. + + Options allow tuning behavior without exposing modes directly. + """ + + # If True, block at step end to publish updates (serving only). + flush_on_step_end: bool | None = None + # If True, pure in-memory execution (serving only). + memory_only: bool = False + # If False, skip doc/source capture in metadata. + code: bool = True + + def to_config_value(self) -> Dict[str, Any]: + """Convert the capture options to a config value. + + Returns: + The config value (no explicit mode; inferred by environment). + """ + cfg: Dict[str, Any] = {"code": self.code} + if self.flush_on_step_end is not None: + cfg["flush_on_step_end"] = bool(self.flush_on_step_end) + if self.memory_only: + cfg["memory_only"] = True + return cfg + + +@dataclass(frozen=True) +class BatchCapture: + """Batch (synchronous) capture configuration. + + Runs/steps and artifacts are always captured synchronously. Users should + adjust logging/metadata/visualization via pipeline settings, not capture. + """ + + mode: ServingMode = ServingMode.BATCH + + def to_config_value(self) -> Dict[str, Any]: + """Convert the batch capture to a config value.""" + return {"mode": self.mode.value} + + +@dataclass(frozen=True) +class RealtimeCapture: + """Realtime capture configuration for serving. + + - flush_on_step_end: if True, block at step end to publish updates. + - memory_only: if True, no server calls/runs/artifacts; in-process handoff. + """ + + mode: ServingMode = ServingMode.REALTIME + flush_on_step_end: bool = False + memory_only: bool = False + + def to_config_value(self) -> Dict[str, Any]: + """Convert the realtime capture to a config value. + + Returns: + The config value. + """ + config: Dict[str, Any] = {"mode": self.mode.value} + # Represent semantics using existing keys consumed by launcher/factory + config["flush_on_step_end"] = self.flush_on_step_end + if self.memory_only: + config["memory_only"] = True + return config + + def __post_init__(self) -> None: + """Post init.""" + # Contradictory: memory-only implies no server operations to flush + if self.memory_only and self.flush_on_step_end: + raise ValueError( + "Contradictory options: memory_only=True with flush_on_step_end=True. " + "Memory-only mode has no server operations to flush." + ) + + +# Unified capture config type alias +CaptureConfig = Union[Capture, BatchCapture, RealtimeCapture] + + +class CapturePolicy: + """Runtime-level capture policy used to select and configure runtimes. + + Provides a common interface for StepLauncher / factory code while the + code-level API remains typed (`Capture`, `BatchCapture`, `RealtimeCapture`). + """ + + def __init__( + self, + mode: ServingMode = ServingMode.BATCH, + options: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the capture policy. + + Args: + mode: The mode to use. + options: The options to use. + """ + self.mode = mode + self.options = options or {} + + @staticmethod + def from_env() -> "CapturePolicy": + """Create a capture policy from environment defaults. + + Returns: + The capture policy. + """ + if os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None: + return CapturePolicy(mode=ServingMode.REALTIME, options={}) + val = os.getenv("ZENML_CAPTURE_MODE", "BATCH").upper() + try: + mode = ServingMode(val) + except ValueError: + mode = ServingMode.BATCH + return CapturePolicy(mode=mode, options={}) + + @staticmethod + def from_value( + value: Optional[Union[str, Capture, BatchCapture, RealtimeCapture]], + ) -> "CapturePolicy": + """Normalize typed or string capture value into a runtime policy. + + Args: + value: The value to normalize. + + Returns: + The capture policy. + """ + if value is None: + return CapturePolicy.from_env() + + if isinstance(value, RealtimeCapture): + return CapturePolicy( + mode=ServingMode.REALTIME, + options={ + "flush_on_step_end": value.flush_on_step_end, + "memory_only": bool(value.memory_only), + }, + ) + if isinstance(value, BatchCapture): + return CapturePolicy(mode=ServingMode.BATCH, options={}) + if isinstance(value, Capture): + pol = CapturePolicy.from_env() + opts: Dict[str, Any] = {} + if value.flush_on_step_end is not None: + opts["flush_on_step_end"] = bool(value.flush_on_step_end) + if value.memory_only: + opts["memory_only"] = True + if value.code is not None: + opts["code"] = bool(value.code) + pol.options.update(opts) + return pol + # String fallback (YAML / ENV) + try: + return CapturePolicy(mode=ServingMode(str(value).upper())) + except Exception: + return CapturePolicy.from_env() + + def get_option(self, key: str, default: Any = None) -> Any: + """Get an option from the capture policy. + + Args: + key: The key to get. + default: The default value to return if the key is not found. + + Returns: + The option value. + """ + return self.options.get(key, default) + + +# capture_to_config_value has been removed from code paths. Downstream consumers +# should use typed configs or CapturePolicy.from_value for YAML/ENV strings. diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index ee959b24c3d..ca31bbf52ea 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -23,9 +23,11 @@ Mapping, Optional, Tuple, + Union, ) from zenml import __version__ +from zenml.capture.config import BatchCapture, RealtimeCapture from zenml.config.base_settings import BaseSettings, ConfigurationLevel from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.config.pipeline_run_configuration import PipelineRunConfiguration @@ -196,13 +198,42 @@ def _apply_run_configuration( config: The run configurations. """ with pipeline.__suppress_configure_warnings__(): + # Normalize run-level capture (str/dict) to typed for configure + cap_typed: Optional[Union[BatchCapture, RealtimeCapture]] = None + if isinstance(config.capture, str): + if config.capture.upper() == "REALTIME": + from zenml.capture.config import RealtimeCapture + + cap_typed = RealtimeCapture() + elif config.capture.upper() == "BATCH": + from zenml.capture.config import BatchCapture + + cap_typed = BatchCapture() + elif isinstance(config.capture, dict): + mode = str(config.capture.get("mode", "BATCH")).upper() + if mode == "REALTIME": + from zenml.capture.config import RealtimeCapture + + cap_typed = RealtimeCapture( + flush_on_step_end=bool( + config.capture.get("flush_on_step_end", False) + ), + memory_only=bool( + config.capture.get("memory_only", False) + ), + ) + else: + from zenml.capture.config import BatchCapture + + cap_typed = BatchCapture() + pipeline.configure( enable_cache=config.enable_cache, enable_artifact_metadata=config.enable_artifact_metadata, enable_artifact_visualization=config.enable_artifact_visualization, enable_step_logs=config.enable_step_logs, enable_pipeline_logs=config.enable_pipeline_logs, - capture=config.capture, + capture=cap_typed, settings=config.settings, tags=config.tags, extra=config.extra, @@ -211,30 +242,6 @@ def _apply_run_configuration( parameters=config.parameters, ) - # Apply additional defaults based on capture mode - try: - capture_cfg = pipeline.configuration.capture - mode_str = None - if isinstance(capture_cfg, str): - mode_str = capture_cfg.upper() - elif isinstance(capture_cfg, dict): - mode = capture_cfg.get("mode") - if isinstance(mode, str): - mode_str = mode.upper() - if mode_str == "OFF": - # Disable overhead while keeping correctness - with pipeline.__suppress_configure_warnings__(): - pipeline.configure( - enable_cache=False, - enable_artifact_metadata=False, - enable_artifact_visualization=False, - enable_step_logs=False, - enable_pipeline_logs=False, - ) - except Exception: - # Non-fatal; leave configuration as-is - pass - invalid_step_configs = set(config.steps) - set(pipeline.invocations) if invalid_step_configs: logger.warning( diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index a4d1689f4c5..7a8b0726e53 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -18,6 +18,12 @@ from pydantic import SerializeAsAny, field_validator +from zenml.capture.config import ( + BatchCapture, + Capture, + CaptureConfig, + RealtimeCapture, +) from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY from zenml.config.retry_config import StepRetryConfig from zenml.config.source import SourceWithValidator @@ -42,9 +48,8 @@ class PipelineConfigurationUpdate(StrictBaseModel): enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None - # Capture policy mode for execution semantics (e.g., BATCH, REALTIME, OFF, CUSTOM) - # Capture policy can be a mode string or a dict with options - capture: Optional[Union[str, Dict[str, Any]]] = None + # Capture policy for execution semantics (typed only) + capture: Optional[CaptureConfig] = None settings: Dict[str, SerializeAsAny[BaseSettings]] = {} tags: Optional[List[Union[str, "Tag"]]] = None extra: Dict[str, Any] = {} @@ -91,40 +96,16 @@ class PipelineConfiguration(PipelineConfigurationUpdate): @field_validator("capture") @classmethod def validate_capture_mode( - cls, value: Optional[Union[str, Dict[str, Any]]] - ) -> Optional[Union[str, Dict[str, Any]]]: - """Validates the capture mode. - - Args: - value: The capture mode to validate. - - Returns: - The validated capture mode. - """ + cls, value: Optional[CaptureConfig] + ) -> Optional[CaptureConfig]: + """Validates the capture config (typed only).""" if value is None: return value - if isinstance(value, dict): - mode = value.get("mode") - if mode is None: - # default to BATCH if mode not provided - value = {**value, "mode": "BATCH"} - mode = "BATCH" - allowed = {"BATCH", "REALTIME", "OFF", "CUSTOM"} - if str(mode).upper() not in allowed: - raise ValueError( - f"Invalid capture mode '{mode}'. Allowed: {sorted(allowed)}" - ) - # normalize mode to upper - value = {**value, "mode": str(mode).upper()} + if isinstance(value, (Capture, BatchCapture, RealtimeCapture)): return value - else: - allowed = {"BATCH", "REALTIME", "OFF", "CUSTOM"} - v = str(value).upper() - if v not in allowed: - raise ValueError( - f"Invalid capture mode '{value}'. Allowed: {sorted(allowed)}" - ) - return v + raise ValueError( + "'capture' must be a typed Capture, BatchCapture, or RealtimeCapture." + ) @field_validator("name") @classmethod diff --git a/src/zenml/deployers/serving/service.py b/src/zenml/deployers/serving/service.py index bcb0face929..f3384d4797e 100644 --- a/src/zenml/deployers/serving/service.py +++ b/src/zenml/deployers/serving/service.py @@ -19,6 +19,8 @@ """ import asyncio +import json +import os import time from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional @@ -327,6 +329,14 @@ async def execute_pipeline( try: # Resolve request parameters resolved_params = self._resolve_parameters(parameters) + # Expose resolved params to launcher/runner via env for memory-only path + os.environ["ZENML_SERVING_REQUEST_PARAMS"] = json.dumps( + resolved_params + ) + # Expose pipeline state via serving context var + from zenml.orchestrators import utils as _orch_utils + + _orch_utils.set_pipeline_state(self.pipeline_state) # Get deployment and check if we're in no-capture mode deployment = self.deployment @@ -334,9 +344,6 @@ async def execute_pipeline( deployment.pipeline_configuration.settings ) - # Set serving capture default for this request (no model mutations needed) - import os - original_capture_default = os.environ.get( "ZENML_SERVING_CAPTURE_DEFAULT" ) @@ -374,8 +381,9 @@ async def execute_pipeline( orchestrator = stack.orchestrator # Ensure a stable run id for StepLauncher to reuse the same PipelineRun + run_uuid = str(uuid4()) if hasattr(orchestrator, "_orchestrator_run_id"): - setattr(orchestrator, "_orchestrator_run_id", str(uuid4())) + setattr(orchestrator, "_orchestrator_run_id", run_uuid) # No serving overrides population in local orchestrator path @@ -394,6 +402,21 @@ async def execute_pipeline( os.environ["ZENML_SERVING_CAPTURE_DEFAULT"] = ( original_capture_default ) + # Clear request params env and shared runtime state + os.environ.pop("ZENML_SERVING_REQUEST_PARAMS", None) + from zenml.orchestrators.utils import set_pipeline_state + + set_pipeline_state(None) + try: + from zenml.orchestrators.runtime_manager import ( + clear_shared_runtime, + reset_memory_runtime_for_run, + ) + + reset_memory_runtime_for_run(run_uuid) + clear_shared_runtime() + except Exception: + pass # Get captured outputs from response tap outputs = orchestrator_utils.response_tap_get_all() diff --git a/src/zenml/execution/capture_policy.py b/src/zenml/execution/capture_policy.py deleted file mode 100644 index 32786dbe5ff..00000000000 --- a/src/zenml/execution/capture_policy.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) ZenML GmbH 2025. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Capture policy scaffolding and presets. - -This is a lightweight placeholder to enable runtime selection without changing -public pipeline APIs. The capture mode can be controlled via the environment -variable `ZENML_CAPTURE_MODE` with values: `BATCH` (default), `REALTIME`, -`OFF`, or `CUSTOM`. -""" - -import os -from enum import Enum -from typing import Any, Dict, Optional, Union - - -class CaptureMode(str, Enum): - """Capture mode enum.""" - - BATCH = "BATCH" - REALTIME = "REALTIME" - OFF = "OFF" - CUSTOM = "CUSTOM" - - -class CapturePolicy: - """Minimal capture policy container with optional options.""" - - def __init__( - self, - mode: CaptureMode = CaptureMode.BATCH, - options: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize the capture policy. - - Args: - mode: The capture mode. - options: The capture options. - """ - self.mode = mode - self.options = options or {} - - @staticmethod - def from_env() -> "CapturePolicy": - """Create a capture policy from the environment. - - Returns: - The capture policy. - """ - val = os.getenv("ZENML_CAPTURE_MODE", "BATCH").upper() - try: - mode = CaptureMode(val) - except ValueError: - mode = CaptureMode.BATCH - # No options provided from env here; runtimes may read env as fallback - return CapturePolicy(mode=mode, options={}) - - @staticmethod - def from_value( - value: Optional[Union[str, Dict[str, Any]]], - ) -> "CapturePolicy": - """Create a capture policy from a value. - - Args: - value: The value to create the capture policy from. - - Returns: - The capture policy. - """ - if value is None: - return CapturePolicy.from_env() - if isinstance(value, dict): - mode = str(value.get("mode", "BATCH")).upper() - try: - cm = CaptureMode(mode) - except Exception: - cm = CaptureMode.BATCH - # store other keys as options - options = {k: v for k, v in value.items() if k != "mode"} - return CapturePolicy(mode=cm, options=options) - else: - try: - return CapturePolicy(mode=CaptureMode(str(value).upper())) - except Exception: - return CapturePolicy.from_env() - - def get_option(self, key: str, default: Any = None) -> Any: - """Get an option from the capture policy. - - Args: - key: The key of the option to get. - default: The default value to return if the option is not found. - - Returns: - The option value. - """ - return self.options.get(key, default) diff --git a/src/zenml/execution/factory.py b/src/zenml/execution/factory.py index 067b24585b3..7d5a5f00275 100644 --- a/src/zenml/execution/factory.py +++ b/src/zenml/execution/factory.py @@ -15,12 +15,11 @@ from typing import Callable, Dict, Optional -from zenml.execution.capture_policy import CaptureMode, CapturePolicy +from zenml.capture.config import CaptureMode, CapturePolicy from zenml.execution.step_runtime import ( BaseStepRuntime, DefaultStepRuntime, MemoryStepRuntime, - OffStepRuntime, ) # Registry of runtime builders keyed by capture mode @@ -49,9 +48,13 @@ def get_runtime(policy: Optional[CapturePolicy]) -> BaseStepRuntime: """ policy = policy or CapturePolicy() builder = _RUNTIME_REGISTRY.get(policy.mode) - if builder is not None: - return builder(policy) - return DefaultStepRuntime() + if builder is None: + raise ValueError( + f"No runtime registered for capture mode: {policy.mode}. " + "Expected one of: " + + ", ".join(m.name for m in _RUNTIME_REGISTRY.keys()) + ) + return builder(policy) # Register default builders @@ -67,11 +70,6 @@ def _build_default(_: CapturePolicy) -> BaseStepRuntime: return DefaultStepRuntime() -def _build_off(_: CapturePolicy) -> BaseStepRuntime: - """Build the off runtime (lightweight: persist artifacts, skip metadata).""" - return OffStepRuntime() - - def _build_realtime(policy: CapturePolicy) -> BaseStepRuntime: """Build the realtime runtime. @@ -84,10 +82,15 @@ def _build_realtime(policy: CapturePolicy) -> BaseStepRuntime: # Import here to avoid circular imports from zenml.execution.realtime_runtime import RealtimeStepRuntime - # If runs are off or persistence is memory/off, use memory runtime + # If memory_only flagged, or legacy runs/persistence indicate memory-only, use memory runtime + memory_only = bool(policy.get_option("memory_only", False)) runs_opt = str(policy.get_option("runs", "on")).lower() persistence = str(policy.get_option("persistence", "async")).lower() - if runs_opt in {"off", "false", "0"} or persistence in {"memory", "off"}: + if ( + memory_only + or runs_opt in {"off", "false", "0"} + or persistence in {"memory", "off"} + ): return MemoryStepRuntime() ttl = policy.get_option("ttl_seconds") @@ -96,6 +99,4 @@ def _build_realtime(policy: CapturePolicy) -> BaseStepRuntime: register_runtime(CaptureMode.BATCH, _build_default) -register_runtime(CaptureMode.CUSTOM, _build_default) -register_runtime(CaptureMode.OFF, _build_off) register_runtime(CaptureMode.REALTIME, _build_realtime) diff --git a/src/zenml/execution/realtime_runtime.py b/src/zenml/execution/realtime_runtime.py index 61da826e0ef..83c9ae69580 100644 --- a/src/zenml/execution/realtime_runtime.py +++ b/src/zenml/execution/realtime_runtime.py @@ -55,19 +55,25 @@ def __init__( """ super().__init__() # Simple LRU cache with TTL - self._cache: "OrderedDict[str, Tuple[Any, float]]" = OrderedDict() + self._cache: OrderedDict[str, Tuple[Any, float]] = OrderedDict() self._lock = threading.RLock() # Event queue: (kind, args, kwargs) Event = Tuple[str, Tuple[Any, ...], Dict[str, Any]] - self._q: "queue.Queue[Event]" = queue.Queue() + self._q: queue.Queue[Event] = queue.Queue() self._worker: Optional[threading.Thread] = None self._stop = threading.Event() self._errors_since_last_flush: int = 0 self._total_errors: int = 0 self._last_error: Optional[BaseException] = None + self._error_reported: bool = False + self._last_report_ts: float = 0.0 self._logger = get_logger(__name__) self._queued_count: int = 0 self._processed_count: int = 0 + # Metrics: cache and op latencies + self._cache_hits: int = 0 + self._cache_misses: int = 0 + self._op_latencies: List[float] = [] # Tunables via env: TTL seconds and max entries # Options precedence: explicit args > env > defaults if ttl_seconds is not None: @@ -75,19 +81,44 @@ def __init__( else: try: self._ttl_seconds = int( - os.getenv("ZENML_RT_CACHE_TTL_SECONDS", "300") + os.getenv("ZENML_RT_CACHE_TTL_SECONDS", "60") ) except Exception: - self._ttl_seconds = 300 + self._ttl_seconds = 60 if max_entries is not None: self._max_entries = int(max_entries) else: try: self._max_entries = int( - os.getenv("ZENML_RT_CACHE_MAX_ENTRIES", "1024") + os.getenv("ZENML_RT_CACHE_MAX_ENTRIES", "256") ) except Exception: - self._max_entries = 1024 + self._max_entries = 256 + # Circuit breaker controls + try: + self._cb_threshold = float( + os.getenv("ZENML_RT_CB_ERR_THRESHOLD", "0.1") + ) + self._cb_min_events = int( + os.getenv("ZENML_RT_CB_MIN_EVENTS", "100") + ) + self._cb_open_seconds = float( + os.getenv("ZENML_RT_CB_OPEN_SECONDS", "300") + ) + except Exception: + self._cb_threshold = 0.1 + self._cb_min_events = 100 + self._cb_open_seconds = 300.0 + self._cb_errors_window: int = 0 + self._cb_total_window: int = 0 + self._cb_open_until_ts: float = 0.0 + # Error report interval (seconds) + try: + self._err_report_interval = float( + os.getenv("ZENML_RT_ERR_REPORT_INTERVAL", "15") + ) + except Exception: + self._err_report_interval = 15.0 # Flush behavior (can be disabled for serving non-blocking) self._flush_on_step_end: bool = True @@ -98,14 +129,17 @@ def start(self) -> None: return def _run() -> None: + idle_sleep = 0.05 while not self._stop.is_set(): try: - kind, args, kwargs = self._q.get(timeout=0.1) + kind, args, kwargs = self._q.get(timeout=idle_sleep) except queue.Empty: # Opportunistic cache sweep: evict expired from head self._sweep_expired() + idle_sleep = min(idle_sleep * 2.0, 2.0) continue try: + start = time.time() if kind == "pipeline_metadata": publish_utils.publish_pipeline_run_metadata( *args, **kwargs @@ -131,7 +165,21 @@ def _run() -> None: finally: with self._lock: self._processed_count += 1 + # Update circuit breaker window + self._cb_total_window += 1 + if self._last_error is not None: + self._cb_errors_window += 1 + # Record latency (bounded sample) + try: + self._op_latencies.append( + max(0.0, time.time() - start) + ) + if len(self._op_latencies) > 512: + self._op_latencies = self._op_latencies[-256:] + except Exception: + pass self._q.task_done() + idle_sleep = 0.01 self._worker = threading.Thread( target=_run, name="zenml-realtime-runtime", daemon=True @@ -169,6 +217,7 @@ def load_input_artifact( if now <= expires_at: # Touch entry for LRU self._cache.move_to_end(key) + self._cache_hits += 1 return value else: # Expired @@ -176,6 +225,7 @@ def load_input_artifact( del self._cache[key] except KeyError: pass + self._cache_misses += 1 # Fallback to default loading return super().load_input_artifact( @@ -246,7 +296,13 @@ def publish_pipeline_run_metadata( pipeline_run_id: The pipeline run ID. pipeline_run_metadata: The pipeline run metadata. """ - # Enqueue for async processing + # Inline if circuit open, else enqueue + if self._should_process_inline(): + publish_utils.publish_pipeline_run_metadata( + pipeline_run_id=pipeline_run_id, + pipeline_run_metadata=pipeline_run_metadata, + ) + return self._q.put( ( "pipeline_metadata", @@ -272,6 +328,11 @@ def publish_step_run_metadata( step_run_id: The step run ID. step_run_metadata: The step run metadata. """ + if self._should_process_inline(): + publish_utils.publish_step_run_metadata( + step_run_id=step_run_id, step_run_metadata=step_run_metadata + ) + return self._q.put( ( "step_metadata", @@ -297,6 +358,12 @@ def publish_successful_step_run( step_run_id: The step run ID. output_artifact_ids: The output artifact IDs. """ + if self._should_process_inline(): + publish_utils.publish_successful_step_run( + step_run_id=step_run_id, + output_artifact_ids=output_artifact_ids, + ) + return self._q.put( ( "step_success", @@ -320,6 +387,9 @@ def publish_failed_step_run( Args: step_run_id: The step run ID. """ + if self._should_process_inline(): + publish_utils.publish_failed_step_run(step_run_id) + return self._q.put(("step_failed", (), {"step_run_id": step_run_id})) with self._lock: self._queued_count += 1 @@ -368,6 +438,7 @@ def flush(self) -> None: count = self._errors_since_last_flush last = self._last_error self._errors_since_last_flush = 0 + self._error_reported = True raise RuntimeError( f"Realtime runtime encountered {count} error(s) while publishing. Last error: {last}" ) @@ -421,10 +492,26 @@ def get_metrics(self) -> Dict[str, Any]: failed_total = self._total_errors ttl_seconds = getattr(self, "_ttl_seconds", None) max_entries = getattr(self, "_max_entries", None) + cache_hits = self._cache_hits + cache_misses = self._cache_misses + latencies = list(self._op_latencies) try: depth = self._q.qsize() except Exception: depth = 0 + # Compute simple percentiles + p50 = p95 = p99 = 0.0 + if latencies: + s = sorted(latencies) + n = len(s) + p50 = s[int(0.5 * (n - 1))] + p95 = s[int(0.95 * (n - 1))] + p99 = s[int(0.99 * (n - 1))] + hit_rate = ( + float(cache_hits) / float(cache_hits + cache_misses) + if (cache_hits + cache_misses) > 0 + else 0.0 + ) return { "queued": queued, "processed": processed, @@ -432,25 +519,67 @@ def get_metrics(self) -> Dict[str, Any]: "queue_depth": depth, "ttl_seconds": ttl_seconds, "max_entries": max_entries, + "cache_hits": cache_hits, + "cache_misses": cache_misses, + "cache_hit_rate": hit_rate, + "op_latency_p50_s": p50, + "op_latency_p95_s": p95, + "op_latency_p99_s": p99, } + # Surface background errors even when not flushing + def check_async_errors(self) -> None: + """Log and mark any background errors on an interval.""" + with self._lock: + if self._last_error: + now = time.time() + if (not self._error_reported) or ( + now - self._last_report_ts > self._err_report_interval + ): + self._logger.error( + "Background realtime runtime error: %s", + self._last_error, + ) + self._error_reported = True + self._last_report_ts = now + # --- internal helpers --- def _sweep_expired(self) -> None: - """Remove expired entries from the head (LRU) side.""" + """Remove expired entries from the head (LRU) side with a time budget.""" + deadline = time.time() + 0.003 with self._lock: - now = time.time() - # Pop from head while expired - keys = list(self._cache.keys()) - for k in keys[:32]: # limit per sweep to bound work + while time.time() < deadline: + try: + key = next(iter(self._cache)) + except StopIteration: + break try: - value, expires_at = self._cache[k] + _, expires_at = self._cache[key] except KeyError: continue - if now > expires_at: - try: - del self._cache[k] - except KeyError: - pass - else: - # Stop at first non-expired near head + if time.time() <= expires_at: break + try: + del self._cache[key] + except KeyError: + pass + + def _should_process_inline(self) -> bool: + """Return True if circuit breaker is open and we should publish inline.""" + with self._lock: + now = time.time() + if now < self._cb_open_until_ts: + return True + total = self._cb_total_window + errors = self._cb_errors_window + if total >= self._cb_min_events: + err_rate = (float(errors) / float(total)) if total > 0 else 0.0 + if err_rate >= self._cb_threshold: + self._cb_open_until_ts = now + self._cb_open_seconds + self._logger.warning( + "Realtime runtime circuit opened for %.0fs due to error rate %.2f", + self._cb_open_seconds, + err_rate, + ) + return True + return False diff --git a/src/zenml/execution/step_runtime.py b/src/zenml/execution/step_runtime.py index 3df62f6b3bb..ca8bb18202e 100644 --- a/src/zenml/execution/step_runtime.py +++ b/src/zenml/execution/step_runtime.py @@ -494,8 +494,10 @@ def publish_step_run_metadata( class MemoryStepRuntime(BaseStepRuntime): """Pure in-memory execution runtime: no server calls, no persistence.""" + # Global registry keyed by run_id to isolate concurrent runs _STORE: Dict[str, Dict[Tuple[str, str], Any]] = {} - _LOCK: Any = threading.RLock() # initialized at class load + _RUN_LOCKS: Dict[str, Any] = {} + _GLOBAL_LOCK: Any = threading.RLock() # protects registry structures @staticmethod def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: @@ -554,6 +556,7 @@ def __init__(self) -> None: super().__init__() self._ctx_run_id: Optional[str] = None self._ctx_substitutions: Dict[str, str] = {} + self._active_run_ids: set[str] = set() def set_context( self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None @@ -566,9 +569,18 @@ def set_context( """ self._ctx_run_id = run_id self._ctx_substitutions = substitutions or {} + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass def resolve_step_inputs( - self, *, step, pipeline_run, step_runs=None + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, ) -> Dict[str, Any]: """Resolve step inputs by constructing in-memory handles. @@ -614,7 +626,12 @@ def load_input_artifact( if not isinstance(handle_id_any, str): raise ValueError("Invalid memory handle id") run_id, step_name, output_name = self.parse_handle_id(handle_id_any) - with MemoryStepRuntime._LOCK: + # Use per-run lock to avoid cross-run interference + with MemoryStepRuntime._GLOBAL_LOCK: + rlock = MemoryStepRuntime._RUN_LOCKS.setdefault( + run_id, threading.RLock() + ) + with rlock: return MemoryStepRuntime._STORE.get(run_id, {}).get( (step_name, output_name) ) @@ -646,9 +663,18 @@ def store_output_artifacts( ctx = get_step_context() run_id = str(getattr(ctx.pipeline_run, "id", "local")) + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass step_name = str(getattr(ctx.step_run, "name", "step")) handles: Dict[str, Any] = {} - with MemoryStepRuntime._LOCK: + with MemoryStepRuntime._GLOBAL_LOCK: + rlock = MemoryStepRuntime._RUN_LOCKS.setdefault( + run_id, threading.RLock() + ) + with rlock: rr = MemoryStepRuntime._STORE.setdefault(run_id, {}) for output_name, value in output_data.items(): rr[(step_name, output_name)] = value @@ -748,3 +774,27 @@ def on_step_end(self) -> None: def shutdown(self) -> None: """Shutdown the memory runtime.""" return + + def __del__(self) -> None: # noqa: D401 + """Best-effort cleanup of per-run memory when GC collects the runtime.""" + try: + for run_id in list(self._active_run_ids): + try: + self.reset(run_id) + except Exception: + pass + except Exception: + pass + + # --- Unified path helpers --- + def reset(self, run_id: str) -> None: + """Clear all in-memory data associated with a specific run. + + Args: + run_id: The run id to clear. + """ + with MemoryStepRuntime._GLOBAL_LOCK: + try: + MemoryStepRuntime._STORE.pop(run_id, None) + finally: + MemoryStepRuntime._RUN_LOCKS.pop(run_id, None) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 78e103c0977..3305e21dd70 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -310,7 +310,7 @@ def main() -> None: # Build a runtime for request factory using capture mode from config try: - from zenml.execution.capture_policy import CapturePolicy + from zenml.capture.config import CapturePolicy from zenml.execution.factory import get_runtime mode_cfg = getattr( diff --git a/src/zenml/orchestrators/run_entity_manager.py b/src/zenml/orchestrators/run_entity_manager.py new file mode 100644 index 00000000000..3dc9ce7bf25 --- /dev/null +++ b/src/zenml/orchestrators/run_entity_manager.py @@ -0,0 +1,173 @@ +"""Run entity manager scaffolding for unified execution. + +Abstracts creation and finalization of pipeline/step runs so we can plug in +either DB-backed behavior or stubbed in-memory entities for memory-only runs. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Protocol, Tuple, cast + + +class RunEntityManager(Protocol): + """Protocol for managing pipeline/step run entities.""" + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful.""" + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed.""" + + +@dataclass +class DefaultRunEntityManager: + """Placeholder for DB-backed manager (to be wired in Phase 2).""" + + launcher: Any + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + return cast(Tuple[Any, bool], self.launcher._create_or_reuse_run()) + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + from zenml.client import Client + + return Client().zen_store.create_run_step(request) + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful. + + Args: + step_run_id: The step run ID. + outputs: The outputs of the step run. + """ + # Defer to runtime publish for now. + return None + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed. + + Args: + step_run_id: The step run ID. + """ + # Defer to runtime publish for now. + return None + + +@dataclass +class MemoryRunEntityManager: + """Stubbed manager for memory-only execution (Phase 2 wiring).""" + + launcher: Any + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + # Build a minimal pipeline run stub compatible with StepRunner expectations + run_id = self.launcher._orchestrator_run_id # noqa: SLF001 + + @dataclass + class _PRCfg: + tags: Any = None + enable_step_logs: Any = False + enable_artifact_metadata: Any = False + enable_artifact_visualization: Any = False + + @dataclass + class _PipelineRunStub: + id: str + name: str + model_version: Any = None + pipeline: Any = None + config: Any = _PRCfg() + + return _PipelineRunStub(id=run_id, name=run_id), True + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + # Return a minimal step run stub + run_id = self.launcher._orchestrator_run_id # noqa: SLF001 + step_name = self.launcher._step_name # noqa: SLF001 + + @dataclass + class _StatusStub: + is_finished: bool = False + + @dataclass + class _StepRunStub: + id: str + name: str + model_version: Any = None + logs: Optional[Any] = None + status: Any = _StatusStub() + outputs: Dict[str, Any] = None # type: ignore[assignment] + regular_inputs: Dict[str, Any] = None # type: ignore[assignment] + + def __post_init__(self) -> None: # noqa: D401 + self.outputs = {} + self.regular_inputs = {} + + return _StepRunStub(id=run_id, name=step_name) + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful. + + Args: + step_run_id: The step run ID. + outputs: The outputs of the step run. + """ + return None + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed. + + Args: + step_run_id: The step run ID. + """ + return None diff --git a/src/zenml/orchestrators/runtime_manager.py b/src/zenml/orchestrators/runtime_manager.py new file mode 100644 index 00000000000..4c2fa7602a1 --- /dev/null +++ b/src/zenml/orchestrators/runtime_manager.py @@ -0,0 +1,83 @@ +"""Runtime manager for unified runtime-driven execution paths. + +Provides helpers to reuse a shared runtime instance across all steps of a +single serving request (e.g., MemoryStepRuntime for memory-only execution), +and utilities to reset per-run state when the request completes. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Optional + +from zenml.execution.step_runtime import BaseStepRuntime, MemoryStepRuntime + +# Shared runtime context for the lifetime of a single request. +_shared_runtime: ContextVar[Optional[BaseStepRuntime]] = ContextVar( + "zenml_shared_runtime", default=None +) + + +def set_shared_runtime(runtime: BaseStepRuntime) -> None: + """Set a runtime instance to be reused across steps for the current request. + + Args: + runtime: The runtime instance to set. + """ + _shared_runtime.set(runtime) + + +def get_shared_runtime() -> Optional[BaseStepRuntime]: + """Get the shared runtime instance for the current request, if any. + + Returns: + The shared runtime instance for the current request, if any. + """ + return _shared_runtime.get() + + +def clear_shared_runtime() -> None: + """Clear the shared runtime instance for the current request. + + Returns: + The shared runtime instance for the current request, if any. + """ + _shared_runtime.set(None) + + +def get_or_create_shared_memory_runtime() -> MemoryStepRuntime: + """Get or create a shared MemoryStepRuntime for the current request. + + Returns: + The shared runtime instance for the current request, if any. + """ + rt = _shared_runtime.get() + if isinstance(rt, MemoryStepRuntime): + return rt + mem = MemoryStepRuntime() + set_shared_runtime(mem) + return mem + + +def reset_memory_runtime_for_run(run_id: str) -> None: + """Reset per-run memory state on the shared memory runtime if present. + + Args: + run_id: The run ID. + """ + rt = _shared_runtime.get() + if isinstance(rt, MemoryStepRuntime): + try: + rt.reset(run_id) + except Exception as e: + # Best-effort cleanup; log at debug level and continue + try: + from zenml.logger import get_logger + + get_logger(__name__).debug( + "Ignoring error during memory runtime reset for run %s: %s", + run_id, + e, + ) + except Exception: + pass diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 43ad1d2e18f..3f7d531f91f 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -13,11 +13,14 @@ # permissions and limitations under the License. """Class to launch (run directly or using a step operator) steps.""" +import json +import os import signal import time from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from zenml.capture.config import CapturePolicy from zenml.client import Client from zenml.config.step_configurations import Step from zenml.config.step_run_info import StepRunInfo @@ -29,8 +32,8 @@ from zenml.enums import ExecutionStatus from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException -from zenml.execution.capture_policy import CapturePolicy from zenml.execution.factory import get_runtime +from zenml.execution.step_runtime import MemoryStepRuntime from zenml.logger import get_logger from zenml.logging import step_logging from zenml.models import ( @@ -43,6 +46,14 @@ from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators import output_utils, publish_utils, step_run_utils from zenml.orchestrators import utils as orchestrator_utils +from zenml.orchestrators.run_entity_manager import ( + DefaultRunEntityManager, + MemoryRunEntityManager, + RunEntityManager, +) +from zenml.orchestrators.runtime_manager import ( + get_or_create_shared_memory_runtime, +) from zenml.orchestrators.step_runner import StepRunner from zenml.stack import Stack from zenml.utils import exception_utils, string_utils @@ -217,7 +228,7 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal_handler) except ValueError as e: # This happens when not in the main thread - logger.debug(f"Cannot register signal handlers: {e}") + logger.debug("Cannot register signal handlers: %s", e) # Continue without signal handling - the step will still run def launch(self) -> None: @@ -243,27 +254,36 @@ def launch(self) -> None: runtime = get_runtime(capture_policy) # Store for later use self._runtime = runtime - runs_opt = str(capture_policy.get_option("runs", "on")).lower() - persistence = str( - capture_policy.get_option("persistence", "async") - ).lower() - memory_only = runs_opt in {"off", "false", "0"} or persistence in { - "memory", - "off", - } - - if memory_only: - self._launch_memory_only() - return - pipeline_run, run_was_created = self._create_or_reuse_run() + # Serving context detection + in_serving_ctx = os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None + memory_only = bool(capture_policy.get_option("memory_only", False)) + # Debug messages to clarify behavior + if capture_policy.mode.name == "REALTIME" and not in_serving_ctx: + logger.warning( + "REALTIME mode enabled outside serving (development). Performance/ordering may vary." + ) + if memory_only and not in_serving_ctx: + # Ignore memory_only outside serving: fall back to normal (Batch/Realtime) behavior + logger.warning( + "memory_only=True requested outside serving; ignoring and proceeding with standard execution." + ) + memory_only = False - # runtime already constructed above; configure flush behavior - # Default for serving (REALTIME): do not flush at step end unless user specifies - import os as _os + # Select entity manager and, if memory-only, set up shared runtime + is_memory_only_path = memory_only and in_serving_ctx + # Declare entity manager type for typing + entity_manager: RunEntityManager + if is_memory_only_path: + try: + shared = get_or_create_shared_memory_runtime() + self._runtime = shared + except Exception: + pass + entity_manager = MemoryRunEntityManager(self) + else: + entity_manager = DefaultRunEntityManager(self) - in_serving_ctx = ( - _os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None - ) + pipeline_run, run_was_created = entity_manager.create_or_reuse_run() if ( capture_policy.mode.name == "REALTIME" and "flush_on_step_end" @@ -336,17 +356,19 @@ def launch(self) -> None: code_opt = capture_policy.get_option("code", True) code_enabled = str(code_opt).lower() not in {"false", "0", "off"} - request_factory = step_run_utils.StepRunRequestFactory( - deployment=self._deployment, - pipeline_run=pipeline_run, - stack=self._stack, - runtime=runtime, - skip_code_capture=not code_enabled, - ) - step_run_request = request_factory.create_request( - invocation_id=self._step_name - ) - step_run_request.logs = logs_model + # Prepare step run creation + if isinstance(entity_manager, DefaultRunEntityManager): + request_factory = step_run_utils.StepRunRequestFactory( + deployment=self._deployment, + pipeline_run=pipeline_run, + stack=self._stack, + runtime=runtime, + skip_code_capture=not code_enabled, + ) + step_run_request = request_factory.create_request( + invocation_id=self._step_name + ) + step_run_request.logs = logs_model # If this step has upstream dependencies and runtime uses non-blocking # publishes, ensure previous step updates are flushed so input @@ -364,24 +386,30 @@ def launch(self) -> None: try: # Always populate request to ensure proper input/output flow - request_factory.populate_request(request=step_run_request) + if isinstance(entity_manager, DefaultRunEntityManager): + request_factory.populate_request(request=step_run_request) # In no-capture mode, force fresh execution (bypass cache) if tracking_disabled: - step_run_request.original_step_run_id = None - step_run_request.outputs = {} - step_run_request.status = ExecutionStatus.RUNNING + if isinstance(entity_manager, DefaultRunEntityManager): + step_run_request.original_step_run_id = None + step_run_request.outputs = {} + step_run_request.status = ExecutionStatus.RUNNING except BaseException as e: - logger.exception(f"Failed preparing step `{self._step_name}`.") - step_run_request.status = ExecutionStatus.FAILED - step_run_request.end_time = utc_now() - step_run_request.exception_info = ( - exception_utils.collect_exception_information(e) - ) + logger.exception("Failed preparing step `%s`.", self._step_name) + if isinstance(entity_manager, DefaultRunEntityManager): + step_run_request.status = ExecutionStatus.FAILED + step_run_request.end_time = utc_now() + step_run_request.exception_info = ( + exception_utils.collect_exception_information(e) + ) raise finally: - # Always create real step run for proper input/output flow - step_run = Client().zen_store.create_run_step(step_run_request) + # Create step run (DB-backed or stubbed) + if isinstance(entity_manager, DefaultRunEntityManager): + step_run = entity_manager.create_step_run(step_run_request) + else: + step_run = entity_manager.create_step_run(None) self._step_run = step_run if not tracking_disabled and ( model_version := step_run.model_version @@ -422,9 +450,25 @@ def _bypass() -> None: e, ) if not tracking_disabled: - runtime.publish_failed_step_run(step_run_id=step_run.id) + # Delegate finalization to entity manager (DB-backed or no-op) + try: + entity_manager.finalize_step_run_failed(step_run.id) + except Exception: + try: + runtime.publish_failed_step_run( + step_run_id=step_run.id + ) + except Exception: + pass if runtime.should_flush_on_step_end(): runtime.flush() + else: + try: + getattr( + runtime, "check_async_errors", lambda: None + )() + except Exception: + pass raise else: logger.info(f"Using cached version of step `{self._step_name}`.") @@ -440,6 +484,18 @@ def _bypass() -> None: # Ensure any queued updates are flushed for cached path (if enabled) if runtime.should_flush_on_step_end(): runtime.flush() + else: + try: + getattr(runtime, "check_async_errors", lambda: None)() + except Exception: + pass + # Notify entity manager of successful completion (default no-op) + try: + entity_manager.finalize_step_run_success( + step_run.id, step_run.outputs + ) + except Exception: + pass # Ensure runtime shutdown after launch try: metrics = {} @@ -457,7 +513,7 @@ def _bypass() -> None: metrics.get("queue_depth"), ) except Exception as e: - logger.debug(f"Runtime shutdown/metrics retrieval error: {e}") + logger.debug("Runtime shutdown/metrics retrieval error: %s", e) def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: """Creates a pipeline run or reuses an existing one. @@ -509,8 +565,6 @@ def _run_step( force_write_logs: The context for the step logs. """ # Create effective step config with serving overrides and no-capture optimizations - from zenml.orchestrators import utils as orchestrator_utils - effective_step_config = self._step.config.model_copy(deep=True) # In no-capture mode, disable caching and step operators for speed @@ -530,6 +584,26 @@ def _run_step( } ) + # Merge request-level parameters (serving) for memory-only runtime + runtime = getattr(self, "_runtime", None) + if isinstance(runtime, MemoryStepRuntime): + try: + req_env = os.getenv("ZENML_SERVING_REQUEST_PARAMS") + req_params = json.loads(req_env) if req_env else {} + if not req_params: + req_params = ( + self._deployment.pipeline_configuration.parameters + or {} + ) + if req_params: + merged = dict(effective_step_config.parameters or {}) + merged.update(req_params) + effective_step_config = effective_step_config.model_copy( + update={"parameters": merged} + ) + except Exception: + pass + # Prepare step run information with effective config step_run_info = StepRunInfo( config=effective_step_config, @@ -541,10 +615,21 @@ def _run_step( force_write_logs=force_write_logs, ) - # Always prepare output URIs for proper artifact flow - output_artifact_uris = output_utils.prepare_output_artifact_uris( - step_run=step_run, stack=self._stack, step=self._step - ) + # Prepare output URIs + if isinstance(runtime, MemoryStepRuntime): + # Build memory:// URIs from declared outputs (no FS writes) + run_id = str( + getattr(pipeline_run, "id", self._orchestrator_run_id) + ) + output_names = list(self._step.config.outputs.keys()) + output_artifact_uris = { + name: f"memory://{run_id}/{self._step_name}/{name}" + for name in output_names + } + else: + output_artifact_uris = output_utils.prepare_output_artifact_uris( + step_run=step_run, stack=self._stack, step=self._step + ) # Run the step. start_time = time.time() @@ -559,11 +644,19 @@ def _run_step( step_run_info=step_run_info, ) else: + # Resolve inputs via runtime in memory-only; otherwise use server-resolved inputs + if isinstance(runtime, MemoryStepRuntime): + input_artifacts = runtime.resolve_step_inputs( + step=self._step, pipeline_run=pipeline_run + ) + else: + input_artifacts = step_run.regular_inputs + self._run_step_without_step_operator( pipeline_run=pipeline_run, step_run=step_run, step_run_info=step_run_info, - input_artifacts=step_run.regular_inputs, + input_artifacts=input_artifacts, output_artifact_uris=output_artifact_uris, ) except: # noqa: E722 @@ -578,16 +671,18 @@ def _run_step( f"`{string_utils.get_human_readable_time(duration)}`." ) - # If runtime is non-blocking and there are downstream steps depending - # on this step, flush now so that downstream input resolution sees - # this step's outputs on the server. - runtime = getattr(self, "_runtime", None) + # If runtime is non-blocking, consider a best-effort flush at step end. + # - If there are downstream steps, flush to ensure server has updates + # - If no downstream (leaf step), flush in serving to publish outputs so UI shows previews immediately if runtime is not None and not runtime.should_flush_on_step_end(): has_downstream = any( self._step_name in cfg.spec.upstream_steps for name, cfg in self._deployment.step_configurations.items() ) - if has_downstream: + should_flush = has_downstream or ( + os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None + ) + if should_flush: try: runtime.flush() except Exception as e: @@ -635,7 +730,19 @@ def _run_step_with_step_operator( self._deployment.pipeline_configuration, "capture", None ) if mode_cfg: - environment["ZENML_CAPTURE_MODE"] = str(mode_cfg).upper() + # If typed capture with explicit mode, export it; unified Capture has no mode + try: + from zenml.capture.config import ( + BatchCapture, + RealtimeCapture, + ) + + if isinstance(mode_cfg, RealtimeCapture): + environment["ZENML_CAPTURE_MODE"] = "REALTIME" + elif isinstance(mode_cfg, BatchCapture): + environment["ZENML_CAPTURE_MODE"] = "BATCH" + except Exception: + pass environment["ZENML_ENABLE_STEP_RUNTIME"] = "true" except Exception: pass @@ -679,96 +786,3 @@ def _run_step_without_step_operator( output_artifact_uris=output_artifact_uris, step_run_info=step_run_info, ) - - def _launch_memory_only(self) -> None: - """Launch the step in pure memory-only mode (no runs, no persistence).""" - from dataclasses import dataclass - from typing import Any - - from zenml.config.step_run_info import StepRunInfo - from zenml.execution.step_runtime import MemoryStepRuntime - from zenml.utils.time_utils import utc_now - - run_id = self._orchestrator_run_id - start_time = utc_now() - substitutions = ( - self._deployment.pipeline_configuration.finalize_substitutions( - start_time=start_time - ) - ) - - @dataclass - class _Cfg: - tags: Any = None - - @dataclass - class _PipelineRunStub: - id: str - model_version: Any = None - pipeline: Any = None - config: Any = _Cfg() - - @dataclass - class _StepCfg: - substitutions: Any - outputs: Any - - @dataclass - class _StepRunStub: - id: str - name: str - model_version: Any - config: Any - is_retriable: bool = True - - pipeline_run_stub = _PipelineRunStub(id=run_id) - step_run_stub = _StepRunStub( - id=run_id, # valid UUID string preferred - name=self._step_name, - model_version=None, - config=_StepCfg( - substitutions=substitutions, outputs=self._step.config.outputs - ), - is_retriable=True, - ) - - # Build URIs from declared outputs (no imports needed) - output_names = list(self._step.config.outputs.keys()) - output_artifact_uris = { - name: f"memory://{run_id}/{self._step_name}/{name}" - for name in output_names - } - - # Resolve inputs via runtime to avoid duplication - if isinstance(self._runtime, MemoryStepRuntime): - self._runtime.set_context( - run_id=run_id, substitutions=substitutions - ) - input_artifacts = self._runtime.resolve_step_inputs( - step=self._step, pipeline_run=pipeline_run_stub - ) - else: - input_artifacts = {} - - runner = StepRunner( - step=self._step, stack=self._stack, runtime=self._runtime - ) - step_run_info = StepRunInfo( - config=self._step.config, - pipeline=self._deployment.pipeline_configuration, - run_name=self._deployment.run_name_template, - pipeline_step_name=self._step_name, - run_id=run_id, - step_run_id=step_run_stub.id, - force_write_logs=lambda: None, - ) - - from typing import Any, cast - - runner.run( - pipeline_run=cast(Any, pipeline_run_stub), - step_run=cast(Any, step_run_stub), - input_artifacts=input_artifacts, - output_artifact_uris=output_artifact_uris, - step_run_info=step_run_info, - ) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 930e4ebb93c..c7d33eef525 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -26,6 +26,7 @@ Optional, Tuple, Type, + cast, ) from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact @@ -116,6 +117,10 @@ def configuration(self) -> StepConfiguration: Returns: The step configuration. """ + # Prefer effective config from step_run_info if available (serving overrides) + effective = getattr(self, "_step_run_info", None) + if effective: + return cast(StepConfiguration, effective.config) return self._step.config def run( @@ -203,6 +208,9 @@ def run( # Initialize the step context singleton StepContext._clear() + # Pass pipeline state if serving provided one + from zenml.orchestrators import utils as _orch_utils + step_context = StepContext( pipeline_run=pipeline_run, step_run=step_run, @@ -211,6 +219,7 @@ def run( output_artifact_configs={ k: v.artifact_config for k, v in output_annotations.items() }, + pipeline_state=_orch_utils.get_pipeline_state(), ) # Parse the inputs for the entrypoint function. diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index c9899938961..ef23134c816 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -198,6 +198,22 @@ def tap_get_step_outputs(step_name: str) -> Optional[Dict[str, Any]]: return _serve_output_tap.get({}).get(step_name) +# Serve pipeline state context +_serve_pipeline_state: ContextVar[Optional[Any]] = ContextVar( + "serve_pipeline_state", default=None +) + + +def set_pipeline_state(state: Optional[Any]) -> None: + """Set pipeline state for serving context.""" + _serve_pipeline_state.set(state) + + +def get_pipeline_state() -> Optional[Any]: + """Get pipeline state for serving context.""" + return _serve_pipeline_state.get(None) + + def tap_clear() -> None: """Clear the serve tap for a fresh request.""" _serve_output_tap.set({}) diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index 85d40738340..7c83cb745ec 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -25,6 +25,11 @@ overload, ) +from zenml.capture.config import ( + BatchCapture, + Capture, + RealtimeCapture, +) from zenml.logger import get_logger if TYPE_CHECKING: @@ -62,7 +67,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Dict[str, Any]] = None, + capture: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None, ) -> Callable[["F"], "Pipeline"]: ... @@ -84,7 +89,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Dict[str, Any]] = None, + capture: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None, ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]: """Decorator to create a pipeline. @@ -115,7 +120,7 @@ def pipeline( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. - capture: Capture policy for the pipeline. + capture: Capture policy for the pipeline (typed only). Returns: A pipeline instance. @@ -124,6 +129,16 @@ def pipeline( def inner_decorator(func: "F") -> "Pipeline": from zenml.pipelines.pipeline_definition import Pipeline + # Directly store typed capture config + cap = capture + cap_val: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None + if cap is not None: + if not isinstance(cap, (Capture, BatchCapture, RealtimeCapture)): + raise ValueError( + "'capture' must be a Capture, BatchCapture or RealtimeCapture." + ) + cap_val = cap + p = Pipeline( name=name or func.__name__, entrypoint=func, @@ -141,7 +156,7 @@ def inner_decorator(func: "F") -> "Pipeline": model=model, retry=retry, substitutions=substitutions, - capture=capture, + capture=cap_val, ) p.__doc__ = func.__doc__ diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 4008461bc94..d94ad89e48d 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -42,6 +42,11 @@ from zenml import constants from zenml.analytics.enums import AnalyticsEvent from zenml.analytics.utils import track_handler +from zenml.capture.config import ( + BatchCapture, + Capture, + RealtimeCapture, +) from zenml.client import Client from zenml.config.compiler import Compiler from zenml.config.pipeline_configurations import ( @@ -148,7 +153,9 @@ def __init__( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Any] = None, + capture: Optional[ + Union[Capture, BatchCapture, RealtimeCapture] + ] = None, ) -> None: """Initializes a pipeline. @@ -181,7 +188,8 @@ def __init__( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. - capture: Capture policy for the pipeline. + capture: Capture policy for the pipeline (typed only): Capture, + BatchCapture or RealtimeCapture. """ self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -333,7 +341,9 @@ def configure( parameters: Optional[Dict[str, Any]] = None, merge: bool = True, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Any] = None, + capture: Optional[ + Union[Capture, BatchCapture, RealtimeCapture] + ] = None, ) -> Self: """Configures the pipeline. @@ -380,7 +390,8 @@ def configure( retry: Retry configuration for the pipeline steps. parameters: input parameters for the pipeline. substitutions: Extra placeholders to use in the name templates. - capture: Capture policy for the pipeline. + capture: Capture policy for the pipeline (typed only). Use + BatchCapture/RealtimeCapture or omit entirely to use sensible defaults. Returns: The pipeline instance that this method was called on. @@ -410,12 +421,16 @@ def configure( # merges dicts tags = self._configuration.tags + tags - # Normalize capture to upper-case string if provided + # Directly store typed capture config + cap_norm = None if capture is not None: - try: - capture = str(capture).upper() - except Exception: - capture = str(capture) + if not isinstance( + capture, (Capture, BatchCapture, RealtimeCapture) + ): + raise ValueError( + "'capture' must be a Capture, BatchCapture or RealtimeCapture." + ) + cap_norm = capture values = dict_utils.remove_none_values( { @@ -435,7 +450,7 @@ def configure( "retry": retry, "parameters": parameters, "substitutions": substitutions, - "capture": capture, + "capture": cap_norm, } ) if not self.__suppress_warnings_flag__: From 5e45b2085ebb558d2c4a4d1c29e22241e2bf3515 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 7 Sep 2025 17:41:21 +0100 Subject: [PATCH 3/8] Refactor serving architecture for improved reliability and performance This commit introduces significant enhancements to the ZenML serving architecture, focusing on a unified capture configuration and memory-only execution mode. The `Capture` class has been simplified to a single typed API, streamlining the capture process and eliminating confusion around capture modes. Key changes include the introduction of memory-only serving, which ensures no database or filesystem writes occur, and the implementation of a robust realtime runtime with improved resource management and error handling. Additionally, request parameter validation has been enhanced to ensure safe merging and type coercion, while logging and metrics have been refined for better observability. These updates aim to provide a more efficient and user-friendly experience for serving pipelines, paving the way for future enhancements and production readiness. --- docs/PR_DESCRIPTION.md | 106 ++++++ docs/book/how-to/serving/serving.md | 222 +++--------- .../serving/advanced/capture-and-runtime.md | 128 ++----- docs/book/serving/advanced/realtime-tuning.md | 9 +- docs/book/serving/overview.md | 53 +-- examples/serving/README.md | 339 +++--------------- examples/serving/weather_pipeline.py | 5 +- src/zenml/capture/config.py | 200 +---------- src/zenml/config/compiler.py | 54 ++- src/zenml/config/pipeline_configurations.py | 21 +- .../config/pipeline_run_configuration.py | 5 +- src/zenml/deployers/serving/service.py | 60 ++-- src/zenml/execution/factory.py | 93 ++--- src/zenml/execution/realtime_runtime.py | 88 +++-- src/zenml/execution/step_runtime.py | 233 +++++++----- .../kubernetes_orchestrator_entrypoint.py | 8 +- .../models/v2/core/pipeline_deployment.py | 36 ++ src/zenml/orchestrators/run_entity_manager.py | 16 + src/zenml/orchestrators/step_launcher.py | 252 +++++++++---- src/zenml/orchestrators/utils.py | 77 ++-- src/zenml/pipelines/pipeline_decorator.py | 19 +- src/zenml/pipelines/pipeline_definition.py | 30 +- .../schemas/pipeline_deployment_schemas.py | 36 +- .../test_default_runtime_metadata_toggle.py | 37 ++ tests/unit/execution/test_memory_runtime.py | 69 ++++ tests/unit/execution/test_realtime_runtime.py | 41 +++ .../test_step_runtime_artifact_write.py | 78 ++++ .../test_step_launcher_params.py | 49 +++ 28 files changed, 1148 insertions(+), 1216 deletions(-) create mode 100644 docs/PR_DESCRIPTION.md create mode 100644 tests/unit/execution/test_default_runtime_metadata_toggle.py create mode 100644 tests/unit/execution/test_memory_runtime.py create mode 100644 tests/unit/execution/test_realtime_runtime.py create mode 100644 tests/unit/execution/test_step_runtime_artifact_write.py create mode 100644 tests/unit/orchestrators/test_step_launcher_params.py diff --git a/docs/PR_DESCRIPTION.md b/docs/PR_DESCRIPTION.md new file mode 100644 index 00000000000..b3cfa388796 --- /dev/null +++ b/docs/PR_DESCRIPTION.md @@ -0,0 +1,106 @@ +## Beta: Unified Serving Capture, Memory-Only Isolation, and Realtime Hardening + +This PR delivers a focused, pragmatic refactor to make serving reliable and easy to reason about for a beta release. It simplifies capture configuration to a single typed `Capture`, unifies the execution path, introduces memory-only isolation, and hardens the realtime runtime with bounded resources and better shutdown behavior. + +### Summary + +- Collapse capture to a single typed API: `Capture(memory_only, code, logs, metadata, visualizations, metrics)`. +- Canonical capture fields on deployments; StepLauncher reads only those (no env/dict overrides). +- Serving request parameters are merged safely (allowlist + light validation + size caps); logged. +- Memory-only serving mode: truly no runs/artifacts/log writes; in-process handoff with per-request isolation. +- Realtime runtime: bounded queue, safe cache sweep, circuit-breaker maintained, improved shutdown and metrics. +- Defensive artifact writes: validation and minimal retries/backoff; fail fast on partial responses. +- In-code TODOs added for post-beta roadmap (transactions, multi-worker/async publishing, monitoring). + +### Motivation + +- Eliminate confusing capture modes and env overrides in code paths. +- Ensure serving is fast (async by default) and memory-only mode never touches DB/FS. +- Prevent cross-request contamination in memory-only; bound resource usage under load. +- Provide clear logs and metrics for diagnosis; pave the way for production hardening. + +### Key Behavioral Changes + +- Pipeline code uses a single `Capture` type; dicts/strings disallowed in code paths. +- Serving merges request parameters only from a declared allowlist; oversized/mismatched params are dropped with warnings. +- Memory-only serving executes fully in-process (no runs/artifacts), with explicit logs; step logs disabled to avoid FS writes. +- Realtime runtime backgrounds publishing with a bounded queue; if the queue is full, events are processed inline. + +### File-Level Changes (Selected) + +- Capture & Compiler + - `src/zenml/capture/config.py`: Single `Capture` dataclass; removed BatchCapture/RealtimeCapture/CapturePolicy. + - `src/zenml/config/compiler.py`: Normalizes typed capture into canonical deployment fields. + - `src/zenml/models/v2/core/pipeline_deployment.py`: Adds canonical capture fields to deployment models. + - `src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py`: Adds DB columns for canonical capture fields. + +- Orchestrator + - `src/zenml/orchestrators/step_launcher.py`: + - Uses canonical fields and serving context. + - Adds `_validate_and_merge_request_params` (allowlist + type coercion + size caps). + - Disables logs in memory-only; avoids FS cleanup for `memory://` URIs. + - `src/zenml/orchestrators/run_entity_manager.py`: In-memory step_run stub with minimal config (`enable_*`, `substitutions`). + - `src/zenml/orchestrators/utils.py`: Serving context helpers and docstrings; removed request-level override plumbing. + +- Execution Runtimes + - `src/zenml/execution/step_runtime.py`: + - `MemoryStepRuntime`: instance-scoped store/locks; per-run cleanup; no globals. + - `DefaultStepRuntime.store_output_artifacts`: defensive batch create (retries/backoff), response count validation; TODO for server-side atomicity. + - `src/zenml/execution/realtime_runtime.py`: + - Bounded queue (maxsize=1024), inline fallback on Full. + - Safe cache sweep (snapshot + safe pop, small time budget). + - Shutdown logs final metrics and warns on non-graceful termination; TODOs for thread-pool or async migration and metrics export. + +- Serving Service & Docs + - `src/zenml/deployers/serving/service.py`: Serving context handling; parameter exposure; cleanup. + - `docs/book/serving/*`: Updated to single Capture, serving async default, memory-only warning/behavior. + - `examples/serving/README.md`: Updated to reflect new serving model; memory-only usage. + +### Configuration & Tuning + +- Serving mode is inferred by context (batch vs. serving). No per-request capture overrides. +- Realtime runtime tuning via env: + - `ZENML_RT_CACHE_TTL_SECONDS` (default 60), `ZENML_RT_CACHE_MAX_ENTRIES` (default 256) + - `ZENML_RT_ERR_REPORT_INTERVAL` (default 15), circuit breaker envs unchanged +- Memory-only: ignored outside serving with a warning. + +### Testing & Validation + +- Unit + - Request parameter validation: allowlist, size caps, type coercion. + - Memory runtime isolation: per-instance store; no cross-contamination. + - Realtime runtime: queue Full → inline fallback; race-free cache sweep; shutdown metrics. + - Defensive artifact writes: retries/backoff; mismatch detection. + +- Manual + - Memory-only serving: no `/runs` or `/artifact_versions` calls; explicit log: `[Memory-only] … in-process handoff (no runs/artifacts)`. + - Serving async default: responses return immediately; background updates proceed. + +### Risk & Mitigations + +- Request param merge: now restricted by allowlist/size/type; unknowns dropped with warnings. +- Memory-only: per-request isolation and no FS/DB writes; logs disabled to avoid side effects. +- Realtime: bounded queue with inline fallback; circuit breaker remains in place. +- Artifact writes: fail fast rather than proceed with partial results; TODO for server-side atomicity. + +### In-Code TODOs (Post-Beta Roadmap) + +- Realtime runtime publishing: + - Either thread pool (ThreadPoolExecutor) workers or asyncio runtime once client has async publish calls and we want async serving. + - Preserve bounded backpressure and orderly shutdown. +- Request parameter schema derived from entrypoint annotations; add total payload size caps and strict mode. +- Server-side transactional/compensating semantics for artifact writes; adopt idempotent, category-aware retries. +- Metrics export to Prometheus; per-worker metrics; worker liveness/health signals; process memory watchdog. + +### How to Review + +- Focus on StepLauncher (param merge, memory-only flags) and runtimes (Memory/Realtime). +- Verify serving behavior in logs; check that memory-only path never touches DB/FS. +- Review TODOs in code for future milestones. + +### Rollout + +- Tag as beta and monitor runtime metrics (`queue_depth`, `failed_total`, `cache_hit_rate`, `op_latency_p95_s`). +- Scale by increasing HTTP workers and replicas; memory-only is fastest for prototypes. +- Provide guidance on cache sizing and memory usage in docs. + diff --git a/docs/book/how-to/serving/serving.md b/docs/book/how-to/serving/serving.md index 9ea0858e726..e14b595dbc9 100644 --- a/docs/book/how-to/serving/serving.md +++ b/docs/book/how-to/serving/serving.md @@ -1,42 +1,25 @@ --- title: Serving Pipelines -description: Millisecond-class pipeline execution over HTTP with intelligent run-only optimization and streaming. +description: Run pipelines as fast HTTP services with async serving by default and optional memory-only execution. --- # Serving Pipelines -ZenML Serving runs pipelines as ultra-fast FastAPI services, achieving millisecond-class latency through intelligent run-only execution. Perfect for real-time inference, AI agents, and interactive workflows. +ZenML Serving exposes a pipeline as a FastAPI service. In serving, execution uses a Realtime runtime with async server updates by default for low latency. You can optionally run memory-only for maximum speed. ## Why Serving vs. Orchestrators -- **Performance**: Millisecond-class latency with run-only execution (no DB/FS writes in fast mode) -- **Simplicity**: Call your pipeline via HTTP; get results or stream progress -- **Intelligence**: Automatically switches between tracking and run-only modes based on capture settings -- **Flexibility**: Optional run/step tracking with fine-grained capture policies +- Performance: Async serving with in-process caching for low latency. +- Simplicity: Invoke your pipeline over HTTP; get results or stream progress. +- Control: Single, typed `Capture` config to tune observability or enable memory-only. -Use orchestrators for scheduled, long-running, reproducible workflows; use Serving for real-time request/response. - -## How It Works - -**Run-Only Architecture** (for millisecond latency): -- **ServingOverrides**: Per-request parameter injection using ContextVar isolation -- **ServingBuffer**: In-memory step output handoff with no persistence -- **Effective Config**: Runtime configuration merging without model mutations -- **Skip I/O**: Bypasses all database writes and filesystem operations -- **Input Injection**: Upstream step outputs automatically injected as parameters - -**Full Tracking Mode** (when capture enabled): -- Traditional ZenML tracking with runs, steps, artifacts, and metadata -- Orchestrator-based execution with full observability - -The service automatically chooses the optimal execution mode based on your capture settings. +Use orchestrators for scheduled, reproducible workflows. Use Serving for request/response inference. ## Quickstart Prerequisites - A deployed pipeline; note its deployment UUID as `ZENML_PIPELINE_DEPLOYMENT_ID`. -- Python env with dev deps (as per CONTRIBUTING). Start the service @@ -47,7 +30,7 @@ export ZENML_SERVICE_PORT=8001 python -m zenml.deployers.serving.app ``` -Synchronous invocation +Invoke (sync) ```bash curl -s -X POST "http://localhost:8001/invoke" \ @@ -55,65 +38,53 @@ curl -s -X POST "http://localhost:8001/invoke" \ -d '{"parameters": {"your_param": "value"}}' ``` -## Performance Modes +## Capture (typed-only) -ZenML Serving automatically chooses the optimal execution mode: - -### Run-Only Mode (Millisecond Latency) - -Activated when `capture="none"` or no capture settings specified: +Configure capture at the pipeline decorator using a single, typed `Capture`: ```python -@pipeline(settings={"capture": "none"}) -def fast_pipeline(x: int) -> int: - return x * 2 -``` +from zenml import pipeline +from zenml.capture.config import Capture -**Optimizations**: -- ✅ Zero database writes -- ✅ Zero filesystem operations -- ✅ In-memory step output handoff -- ✅ Per-request parameter injection -- ✅ Effective configuration merging -- ✅ Multi-worker safe (ContextVar isolation) +@pipeline(capture=Capture()) # serving async by default +def serve_pipeline(...): + ... -**Use for**: Real-time inference, AI agents, interactive demos - -### Full Tracking Mode +@pipeline(capture=Capture(memory_only=True)) # serving only +def max_speed_pipeline(...): + ... +``` -Activated when capture settings specify tracking: +Options (observability only; do not affect dataflow): +- `code`: include code/source/docstrings in metadata (default True) +- `logs`: persist step logs (default True) +- `metadata`: publish run/step metadata (default True) +- `visualizations`: persist visualizations (default True) +- `metrics`: emit runtime metrics (default True) -```python -@pipeline(settings={"capture": "full"}) -def tracked_pipeline(x: int) -> int: - return x * 2 -``` +Notes +- Serving is async by default; there is no `flush_on_step_end` knob. +- `memory_only=True` is ignored outside serving with a warning. -**Features**: -- Complete run/step tracking -- Artifact persistence -- Metadata collection -- Dashboard integration +## Request Parameters -**Use for**: Experimentation, debugging, audit trails +Request JSON under `parameters` is merged into the effective step config in serving. Logged keys indicate which parameters were applied. ## Execution Modes -- **Sync**: `POST /invoke` waits for completion; returns results or error. -- **Async**: `POST /invoke?mode=async` returns a `job_id`; poll `GET /jobs/{job_id}`. -- **Streaming**: `GET /stream/{job_id}` (SSE) or `WebSocket /stream` to receive progress and completion events in real time. +- Sync: `POST /invoke` waits for completion; returns results or error. +- Async: `POST /invoke?mode=async` returns a `job_id`; poll `GET /jobs/{job_id}`. +- Streaming: `GET /stream/{job_id}` (SSE) or `WebSocket /stream` to stream progress. Async example ```bash -# Submit -JOB_ID=$(curl -s -X POST "http://localhost:8001/invoke?mode=async" -H "Content-Type: application/json" -d '{"parameters":{}}' | jq -r .job_id) - -# Poll +JOB_ID=$(curl -s -X POST "http://localhost:8001/invoke?mode=async" \ + -H "Content-Type: application/json" -d '{"parameters":{}}' | jq -r .job_id) curl -s "http://localhost:8001/jobs/$JOB_ID" ``` -SSE example +SSE ```bash curl -N -H "Accept: text/event-stream" "http://localhost:8001/stream/$JOB_ID" @@ -123,124 +94,19 @@ curl -N -H "Accept: text/event-stream" "http://localhost:8001/stream/$JOB_ID" - `/health`: Service health and uptime. - `/info`: Pipeline name, steps, parameter schema, deployment info. -- `/metrics`: Execution statistics (counts, averages). +- `/metrics`: Execution statistics (queue depth, cache hit rate, latencies when metrics enabled). - `/status`: Service configuration snapshot. -- `/invoke`: Execute (sync/async) with optional parameter overrides. +- `/invoke`: Execute (sync/async) with optional parameters. - `/jobs`, `/jobs/{id}`, `/jobs/{id}/cancel`: Manage async jobs. -- `/stream/{id}`: Server‑Sent Events stream for a job; `WebSocket /stream` for bidirectional. - -## Configuration - -Key environment variables - -- `ZENML_PIPELINE_DEPLOYMENT_ID`: Deployment UUID (required). -- `ZENML_SERVING_CAPTURE_DEFAULT`: Default capture mode (`none` for run-only, `full` for tracking). -- `ZENML_SERVICE_HOST` (default: `0.0.0.0`), `ZENML_SERVICE_PORT` (default: `8001`). -- `ZENML_LOG_LEVEL`: Logging verbosity. - -## Capture Policies - -Control what gets tracked per invocation: - -- **`none`**: Run-only mode, millisecond latency, no persistence -- **`metadata`**: Track runs/steps, no payload data -- **`full`**: Complete tracking with artifacts and metadata -- **`sampled`**: Probabilistic tracking for cost control -- **`errors_only`**: Track only failed executions - -Configuration locations: -- **Pipeline-level**: `@pipeline(settings={"capture": "none"})` -- **Request-level**: `{"capture_override": {"mode": "full"}}` -- **Environment**: `ZENML_SERVING_CAPTURE_DEFAULT=none` - -Precedence: Request > Pipeline > Environment > Default - -## Advanced Features - -### Input/Output Contracts - -Pipelines automatically expose their signature: - -```python -@pipeline -def my_pipeline(city: str, temperature: float) -> str: - return process_weather(city, temperature) - -# Automatic parameter schema: -# {"city": {"type": "str", "required": true}, -# "temperature": {"type": "float", "required": true}} -``` - -### Multi-Step Pipelines - -Step outputs automatically injected as inputs: - -```python -@step -def fetch_data(city: str) -> dict: - return {"weather": "sunny", "temp": 25} - -@step -def analyze_data(weather_data: dict) -> str: - return f"Analysis: {weather_data}" - -@pipeline -def weather_pipeline(city: str) -> str: - data = fetch_data(city) - return analyze_data(data) # weather_data auto-injected -``` - -### Response Building - -Only declared pipeline outputs returned: - -```python -@pipeline -def multi_output_pipeline(x: int) -> tuple[int, str]: - return x * 2, f"Result: {x}" - -# Response: {"outputs": {"output_0": 4, "output_1": "Result: 2"}} -``` - -## Testing & Local Dev - -Exercise endpoints locally: - -```bash -# Health check -curl http://localhost:8001/health - -# Pipeline info -curl http://localhost:8001/info - -# Execute with parameters -curl -X POST http://localhost:8001/invoke \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"city": "Paris"}}' - -# Override capture mode -curl -X POST http://localhost:8001/invoke \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"city": "Tokyo"}, "capture_override": {"mode": "full"}}' -``` +- `/stream/{id}`: Server‑Sent Events stream; `WebSocket /stream` for bidirectional. ## Troubleshooting -- **Missing deployment ID**: set `ZENML_PIPELINE_DEPLOYMENT_ID`. -- **Slow performance**: ensure `capture="none"` for run-only mode. -- **Import errors**: run-only mode bypasses some ZenML integrations that aren't needed for serving. -- **Memory leaks**: serving contexts are automatically cleared per request. -- **Multi-worker issues**: ContextVar isolation ensures thread safety. - -## Architecture Comparison +- Missing deployment ID: set `ZENML_PIPELINE_DEPLOYMENT_ID`. +- Slow responses: ensure you are in serving (async by default) or consider `Capture(memory_only=True)` for prototypes. +- Multi-worker/safety: Serving isolates request state; taps are cleared per request. -| Feature | Run-Only Mode | Full Tracking | -|---------|---------------|---------------| -| **Latency** | Milliseconds | Seconds | -| **DB Writes** | None | Full tracking | -| **FS Writes** | None | Artifacts | -| **Memory** | Minimal | Standard | -| **Debugging** | Limited | Complete | -| **Production** | ✅ Optimal | For experimentation | +## See Also -Choose run-only for production serving, full tracking for development and debugging. \ No newline at end of file +- Capture & Runtimes (advanced): serving defaults, toggles, memory-only behavior. +- Realtime Tuning: cache TTL/size, error reporting, and circuit breaker knobs. diff --git a/docs/book/serving/advanced/capture-and-runtime.md b/docs/book/serving/advanced/capture-and-runtime.md index 06631afeb39..6cad8a86b73 100644 --- a/docs/book/serving/advanced/capture-and-runtime.md +++ b/docs/book/serving/advanced/capture-and-runtime.md @@ -1,79 +1,60 @@ --- -title: Capture Policy & Execution Runtimes (Advanced) +title: Capture & Execution Runtimes (Advanced) --- -# Capture Policy & Execution Runtimes (Advanced) +# Capture & Execution Runtimes (Advanced) This page explains how capture options map to execution runtimes and how to tune them for production serving. ## Execution Runtimes -- DefaultStepRuntime - - Standard ZenML execution: persists artifacts, creates runs and step runs, captures metadata/logs per config. +- DefaultStepRuntime (Batch) + - Standard ZenML execution: persists artifacts, creates runs and step runs, captures metadata/logs based on capture toggles. + - Used outside serving. -- RealtimeStepRuntime - - Focus: Low latency + observability. - - Features: - - In-process artifact value cache for downstream steps in the same process. - - Tunables: `ttl_seconds`, `max_entries` via capture options (or env vars `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`). - - Async server updates with a background worker. - - `flush_on_step_end` controls whether to block at step boundary to flush updates. - - In serving with `mode=REALTIME`, `flush_on_step_end` defaults to `false` unless explicitly set. +- RealtimeStepRuntime (Serving, async by default) + - Optimized for low latency with async server updates and an in‑process cache for downstream loads. + - Tunables via env: `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`, `ZENML_RT_ERR_REPORT_INTERVAL`, circuit breaker knobs (see Realtime Tuning page). -- MemoryStepRuntime - - Focus: Pure in-memory execution (no server, no persistence). - - Behavior: Inter-step data is exchanged via in-process memory handles; no runs or artifacts. - - Configure with REALTIME: `@pipeline(capture=Capture(memory_only=True))`. +- MemoryStepRuntime (Serving with memory_only) + - Pure in‑memory execution: no runs/steps/artifacts or server calls. + - Inter‑step data is exchanged via in‑process handles. -## Capture Configuration - -Where to set: -- In code: `@pipeline(capture=...)` (typed only) -- In run config YAML: `capture: REALTIME|BATCH` - -Recommended API (typed) +## Capture API (typed only) ```python from zenml.capture.config import Capture -# Not required for defaults, but explicit usage examples: - -# Realtime (default in serving), non-blocking reporting +# Serving async (default) – explicit but not required @pipeline(capture=Capture()) -def serve(...): - ... - -# Realtime, blocking reporting -@pipeline(capture=Capture(flush_on_step_end=True)) -# Realtime, memory-only (serving only) +# Serving memory-only (no DB/artifacts) @pipeline(capture=Capture(memory_only=True)) +def serve(...): + ... ``` -Notes: -- Modes are inferred by context (batch vs serving), you only set options: - - `flush_on_step_end`: If `False`, serving returns immediately; tracking is published asynchronously by the runtime worker. - - `memory_only=True`: Pure in-memory execution (no runs/artifacts), serving only. - - `code=False`: Skips docstring/source capture (metadata), but does not affect code execution. +Options: +- `memory_only` (serving only): in‑process handoff; no persistence. +- Observability toggles (affect only observability, not dataflow): + - `code`: include code/source/docstrings in metadata (default True) + - `logs`: persist step logs (default True) + - `metadata`: publish run/step metadata (default True) + - `visualizations`: persist visualizations (default True) + - `metrics`: emit runtime metrics (default True) ## Serving Defaults -- REALTIME + serving context: - - If capture is unset, defaults to non-blocking (`flush_on_step_end=False`). - - Users can set `flush_on_step_end=True` to block at step boundary. +- Serving uses the Realtime runtime and returns asynchronously by default. +- There is no `flush_on_step_end` knob; batch is blocking, serving is async. ## Validation & Behavior -- Realtime capture outside serving: - - Allowed for development; logs a warning and continues. In production, use the serving service. -- memory_only outside serving: - - Ignored with a warning; standard execution proceeds (Batch/Realtime as applicable). -- Contradictory options: - - Capture(memory_only=True, flush_on_step_end=True) → raises ValueError. +- memory_only outside serving: ignored with a warning. +- Observability toggles never affect dataflow/caching, only what’s recorded. ## Step Operators & Remote Execution -- Step operators inherit capture via environment (e.g., `ZENML_CAPTURE_MODE`). -- Remote entrypoints construct the matching runtime and honor capture options. +Step operators and remote entrypoints derive behavior from context; no capture env propagation is required. ## Memory-Only Internals (for deeper understanding) @@ -84,52 +65,14 @@ Notes: - `store_output_artifacts`: stores outputs back to the store; returns new handles for downstream steps. - No server calls; no runs or artifacts are created. -## Environment Variables - -- `ZENML_CAPTURE_MODE`: global default capture when not set in the pipeline. -- `ZENML_SERVING_CAPTURE_DEFAULT`: used internally to reduce tracking when capture is not set (serving compatibility). -- `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`: Realtime cache controls. - ## Recipes -- Low-latency serving (eventual consistency): - - `@pipeline(capture=Capture())` - -- Strict serving (strong consistency): - - `@pipeline(capture=Capture(flush_on_step_end=True))` - -- Memory-only (stateless service): - - `@pipeline(capture=Capture(memory_only=True))` - -### Control logs/metadata/visualizations (Batch & Realtime) - -These are pipeline settings, not capture options. Set them via `pipeline.configure(...)` or YAML: - -```python -@pipeline() -def train(...): - ... - -# In code -train = train.with_options() -train.configure( - enable_step_logs=True, - enable_artifact_metadata=True, - enable_artifact_visualization=False, -) -``` - -Or in run config YAML: - -```yaml -enable_step_logs: true -enable_artifact_metadata: true -enable_artifact_visualization: false -``` +- Low-latency serving (default): `@pipeline(capture=Capture())` +- Memory-only (stateless service): `@pipeline(capture=Capture(memory_only=True))` ### Disable code capture (docstring/source) -Code capture affects metadata only (not execution). You can disable it via capture in both modes: +Code capture affects metadata only (not execution). You can disable it via capture: ```python from zenml.capture.config import Capture @@ -145,14 +88,7 @@ def train(...): ## FAQ -- Can I enable only partial capture (e.g., errors-only logs)? - - Yes, e.g., `logs: errors-only` and `metadata: false`. - - Does `code: false` break step execution? - No. It only disables docstring/source capture. Steps still run normally. - -- How does caching interact with REALTIME? - - Default caching behavior is unchanged. Set `cache_enabled: false` to bypass caching entirely. - - Can memory-only work with parallelism? - Memory-only is per-process. For multi-process/multi-container setups, use persistence for cross-process data. diff --git a/docs/book/serving/advanced/realtime-tuning.md b/docs/book/serving/advanced/realtime-tuning.md index abd4bfec946..b11a1ab73f5 100644 --- a/docs/book/serving/advanced/realtime-tuning.md +++ b/docs/book/serving/advanced/realtime-tuning.md @@ -35,15 +35,9 @@ Circuit Breaker (async → inline fallback) - `ZENML_RT_CB_OPEN_SECONDS` (default: `300`) - Duration (seconds) to keep breaker open; inline publishing is used while open. -Capture & Mode (context) - -- `ZENML_CAPTURE_MODE`: default runtime mode from environment (`BATCH|REALTIME`). -- `ZENML_SERVING_CAPTURE_DEFAULT`: when present, serving defaults to `REALTIME` if capture is not set. - Notes -- Realtime outside serving logs a warning and continues (for local development). For production serving, run via the serving service. -- YAML/ENV can still set `capture: REALTIME|BATCH` for run configs; code paths are typed-only (`Capture`, `BatchCapture`, `RealtimeCapture`). +- Serving uses the Realtime runtime by default. Outside serving, batch runtime is used. ## Metrics & Observability @@ -83,4 +77,3 @@ Recommendation - Low cache hit rate: - Check step dependencies and cache TTL; ensure downstream steps run in the same process to benefit from warm cache. - diff --git a/docs/book/serving/overview.md b/docs/book/serving/overview.md index 7c802f90238..e1d4ef59f0d 100644 --- a/docs/book/serving/overview.md +++ b/docs/book/serving/overview.md @@ -7,8 +7,8 @@ title: Pipeline Serving Overview ## What Is Pipeline Serving? - Purpose: Expose a ZenML pipeline as a low-latency service (e.g., via FastAPI) that executes steps on incoming requests and returns results. -- Value: Production-grade orchestration with flexible capture policies to balance latency, observability, and lineage. -- Modes: Default batch-style execution, optimized realtime execution, and pure in-memory execution for maximum speed. +- Value: Production-grade orchestration with simple capture options to balance latency, observability, and lineage. +- Modes by context: Batch outside serving (blocking), Realtime in serving (async), and pure in-memory serving for maximum speed. ## Quick Start @@ -18,48 +18,31 @@ title: Pipeline Serving Overview 2) Choose capture only when you need to change defaults - You don’t need to set capture in most cases: - - Normal runs default to Batch. - - Serving defaults to Realtime (non-blocking). + - Batch (outside serving) is blocking. + - Serving is async by default. - Optional tweaks (typed API only): - - Low-latency, non-blocking (explicit): `@pipeline(capture=Capture())` - - Blocking realtime (serving): `@pipeline(capture=Capture(flush_on_step_end=True))` - - Pure in-memory (serving only): `@pipeline(capture=Capture(memory_only=True))` + - Make it explicit: `@pipeline(capture=Capture())` + - Pure in-memory serving: `@pipeline(capture=Capture(memory_only=True))` 3) Deploy the serving service with your preferred deployer and call the FastAPI endpoint. -## Capture Modes (Essentials) +## Capture Essentials -- BATCH (default) - - Behavior: Standard ZenML behavior (pipeline runs + step runs + artifacts + metadata/logs depending on config). - - Use when: Full lineage and strong consistency are required. +- Batch (outside serving) + - Blocking publishes; full persistence as configured by capture toggles. -- REALTIME - - Behavior: Optimized for latency and throughput. - - In-memory cache of artifact values within the same process. - - Async server updates by default; in serving, defaults to non-blocking responses (tracking finishes in background). - - Use when: You need low-latency serving with observability. +- Serving (inside serving) + - Async publishes by default with an in‑process cache; low latency. -- Memory-only (special case inside REALTIME) - - Behavior: Pure in-memory execution: - - No pipeline runs or step runs, no artifacts, no server calls. - - Steps exchange data in-process; response returns immediately. - - Use when: Maximum speed (prototyping, ultra-low-latency paths) without lineage. - - Note: Outside serving contexts, `memory_only=True` is ignored with a warning and standard execution proceeds. +- Memory-only (serving only) + - Pure in‑memory execution: no runs/steps/artifacts or server calls; maximum speed. + - Outside serving, `memory_only=True` is ignored with a warning. ## Where To Configure Capture - In code (typed only) - - `@pipeline(capture=Capture())` - - `@pipeline(capture=Capture(flush_on_step_end=False))` - -- In run config YAML -```yaml -capture: REALTIME # or BATCH -``` - -- Environment (fallbacks) - - `ZENML_CAPTURE_MODE=BATCH|REALTIME` - - Serving sets `ZENML_SERVING_CAPTURE_DEFAULT` internally to switch default to Realtime when capture is not set. + - `@pipeline(capture=Capture(...))` + - Options: `memory_only`, `code`, `logs`, `metadata`, `visualizations`, `metrics` ## Best Practices @@ -72,7 +55,7 @@ capture: REALTIME # or BATCH - Great for tests, benchmarks, or hot paths where lineage is not needed. - Compliance or rich lineage - - Use Batch (default in non-serving) or set: `@pipeline(capture=Capture(flush_on_step_end=True))`. + - Use Batch (outside serving) where publishes are blocking by default. ## FAQ (Essentials) @@ -81,7 +64,7 @@ capture: REALTIME # or BATCH - Memory-only (Realtime with `memory_only=True`): No; executes purely in memory. - Will serving block responses to flush tracking? - - Realtime in serving defaults to non-blocking (returns immediately), unless you set `flush_on_step_end=True`. + - No. Serving is async by default and returns immediately. - Is memory-only safe for production? - Yes for stateless, speed-critical paths. Note: No lineage or persisted artifacts. diff --git a/examples/serving/README.md b/examples/serving/README.md index c7e7062910a..ca5a8295741 100644 --- a/examples/serving/README.md +++ b/examples/serving/README.md @@ -1,338 +1,118 @@ # ZenML Pipeline Serving Examples -This directory contains examples demonstrating ZenML's new **run-only serving architecture** with millisecond-class latency for real-time inference and AI applications. +This directory contains examples that run pipelines as HTTP services using ZenML Serving. -## 🚀 **New Run-Only Architecture** +Highlights -ZenML Serving now automatically optimizes for performance: +- Async serving by default for low latency +- Optional memory-only execution via `Capture(memory_only=True)` +- Request parameter merging and streaming support -- **🏃‍♂️ Run-Only Mode**: Millisecond-class latency with zero DB/FS writes -- **🧠 Intelligent Switching**: Automatically chooses optimal execution mode -- **⚡ In-Memory Handoff**: Step outputs passed directly via serving buffer -- **🔄 Multi-Worker Safe**: ContextVar isolation for concurrent requests -- **📝 No Model Mutations**: Clean effective configuration merging +## Files -## 📁 Files +1. `weather_pipeline.py` – simple weather analysis +2. `chat_agent_pipeline.py` – conversational agent with streaming +3. `test_serving.py` – basic endpoint checks -1. **`weather_pipeline.py`** - Simple weather analysis with run-only optimization -2. **`chat_agent_pipeline.py`** - Streaming conversational AI with fast execution -3. **`test_serving.py`** - Test script to verify serving endpoints -4. **`README.md`** - This comprehensive guide +## Serving Modes (by context) -## 🎯 Examples Overview +- Batch (outside serving): blocking publishes; standard persistence +- Serving (default): async publishes with in‑process cache +- Memory-only (serving only): in‑process handoff; no DB/artifacts -### 1. Weather Agent Pipeline -- **Purpose**: Analyze weather for any city with AI recommendations -- **Mode**: Run-only optimization for millisecond response times -- **Features**: Automatic parameter injection, rule-based fallback -- **API**: Standard HTTP POST requests +## Quick Start: Weather Agent -### 2. Streaming Chat Agent Pipeline -- **Purpose**: Real-time conversational AI with streaming responses -- **Mode**: Run-only with optional streaming support -- **Features**: Token-by-token streaming, WebSocket support -- **API**: HTTP, WebSocket streaming, async jobs with SSE - -## 🏃‍♂️ **Run-Only vs Full Tracking** - -### Run-Only Mode (Default - Millisecond Latency) -```python -@pipeline # No capture settings = run-only mode -def fast_pipeline(city: str) -> str: - return analyze_weather(city) -``` - -**✅ Optimizations Active:** -- Zero database writes -- Zero filesystem operations -- In-memory step output handoff -- Per-request parameter injection -- Multi-worker safe execution - -### Full Tracking Mode (For Development) -```python -@pipeline(settings={"capture": "full"}) -def tracked_pipeline(city: str) -> str: - return analyze_weather(city) -``` - -**📊 Features Active:** -- Complete run/step tracking -- Artifact persistence -- Dashboard integration -- Debug information - -# 🚀 Quick Start Guide - -## Prerequisites - -```bash -# Install ZenML with serving support -pip install zenml - -# Optional: For LLM analysis (otherwise uses rule-based fallback) -export OPENAI_API_KEY=your_openai_api_key_here -pip install openai -``` - -## Example 1: Weather Agent (Run-Only Mode) - -### Step 1: Create and Deploy Pipeline +1) Create and deploy the pipeline ```bash python weather_pipeline.py ``` -**Expected Output:** -``` -🌤️ Creating Weather Agent Pipeline Deployment... -📦 Creating deployment for serving... -✅ Deployment ID: 12345678-1234-5678-9abc-123456789abc - -🚀 Start serving with: -export ZENML_PIPELINE_DEPLOYMENT_ID=12345678-1234-5678-9abc-123456789abc -python -m zenml.deployers.serving.app -``` - -### Step 2: Start Serving Service +2) Start the serving service ```bash -export ZENML_PIPELINE_DEPLOYMENT_ID=12345678-1234-5678-9abc-123456789abc +export ZENML_PIPELINE_DEPLOYMENT_ID= python -m zenml.deployers.serving.app ``` -**Service Configuration:** -- **Mode**: Run-only (millisecond latency) -- **Host**: `http://localhost:8000` -- **Optimizations**: All I/O operations bypassed - -### Step 3: Test Ultra-Fast Weather Analysis +3) Invoke ```bash -# Basic request (millisecond response time) curl -X POST "http://localhost:8000/invoke" \ -H "Content-Type: application/json" \ -d '{"parameters": {"city": "Paris"}}' - -# Response format: -{ - "success": true, - "outputs": { - "weather_analysis": "Weather in Paris is sunny with 22°C..." - }, - "execution_time": 0.003, # Milliseconds! - "metadata": { - "pipeline_name": "weather_agent_pipeline", - "parameters_used": {"city": "Paris"}, - "steps_executed": 3 - } -} -``` - -## Example 2: Streaming Chat Agent (Run-Only Mode) - -### Step 1: Create Chat Pipeline - -```bash -python chat_agent_pipeline.py -``` - -### Step 2: Start Serving Service - -```bash -export ZENML_PIPELINE_DEPLOYMENT_ID= -python -m zenml.deployers.serving.app -``` - -### Step 3: Test Ultra-Fast Chat - -#### Method A: Instant Response (Milliseconds) -```bash -curl -X POST "http://localhost:8000/invoke" \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"message": "Hello!", "user_name": "Alice"}}' - -# Ultra-fast response: -{ - "success": true, - "outputs": {"chat_response": "Hello Alice! How can I help you today?"}, - "execution_time": 0.002 # Milliseconds! -} -``` - -#### Method B: Streaming Mode (Optional) -```bash -# Create async job -JOB_ID=$(curl -X POST 'http://localhost:8000/invoke?mode=async' \ - -H 'Content-Type: application/json' \ - -d '{"parameters": {"message": "Tell me about AI", "enable_streaming": true}}' \ - | jq -r .job_id) - -# Stream real-time results -curl -N "http://localhost:8000/stream/$JOB_ID" ``` -#### Method C: WebSocket Streaming -```bash -# Install wscat: npm install -g wscat -wscat -c ws://localhost:8000/stream - -# Send message: -{"parameters": {"message": "Hi there!", "user_name": "Alice", "enable_streaming": true}} -``` - -## 📊 Performance Comparison - -| Feature | Run-Only Mode | Full Tracking | -|---------|---------------|---------------| -| **Response Time** | 1-5ms | 100-500ms | -| **Throughput** | 1000+ RPS | 10-50 RPS | -| **Memory Usage** | Minimal | Standard | -| **DB Operations** | Zero | Full tracking | -| **FS Operations** | Zero | Artifact storage | -| **Use Cases** | Production serving | Development/debug | +Service defaults -## 🛠️ Advanced Configuration +- Host: `http://localhost:8000` +- Serving: async by default -### Performance Tuning +## Configuration ```bash -# Set capture mode explicitly -export ZENML_SERVING_CAPTURE_DEFAULT=none # Run-only mode - -# Multi-worker deployment -export ZENML_SERVICE_WORKERS=4 +export ZENML_PIPELINE_DEPLOYMENT_ID= python -m zenml.deployers.serving.app ``` -### Override Modes Per Request - -```bash -# Force tracking for a single request (slower but tracked) -curl -X POST "http://localhost:8000/invoke" \ - -H "Content-Type: application/json" \ - -d '{ - "parameters": {"city": "Tokyo"}, - "capture_override": {"mode": "full"} - }' -``` - -### Monitor Performance - -```bash -# Service health and performance -curl http://localhost:8000/health -curl http://localhost:8000/metrics - -# Pipeline information -curl http://localhost:8000/info -``` +To enable memory-only mode, set it in code: -## 🏗️ Architecture Deep Dive - -### Run-Only Execution Flow +```python +from zenml import pipeline +from zenml.capture.config import Capture -``` -Request → ServingOverrides → Effective Config → StepRunner → ServingBuffer → Response - (Parameters) (No mutations) (No I/O) (In-memory) (JSON) +@pipeline(capture=Capture(memory_only=True)) +def serve_max_speed(...): + ... ``` -1. **Request Arrives**: JSON parameters received -2. **ServingOverrides**: Per-request parameter injection via ContextVar -3. **Effective Config**: Runtime configuration merging (no model mutations) -4. **Step Execution**: Direct execution with serving buffer storage -5. **Response Building**: Only declared outputs returned as JSON +## Execution Flow (serving) -### Key Components +Request → Parameter merge → StepRunner → Response -- **`ServingOverrides`**: Thread-safe parameter injection -- **`ServingBuffer`**: In-memory step output handoff -- **Effective Configuration**: Runtime config merging without mutations -- **ContextVar Isolation**: Multi-worker safe execution +- Parameters under `parameters` are merged into step config. +- Serving is async; background updates do not block the response. -## 📚 API Reference +## API Reference -### Core Endpoints +Core endpoints -| Endpoint | Method | Purpose | Performance | -|----------|---------|---------|-------------| -| `/invoke` | POST | Execute pipeline | Milliseconds | -| `/health` | GET | Service health | Instant | -| `/info` | GET | Pipeline schema | Instant | -| `/metrics` | GET | Performance stats | Instant | +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/invoke` | POST | Execute pipeline (sync or async) | +| `/health` | GET | Service health | +| `/info` | GET | Pipeline schema & deployment info | +| `/metrics` | GET | Runtime metrics (if enabled) | +| `/jobs`, `/jobs/{id}` | GET | Manage async jobs | +| `/stream/{id}` | GET | Server‑Sent Events stream | -### Request Format +Request format ```json { "parameters": { - "city": "string", - "temperature": "number", - "enable_streaming": "boolean" - }, - "capture_override": { - "mode": "none|metadata|full" + "city": "string" } } ``` -### Response Format +## Troubleshooting -```json -{ - "success": true, - "outputs": { - "output_name": "output_value" - }, - "execution_time": 0.003, - "metadata": { - "pipeline_name": "string", - "parameters_used": {}, - "steps_executed": 0 - } -} -``` - -## 🔧 Troubleshooting +- Missing deployment ID: set `ZENML_PIPELINE_DEPLOYMENT_ID`. +- Slow responses: serving is async by default; for prototypes consider `Capture(memory_only=True)`. +- Monitor: use `/metrics` for queue depth, cache hit rate, and latencies. -### Performance Issues -- ✅ **Ensure run-only mode**: No capture settings or `capture="none"` -- ✅ **Check environment**: `ZENML_SERVING_CAPTURE_DEFAULT=none` -- ✅ **Monitor metrics**: Use `/metrics` endpoint - -### Common Problems -- **Slow responses**: Verify run-only mode is active -- **Import errors**: Run-only mode bypasses unnecessary integrations -- **Memory leaks**: Serving contexts auto-cleared per request -- **Multi-worker issues**: ContextVar provides thread isolation - -### Debug Mode -```bash -# Enable full tracking for debugging -curl -X POST "http://localhost:8000/invoke" \ - -d '{"parameters": {...}, "capture_override": {"mode": "full"}}' -``` - -## 🎯 Production Deployment - -### Docker Example +## Docker ```dockerfile -FROM python:3.9-slim - -# Install ZenML +FROM python:3.11-slim RUN pip install zenml - -# Set serving configuration -ENV ZENML_SERVING_CAPTURE_DEFAULT=none ENV ZENML_SERVICE_HOST=0.0.0.0 ENV ZENML_SERVICE_PORT=8000 - -# Start serving CMD ["python", "-m", "zenml.deployers.serving.app"] ``` -### Kubernetes Example +## Kubernetes (snippet) ```yaml apiVersion: apps/v1 @@ -340,7 +120,7 @@ kind: Deployment metadata: name: zenml-serving spec: - replicas: 3 + replicas: 2 template: spec: containers: @@ -349,18 +129,7 @@ spec: env: - name: ZENML_PIPELINE_DEPLOYMENT_ID value: "your-deployment-id" - - name: ZENML_SERVING_CAPTURE_DEFAULT - value: "none" ports: - containerPort: 8000 ``` -## 🚀 Next Steps - -1. **Deploy Examples**: Try both weather and chat examples -2. **Measure Performance**: Use the `/metrics` endpoint -3. **Scale Up**: Deploy with multiple workers -4. **Monitor**: Integrate with your observability stack -5. **Optimize**: Fine-tune capture policies for your use case - -The new run-only architecture delivers production-ready performance for real-time AI applications! 🎉 \ No newline at end of file diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index f6427322b77..910b4257fae 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -18,7 +18,6 @@ from typing import Dict from zenml import pipeline, step -from zenml.capture.config import Capture from zenml.config import DockerSettings # Import enums for type-safe capture mode configuration @@ -100,6 +99,7 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: and returns analysis with no database or filesystem writes. """ import time + temp = weather_data["temperature"] humidity = weather_data["humidity"] wind = weather_data["wind_speed"] @@ -221,7 +221,6 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: @pipeline( - capture=Capture(memory_only=True), on_init=init_hook, settings={ "docker": docker_settings, @@ -274,7 +273,7 @@ def weather_agent_pipeline(city: str = "London") -> str: # Create deployment without running deployment = weather_agent_pipeline._create_deployment() - weather_agent_pipeline() + # weather_agent_pipeline() print("\n✅ Pipeline deployed for run-only serving!") print(f"📋 Deployment ID: {deployment.id}") diff --git a/src/zenml/capture/config.py b/src/zenml/capture/config.py index 243dc7a5e10..16e36dce1d1 100644 --- a/src/zenml/capture/config.py +++ b/src/zenml/capture/config.py @@ -11,202 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Capture configuration for ZenML.""" +"""Capture configuration for ZenML (single, typed).""" -import os from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, Union - - -class ServingMode(str, Enum): - """Serving mode enum.""" - - BATCH = "BATCH" - REALTIME = "REALTIME" - - -# Backwards-compat alias used by runtime factory and others -CaptureMode = ServingMode @dataclass(frozen=True) class Capture: - """Unified capture configuration with simple, typed options. + """Single capture configuration. - Modes are inferred by context: - - Orchestrated runs default to batch semantics. - - Serving defaults to realtime semantics. + Semantics are derived from context: + - Batch (orchestrated) runs use blocking publishes. + - Serving uses async publishes; `memory_only` switches to in-process handoff. - Options allow tuning behavior without exposing modes directly. + Only observability toggles are exposed; they never affect dataflow except + `memory_only`, which is serving-only and ignored elsewhere. """ - # If True, block at step end to publish updates (serving only). - flush_on_step_end: bool | None = None - # If True, pure in-memory execution (serving only). + # Serving-only: run without DB/artifact persistence using in-process handoff memory_only: bool = False - # If False, skip doc/source capture in metadata. + # Observability toggles code: bool = True - - def to_config_value(self) -> Dict[str, Any]: - """Convert the capture options to a config value. - - Returns: - The config value (no explicit mode; inferred by environment). - """ - cfg: Dict[str, Any] = {"code": self.code} - if self.flush_on_step_end is not None: - cfg["flush_on_step_end"] = bool(self.flush_on_step_end) - if self.memory_only: - cfg["memory_only"] = True - return cfg - - -@dataclass(frozen=True) -class BatchCapture: - """Batch (synchronous) capture configuration. - - Runs/steps and artifacts are always captured synchronously. Users should - adjust logging/metadata/visualization via pipeline settings, not capture. - """ - - mode: ServingMode = ServingMode.BATCH - - def to_config_value(self) -> Dict[str, Any]: - """Convert the batch capture to a config value.""" - return {"mode": self.mode.value} - - -@dataclass(frozen=True) -class RealtimeCapture: - """Realtime capture configuration for serving. - - - flush_on_step_end: if True, block at step end to publish updates. - - memory_only: if True, no server calls/runs/artifacts; in-process handoff. - """ - - mode: ServingMode = ServingMode.REALTIME - flush_on_step_end: bool = False - memory_only: bool = False - - def to_config_value(self) -> Dict[str, Any]: - """Convert the realtime capture to a config value. - - Returns: - The config value. - """ - config: Dict[str, Any] = {"mode": self.mode.value} - # Represent semantics using existing keys consumed by launcher/factory - config["flush_on_step_end"] = self.flush_on_step_end - if self.memory_only: - config["memory_only"] = True - return config - - def __post_init__(self) -> None: - """Post init.""" - # Contradictory: memory-only implies no server operations to flush - if self.memory_only and self.flush_on_step_end: - raise ValueError( - "Contradictory options: memory_only=True with flush_on_step_end=True. " - "Memory-only mode has no server operations to flush." - ) - - -# Unified capture config type alias -CaptureConfig = Union[Capture, BatchCapture, RealtimeCapture] - - -class CapturePolicy: - """Runtime-level capture policy used to select and configure runtimes. - - Provides a common interface for StepLauncher / factory code while the - code-level API remains typed (`Capture`, `BatchCapture`, `RealtimeCapture`). - """ - - def __init__( - self, - mode: ServingMode = ServingMode.BATCH, - options: Optional[Dict[str, Any]] = None, - ) -> None: - """Initialize the capture policy. - - Args: - mode: The mode to use. - options: The options to use. - """ - self.mode = mode - self.options = options or {} - - @staticmethod - def from_env() -> "CapturePolicy": - """Create a capture policy from environment defaults. - - Returns: - The capture policy. - """ - if os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None: - return CapturePolicy(mode=ServingMode.REALTIME, options={}) - val = os.getenv("ZENML_CAPTURE_MODE", "BATCH").upper() - try: - mode = ServingMode(val) - except ValueError: - mode = ServingMode.BATCH - return CapturePolicy(mode=mode, options={}) - - @staticmethod - def from_value( - value: Optional[Union[str, Capture, BatchCapture, RealtimeCapture]], - ) -> "CapturePolicy": - """Normalize typed or string capture value into a runtime policy. - - Args: - value: The value to normalize. - - Returns: - The capture policy. - """ - if value is None: - return CapturePolicy.from_env() - - if isinstance(value, RealtimeCapture): - return CapturePolicy( - mode=ServingMode.REALTIME, - options={ - "flush_on_step_end": value.flush_on_step_end, - "memory_only": bool(value.memory_only), - }, - ) - if isinstance(value, BatchCapture): - return CapturePolicy(mode=ServingMode.BATCH, options={}) - if isinstance(value, Capture): - pol = CapturePolicy.from_env() - opts: Dict[str, Any] = {} - if value.flush_on_step_end is not None: - opts["flush_on_step_end"] = bool(value.flush_on_step_end) - if value.memory_only: - opts["memory_only"] = True - if value.code is not None: - opts["code"] = bool(value.code) - pol.options.update(opts) - return pol - # String fallback (YAML / ENV) - try: - return CapturePolicy(mode=ServingMode(str(value).upper())) - except Exception: - return CapturePolicy.from_env() - - def get_option(self, key: str, default: Any = None) -> Any: - """Get an option from the capture policy. - - Args: - key: The key to get. - default: The default value to return if the key is not found. - - Returns: - The option value. - """ - return self.options.get(key, default) - - -# capture_to_config_value has been removed from code paths. Downstream consumers -# should use typed configs or CapturePolicy.from_value for YAML/ENV strings. + logs: bool = True + metadata: bool = True + visualizations: bool = True + metrics: bool = True diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index ca31bbf52ea..eca21ef4b5d 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -23,11 +23,10 @@ Mapping, Optional, Tuple, - Union, ) from zenml import __version__ -from zenml.capture.config import BatchCapture, RealtimeCapture +from zenml.capture.config import Capture from zenml.config.base_settings import BaseSettings, ConfigurationLevel from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.config.pipeline_run_configuration import PipelineRunConfiguration @@ -152,6 +151,26 @@ def compile( pipeline_spec=pipeline_spec, ) + # Populate canonical capture fields from typed pipeline configuration + cap: Optional[Capture] = pipeline.configuration.capture + mem_only = bool(getattr(cap, "memory_only", False)) if cap else False + code = bool(getattr(cap, "code", True)) if cap else True + logs = bool(getattr(cap, "logs", True)) if cap else True + metadata_enabled = ( + bool(getattr(cap, "metadata", True)) if cap else True + ) + visuals = bool(getattr(cap, "visualizations", True)) if cap else True + metrics = bool(getattr(cap, "metrics", True)) if cap else True + try: + setattr(deployment, "capture_memory_only", mem_only) + setattr(deployment, "capture_code", code) + setattr(deployment, "capture_logs", logs) + setattr(deployment, "capture_metadata", metadata_enabled) + setattr(deployment, "capture_visualizations", visuals) + setattr(deployment, "capture_metrics", metrics) + except Exception: + pass + logger.debug("Compiled pipeline deployment: %s", deployment) return deployment @@ -198,42 +217,13 @@ def _apply_run_configuration( config: The run configurations. """ with pipeline.__suppress_configure_warnings__(): - # Normalize run-level capture (str/dict) to typed for configure - cap_typed: Optional[Union[BatchCapture, RealtimeCapture]] = None - if isinstance(config.capture, str): - if config.capture.upper() == "REALTIME": - from zenml.capture.config import RealtimeCapture - - cap_typed = RealtimeCapture() - elif config.capture.upper() == "BATCH": - from zenml.capture.config import BatchCapture - - cap_typed = BatchCapture() - elif isinstance(config.capture, dict): - mode = str(config.capture.get("mode", "BATCH")).upper() - if mode == "REALTIME": - from zenml.capture.config import RealtimeCapture - - cap_typed = RealtimeCapture( - flush_on_step_end=bool( - config.capture.get("flush_on_step_end", False) - ), - memory_only=bool( - config.capture.get("memory_only", False) - ), - ) - else: - from zenml.capture.config import BatchCapture - - cap_typed = BatchCapture() - pipeline.configure( enable_cache=config.enable_cache, enable_artifact_metadata=config.enable_artifact_metadata, enable_artifact_visualization=config.enable_artifact_visualization, enable_step_logs=config.enable_step_logs, enable_pipeline_logs=config.enable_pipeline_logs, - capture=cap_typed, + capture=config.capture, settings=config.settings, tags=config.tags, extra=config.extra, diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 7a8b0726e53..f86382f884a 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -18,12 +18,7 @@ from pydantic import SerializeAsAny, field_validator -from zenml.capture.config import ( - BatchCapture, - Capture, - CaptureConfig, - RealtimeCapture, -) +from zenml.capture.config import Capture from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY from zenml.config.retry_config import StepRetryConfig from zenml.config.source import SourceWithValidator @@ -48,8 +43,8 @@ class PipelineConfigurationUpdate(StrictBaseModel): enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None - # Capture policy for execution semantics (typed only) - capture: Optional[CaptureConfig] = None + # Capture configuration (typed only) + capture: Optional[Capture] = None settings: Dict[str, SerializeAsAny[BaseSettings]] = {} tags: Optional[List[Union[str, "Tag"]]] = None extra: Dict[str, Any] = {} @@ -96,16 +91,14 @@ class PipelineConfiguration(PipelineConfigurationUpdate): @field_validator("capture") @classmethod def validate_capture_mode( - cls, value: Optional[CaptureConfig] - ) -> Optional[CaptureConfig]: + cls, value: Optional[Capture] + ) -> Optional[Capture]: """Validates the capture config (typed only).""" if value is None: return value - if isinstance(value, (Capture, BatchCapture, RealtimeCapture)): + if isinstance(value, Capture): return value - raise ValueError( - "'capture' must be a typed Capture, BatchCapture, or RealtimeCapture." - ) + raise ValueError("'capture' must be a typed Capture.") @field_validator("name") @classmethod diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 33e72d60682..18eb15ea403 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -18,6 +18,7 @@ from pydantic import Field, SerializeAsAny +from zenml.capture.config import Capture from zenml.config.base_settings import BaseSettings from zenml.config.retry_config import StepRetryConfig from zenml.config.schedule import Schedule @@ -41,8 +42,8 @@ class PipelineRunConfiguration( enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None - # Optional override for capture per run: mode string or dict with options - capture: Optional[Union[str, Dict[str, Any]]] = None + # Optional typed capture override per run (no dicts/strings) + capture: Optional[Capture] = None schedule: Optional[Schedule] = None build: Union[PipelineBuildBase, UUID, None] = Field( default=None, union_mode="left_to_right" diff --git a/src/zenml/deployers/serving/service.py b/src/zenml/deployers/serving/service.py index f3384d4797e..254db972fa5 100644 --- a/src/zenml/deployers/serving/service.py +++ b/src/zenml/deployers/serving/service.py @@ -19,19 +19,31 @@ """ import asyncio +import inspect import json import os import time +import traceback from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional from uuid import UUID, uuid4 +import numpy as np + from zenml.client import Client from zenml.integrations.registry import integration_registry from zenml.logger import get_logger from zenml.models import PipelineDeploymentResponse -from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.topsort import topsorted_layers +from zenml.orchestrators.utils import ( + extract_return_contract, + is_tracking_disabled, + response_tap_clear, + response_tap_get_all, + set_pipeline_state, + set_return_targets, + set_serving_context, +) from zenml.stack import Stack from zenml.utils import source_utils @@ -128,8 +140,6 @@ async def initialize(self) -> None: except Exception as e: logger.error(f"❌ Failed to initialize service: {str(e)}") logger.error(f" Error type: {type(e)}") - import traceback - logger.error(f" Traceback: {traceback.format_exc()}") raise @@ -193,10 +203,6 @@ def _extract_parameter_schema(self) -> Dict[str, Any]: self.deployment.pipeline_configuration, "spec", None ) if pipeline_spec and getattr(pipeline_spec, "source", None): - import inspect - - from zenml.utils import source_utils - # Load the pipeline function pipeline_func = source_utils.load(pipeline_spec.source) @@ -287,16 +293,12 @@ def _serialize_for_json(self, value: Any) -> Any: JSON-serializable representation of the value """ try: - import json - # Handle common ML types that aren't JSON serializable if hasattr(value, "tolist"): # numpy arrays, pandas Series return value.tolist() elif hasattr(value, "to_dict"): # pandas DataFrames return value.to_dict() elif hasattr(value, "__array__"): # numpy-like arrays - import numpy as np - return np.asarray(value).tolist() # Test if it's already JSON serializable @@ -323,7 +325,7 @@ async def execute_pipeline( logger.info("Starting pipeline execution") # Set up response capture - orchestrator_utils.response_tap_clear() + response_tap_clear() self._setup_return_targets() try: @@ -334,20 +336,16 @@ async def execute_pipeline( resolved_params ) # Expose pipeline state via serving context var - from zenml.orchestrators import utils as _orch_utils - - _orch_utils.set_pipeline_state(self.pipeline_state) + set_pipeline_state(self.pipeline_state) # Get deployment and check if we're in no-capture mode deployment = self.deployment - _ = orchestrator_utils.is_tracking_disabled( + _ = is_tracking_disabled( deployment.pipeline_configuration.settings ) - original_capture_default = os.environ.get( - "ZENML_SERVING_CAPTURE_DEFAULT" - ) - os.environ["ZENML_SERVING_CAPTURE_DEFAULT"] = "none" + # Mark serving context for the orchestrator/launcher + set_serving_context(True) # Build execution order using the production-tested topsort utility steps = deployment.step_configurations @@ -395,18 +393,12 @@ async def execute_pipeline( finally: orchestrator._cleanup_run() - # Restore original capture default environment variable - if original_capture_default is None: - os.environ.pop("ZENML_SERVING_CAPTURE_DEFAULT", None) - else: - os.environ["ZENML_SERVING_CAPTURE_DEFAULT"] = ( - original_capture_default - ) + # Clear serving context marker + set_serving_context(False) # Clear request params env and shared runtime state os.environ.pop("ZENML_SERVING_REQUEST_PARAMS", None) - from zenml.orchestrators.utils import set_pipeline_state - set_pipeline_state(None) + # No per-request capture override to clear try: from zenml.orchestrators.runtime_manager import ( clear_shared_runtime, @@ -419,7 +411,7 @@ async def execute_pipeline( pass # Get captured outputs from response tap - outputs = orchestrator_utils.response_tap_get_all() + outputs = response_tap_get_all() execution_time = time.time() - start self._update_execution_stats(True, execution_time) @@ -458,7 +450,7 @@ async def execute_pipeline( } finally: # Clean up response tap - orchestrator_utils.response_tap_clear() + response_tap_clear() async def submit_pipeline( self, @@ -597,7 +589,7 @@ def _setup_return_targets(self) -> None: else None ) contract = ( - orchestrator_utils.extract_return_contract(pipeline_source) + extract_return_contract(pipeline_source) if pipeline_source else None ) @@ -637,12 +629,12 @@ def _setup_return_targets(self) -> None: ) logger.debug(f"Return targets: {return_targets}") - orchestrator_utils.set_return_targets(return_targets) + set_return_targets(return_targets) except Exception as e: logger.warning(f"Failed to setup return targets: {e}") # Set empty targets as fallback - orchestrator_utils.set_return_targets({}) + set_return_targets({}) def is_healthy(self) -> bool: """Check if the service is healthy and ready to serve requests. diff --git a/src/zenml/execution/factory.py b/src/zenml/execution/factory.py index 7d5a5f00275..53f6f07c2f3 100644 --- a/src/zenml/execution/factory.py +++ b/src/zenml/execution/factory.py @@ -11,92 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Factory to construct a step runtime based on capture policy.""" +"""Factory to construct a step runtime based on context and capture.""" -from typing import Callable, Dict, Optional - -from zenml.capture.config import CaptureMode, CapturePolicy from zenml.execution.step_runtime import ( BaseStepRuntime, DefaultStepRuntime, MemoryStepRuntime, ) -# Registry of runtime builders keyed by capture mode -_RUNTIME_REGISTRY: Dict[ - CaptureMode, Callable[[CapturePolicy], BaseStepRuntime] -] = {} - - -def register_runtime( - mode: CaptureMode, builder: Callable[[CapturePolicy], BaseStepRuntime] -) -> None: - """Register a runtime builder for a capture mode.""" - _RUNTIME_REGISTRY[mode] = builder - - -def get_runtime(policy: Optional[CapturePolicy]) -> BaseStepRuntime: - """Return a runtime implementation using the registry. - - Falls back to the default runtime if no builder is registered. - - Args: - policy: The capture policy. - - Returns: - The runtime implementation. - """ - policy = policy or CapturePolicy() - builder = _RUNTIME_REGISTRY.get(policy.mode) - if builder is None: - raise ValueError( - f"No runtime registered for capture mode: {policy.mode}. " - "Expected one of: " - + ", ".join(m.name for m in _RUNTIME_REGISTRY.keys()) - ) - return builder(policy) - -# Register default builders -def _build_default(_: CapturePolicy) -> BaseStepRuntime: - """Build the default runtime. +def get_runtime( + *, serving: bool, memory_only: bool, metrics_enabled: bool = True +) -> BaseStepRuntime: + """Return a runtime implementation for the given context. Args: - policy: The capture policy. + serving: True if executing in serving context. + memory_only: True if serving should use in-process handoff. + metrics_enabled: Enable runtime metrics collection (realtime only). Returns: The runtime implementation. """ - return DefaultStepRuntime() - - -def _build_realtime(policy: CapturePolicy) -> BaseStepRuntime: - """Build the realtime runtime. - - Args: - policy: The capture policy. + if not serving: + return DefaultStepRuntime() + if memory_only: + return MemoryStepRuntime() - Returns: - The runtime implementation. - """ # Import here to avoid circular imports from zenml.execution.realtime_runtime import RealtimeStepRuntime - # If memory_only flagged, or legacy runs/persistence indicate memory-only, use memory runtime - memory_only = bool(policy.get_option("memory_only", False)) - runs_opt = str(policy.get_option("runs", "on")).lower() - persistence = str(policy.get_option("persistence", "async")).lower() - if ( - memory_only - or runs_opt in {"off", "false", "0"} - or persistence in {"memory", "off"} - ): - return MemoryStepRuntime() - - ttl = policy.get_option("ttl_seconds") - max_entries = policy.get_option("max_entries") - return RealtimeStepRuntime(ttl_seconds=ttl, max_entries=max_entries) - - -register_runtime(CaptureMode.BATCH, _build_default) -register_runtime(CaptureMode.REALTIME, _build_realtime) + rt = RealtimeStepRuntime() + # Gate metrics at the runtime if supported + if not metrics_enabled: + try: + setattr( + rt, "_metrics_disabled", True + ) # runtime may optionally read this + except Exception: + pass + return rt diff --git a/src/zenml/execution/realtime_runtime.py b/src/zenml/execution/realtime_runtime.py index 83c9ae69580..bee8b16712e 100644 --- a/src/zenml/execution/realtime_runtime.py +++ b/src/zenml/execution/realtime_runtime.py @@ -40,7 +40,14 @@ class RealtimeStepRuntime(DefaultStepRuntime): - """Realtime runtime optimized for low-latency loads via memory cache.""" + """Realtime runtime optimized for low-latency loads via memory cache. + + TODO(beta->prod): scale background publishing either by + - adding a small multi-worker thread pool (ThreadPoolExecutor), or + - migrating to an asyncio-based runtime once the client/publish calls have + async variants and we want an async mode in serving. + Both paths must keep bounded backpressure and orderly shutdown. + """ def __init__( self, @@ -59,7 +66,11 @@ def __init__( self._lock = threading.RLock() # Event queue: (kind, args, kwargs) Event = Tuple[str, Tuple[Any, ...], Dict[str, Any]] - self._q: queue.Queue[Event] = queue.Queue() + self._q: queue.Queue[Event] = queue.Queue(maxsize=1024) + # TODO(beta->prod): when scaling per-process publishing, prefer either + # (1) a small thread pool consuming this queue, or (2) an asyncio loop + # with an asyncio.Queue and async workers, once the client has async + # publish calls and we opt into async serving. self._worker: Optional[threading.Thread] = None self._stop = threading.Event() self._errors_since_last_flush: int = 0 @@ -74,6 +85,8 @@ def __init__( self._cache_hits: int = 0 self._cache_misses: int = 0 self._op_latencies: List[float] = [] + # TODO(beta->prod): add process memory monitoring and expose worker + # liveness/health at the service layer. # Tunables via env: TTL seconds and max entries # Options precedence: explicit args > env > defaults if ttl_seconds is not None: @@ -119,8 +132,8 @@ def __init__( ) except Exception: self._err_report_interval = 15.0 - # Flush behavior (can be disabled for serving non-blocking) - self._flush_on_step_end: bool = True + # Serving is async by default (non-blocking) + self._flush_on_step_end: bool = False # --- lifecycle --- def start(self) -> None: @@ -390,12 +403,25 @@ def publish_failed_step_run( if self._should_process_inline(): publish_utils.publish_failed_step_run(step_run_id) return - self._q.put(("step_failed", (), {"step_run_id": step_run_id})) - with self._lock: - self._queued_count += 1 + try: + self._q.put_nowait( + ("step_failed", (), {"step_run_id": step_run_id}) + ) + with self._lock: + self._queued_count += 1 + except queue.Full: + self._logger.debug("Queue full, processing step_failed inline") + try: + publish_utils.publish_failed_step_run(step_run_id) + except Exception as e: + self._logger.warning("Inline processing failed: %s", e) def flush(self) -> None: - """Flush the realtime runtime by draining queued events synchronously.""" + """Flush the realtime runtime by draining queued events synchronously. + + Raises: + RuntimeError: If background errors were encountered while draining. + """ # Drain the queue in the calling thread to avoid waiting on the worker while True: try: @@ -449,7 +475,10 @@ def on_step_end(self) -> None: return def shutdown(self) -> None: - """Shutdown the realtime runtime.""" + """Shutdown the realtime runtime. + + TODO(beta->prod): expose worker liveness/health signals to the service. + """ # Wait for remaining tasks and stop self.flush() self._stop.set() @@ -486,6 +515,10 @@ def get_metrics(self) -> Dict[str, Any]: Returns: The runtime metrics snapshot. """ + # TODO(beta->prod): export to an external sink (e.g., Prometheus) and + # expand with additional histograms / event counters as needed. + if bool(getattr(self, "_metrics_disabled", False)): + return {} with self._lock: queued = self._queued_count processed = self._processed_count @@ -545,27 +578,28 @@ def check_async_errors(self) -> None: # --- internal helpers --- def _sweep_expired(self) -> None: - """Remove expired entries from the head (LRU) side with a time budget.""" - deadline = time.time() + 0.003 + """Remove expired entries using a snapshot within a small time budget.""" + deadline = time.time() + 0.005 with self._lock: - while time.time() < deadline: - try: - key = next(iter(self._cache)) - except StopIteration: - break - try: - _, expires_at = self._cache[key] - except KeyError: - continue - if time.time() <= expires_at: - break - try: - del self._cache[key] - except KeyError: - pass + snapshot = list(self._cache.items()) + expired: List[str] = [] + now = time.time() + for key, (_val, expires_at) in snapshot: + if time.time() > deadline: + break + if now > expires_at: + expired.append(key) + if expired: + with self._lock: + for key in expired: + self._cache.pop(key, None) def _should_process_inline(self) -> bool: - """Return True if circuit breaker is open and we should publish inline.""" + """Return True if circuit breaker is open and we should publish inline. + + Returns: + True if inline processing should be used, False otherwise. + """ with self._lock: now = time.time() if now < self._cb_open_until_ts: diff --git a/src/zenml/execution/step_runtime.py b/src/zenml/execution/step_runtime.py index ca8bb18202e..ba1fd8924ea 100644 --- a/src/zenml/execution/step_runtime.py +++ b/src/zenml/execution/step_runtime.py @@ -21,13 +21,32 @@ """ import threading +import time from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from uuid import UUID +from zenml.artifacts.unmaterialized_artifact import ( + UnmaterializedArtifact, +) from zenml.client import Client +from zenml.enums import ArtifactSaveType +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.materializer_registry import materializer_registry from zenml.models import ArtifactVersionResponse +# Note: avoid importing zenml.orchestrators modules at import time to prevent +# circular dependencies. Where needed, import locally within methods. +from zenml.steps.step_context import get_step_context +from zenml.utils import ( + materializer_utils, + source_utils, + string_utils, + tag_utils, +) +from zenml.utils.typing_utils import get_origin, is_union + if TYPE_CHECKING: from zenml.artifact_stores import BaseArtifactStore from zenml.config.step_configurations import Step @@ -37,6 +56,8 @@ from zenml.stack import Stack from zenml.steps.utils import OutputSignature +logger = get_logger(__name__) + class BaseStepRuntime(ABC): """Abstract execution-time interface for step I/O and interactions. @@ -125,6 +146,7 @@ def compute_cache_key( Returns: The computed cache key. """ + # Local import to avoid circular import issues from zenml.orchestrators import cache_utils return cache_utils.generate_cache_key( @@ -147,6 +169,7 @@ def get_cached_step_run( Returns: The cached step run if available, otherwise None. """ + # Local import to avoid circular import issues from zenml.orchestrators import cache_utils return cache_utils.get_cached_step_run(cache_key=cache_key) @@ -167,11 +190,13 @@ def publish_pipeline_run_metadata( pipeline_run_id: The pipeline run ID. pipeline_run_metadata: The pipeline run metadata. """ + if not bool(getattr(self, "_metadata_enabled", True)): + return from zenml.orchestrators.publish_utils import ( - publish_pipeline_run_metadata, + publish_pipeline_run_metadata as _pub_run_md, ) - publish_pipeline_run_metadata( + _pub_run_md( pipeline_run_id=pipeline_run_id, pipeline_run_metadata=pipeline_run_metadata, ) @@ -185,9 +210,13 @@ def publish_step_run_metadata( step_run_id: The step run ID. step_run_metadata: The step run metadata. """ - from zenml.orchestrators.publish_utils import publish_step_run_metadata + if not bool(getattr(self, "_metadata_enabled", True)): + return + from zenml.orchestrators.publish_utils import ( + publish_step_run_metadata as _pub_step_md, + ) - publish_step_run_metadata( + _pub_step_md( step_run_id=step_run_id, step_run_metadata=step_run_metadata ) @@ -201,10 +230,10 @@ def publish_successful_step_run( output_artifact_ids: The output artifact IDs. """ from zenml.orchestrators.publish_utils import ( - publish_successful_step_run, + publish_successful_step_run as _pub_step_success, ) - publish_successful_step_run( + _pub_step_success( step_run_id=step_run_id, output_artifact_ids=output_artifact_ids ) @@ -214,9 +243,11 @@ def publish_failed_step_run(self, *, step_run_id: Any) -> None: Args: step_run_id: The step run ID. """ - from zenml.orchestrators.publish_utils import publish_failed_step_run + from zenml.orchestrators.publish_utils import ( + publish_failed_step_run as _pub_step_failed, + ) - publish_failed_step_run(step_run_id) + _pub_step_failed(step_run_id) def flush(self) -> None: """Ensure all queued updates are sent.""" @@ -230,7 +261,8 @@ def shutdown(self) -> None: def get_metrics(self) -> Dict[str, Any]: """Optional runtime metrics for observability. - Default implementation returns an empty dict. + Returns: + Dictionary of runtime metrics; empty by default. """ return {} @@ -239,6 +271,9 @@ def should_flush_on_step_end(self) -> bool: """Whether the runner should call flush() at step end. Implementations may override to disable flush for non-blocking serving. + + Returns: + True to flush on step end; False otherwise. """ return True @@ -264,8 +299,12 @@ def resolve_step_inputs( Args: step: The step to resolve inputs for. pipeline_run: The pipeline run to resolve inputs for. - step_runs: The step runs to resolve inputs for. + step_runs: Optional map of step runs. + + Returns: + Mapping from input name to resolved step run input. """ + # Local import to avoid circular import issues from zenml.orchestrators import input_utils return input_utils.resolve_step_inputs( @@ -286,26 +325,17 @@ def load_input_artifact( artifact: The artifact to load. data_type: The data type of the artifact. stack: The stack to load the artifact from. - """ - from typing import Any as _Any - - from zenml.artifacts.unmaterialized_artifact import ( - UnmaterializedArtifact, - ) - from zenml.materializers.base_materializer import BaseMaterializer - from zenml.orchestrators.utils import ( - register_artifact_store_filesystem, - ) - from zenml.utils import source_utils - from zenml.utils.typing_utils import get_origin, is_union + Returns: + The loaded Python value for the input artifact. + """ # Skip materialization for `UnmaterializedArtifact`. if data_type == UnmaterializedArtifact: return UnmaterializedArtifact( **artifact.get_hydrated_version().model_dump() ) - if data_type in (None, _Any) or is_union(get_origin(data_type)): + if data_type in (None, Any) or is_union(get_origin(data_type)): # Use the stored artifact datatype when function annotation is not specific data_type = source_utils.load(artifact.data_type) @@ -326,6 +356,11 @@ def _load(artifact_store: "BaseArtifactStore") -> Any: stack.artifact_store._register() return _load(artifact_store=stack.artifact_store) else: + # Local import to avoid circular import issues + from zenml.orchestrators.utils import ( + register_artifact_store_filesystem, + ) + with register_artifact_store_filesystem( artifact.artifact_store_id ) as target_store: @@ -353,17 +388,20 @@ def store_output_artifacts( artifact_visualization_enabled: Whether artifact visualization is enabled. Returns: - The stored artifacts. - """ - from typing import Type as _Type + Mapping from output name to stored artifact version. - from zenml.artifacts.utils import ( - _store_artifact_data_and_prepare_request, + Raises: + RuntimeError: If artifact batch creation fails after retries or + the number of responses does not match requests. + """ + # Apply capture toggles for metadata and visualizations + artifact_metadata_enabled = artifact_metadata_enabled and bool( + getattr(self, "_metadata_enabled", True) + ) + artifact_visualization_enabled = ( + artifact_visualization_enabled + and bool(getattr(self, "_visualizations_enabled", True)) ) - from zenml.enums import ArtifactSaveType - from zenml.materializers.base_materializer import BaseMaterializer - from zenml.steps.step_context import get_step_context - from zenml.utils import materializer_utils, source_utils, tag_utils step_context = get_step_context() artifact_requests: List[Any] = [] @@ -372,7 +410,7 @@ def store_output_artifacts( data_type = type(return_value) materializer_classes = output_materializers[output_name] if materializer_classes: - materializer_class: _Type[BaseMaterializer] = ( + materializer_class: Type[BaseMaterializer] = ( materializer_utils.select_materializer( data_type=data_type, materializer_classes=materializer_classes, @@ -380,10 +418,6 @@ def store_output_artifacts( ) else: # Runtime selection if no explicit materializer recorded - from zenml.materializers.materializer_registry import ( - materializer_registry, - ) - default_materializer_source = ( step_context.step_run.config.outputs[ output_name @@ -393,7 +427,7 @@ def store_output_artifacts( ) if default_materializer_source: - default_materializer_class: _Type[BaseMaterializer] = ( + default_materializer_class: Type[BaseMaterializer] = ( source_utils.load_and_validate_class( default_materializer_source, expected_class=BaseMaterializer, @@ -435,6 +469,11 @@ def store_output_artifacts( if isinstance(tag, tag_utils.Tag) and tag.cascade is True: tags.append(tag.name) + # Store artifact data and prepare a request to the server. + from zenml.artifacts.utils import ( + _store_artifact_data_and_prepare_request, + ) + artifact_request = _store_artifact_data_and_prepare_request( name=artifact_name, data=return_value, @@ -451,53 +490,51 @@ def store_output_artifacts( ) artifact_requests.append(artifact_request) - responses = Client().zen_store.batch_create_artifact_versions( - artifact_requests - ) - return dict(zip(output_data.keys(), responses)) - - -class OffStepRuntime(DefaultStepRuntime): - """OFF mode runtime: minimize overhead but keep correctness. - - Notes: - - We intentionally keep artifact persistence and success/failure status - updates to avoid breaking input resolution across steps. - - We no-op metadata publishing calls to reduce server traffic. - """ - - def publish_pipeline_run_metadata( - self, *, pipeline_run_id: Any, pipeline_run_metadata: Any - ) -> None: - """Publish pipeline run metadata. + max_retries = 2 + delay = 1.0 - Args: - pipeline_run_id: The pipeline run ID. - pipeline_run_metadata: The pipeline run metadata. - """ - # No-op: skip pipeline run metadata in OFF mode - return - - def publish_step_run_metadata( - self, *, step_run_id: Any, step_run_metadata: Any - ) -> None: - """Publish step run metadata. + for attempt in range(max_retries + 1): + try: + responses = Client().zen_store.batch_create_artifact_versions( + artifact_requests + ) + if len(responses) != len(artifact_requests): + raise RuntimeError( + f"Artifact batch creation returned {len(responses)}/{len(artifact_requests)} responses" + ) + return dict(zip(output_data.keys(), responses)) + except Exception as e: + if attempt < max_retries: + logger.warning( + "Artifact creation attempt %s failed: %s. Retrying in %.1fs...", + attempt + 1, + e, + delay, + ) + time.sleep(delay) + delay *= 1.5 + else: + logger.error( + "Failed to create artifacts after %s attempts: %s. Failing step to avoid inconsistency.", + max_retries + 1, + e, + ) + raise - Args: - step_run_id: The step run ID. - step_run_metadata: The step run metadata. - """ - # No-op: skip step run metadata in OFF mode - return + # TODO(beta->prod): Align with server to provide atomic batch create or + # compensating deletes. Consider idempotent requests and retriable error + # categories with jittered backoff. + raise RuntimeError( + "Artifact creation failed unexpectedly without raising" + ) class MemoryStepRuntime(BaseStepRuntime): - """Pure in-memory execution runtime: no server calls, no persistence.""" + """Pure in-memory execution runtime: no server calls, no persistence. - # Global registry keyed by run_id to isolate concurrent runs - _STORE: Dict[str, Dict[Tuple[str, str], Any]] = {} - _RUN_LOCKS: Dict[str, Any] = {} - _GLOBAL_LOCK: Any = threading.RLock() # protects registry structures + Instance-scoped store to isolate requests. Values are accessible within the + same process for the same run id and step chain only. + """ @staticmethod def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: @@ -522,6 +559,9 @@ def parse_handle_id(handle_id: str) -> Tuple[str, str, str]: Returns: The run ID, step name, and output name. + + Raises: + ValueError: If the handle id is malformed. """ if not isinstance(handle_id, str) or not handle_id.startswith( "mem://" @@ -557,6 +597,10 @@ def __init__(self) -> None: self._ctx_run_id: Optional[str] = None self._ctx_substitutions: Dict[str, str] = {} self._active_run_ids: set[str] = set() + # Instance-scoped storage and locks per run_id + self._store: Dict[str, Dict[Tuple[str, str], Any]] = {} + self._run_locks: Dict[str, Any] = {} + self._global_lock: Any = threading.RLock() def set_context( self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None @@ -592,8 +636,6 @@ def resolve_step_inputs( Returns: A mapping of input name to MemoryStepRuntime.Handle. """ - from zenml.utils import string_utils - run_id = self._ctx_run_id or str(getattr(pipeline_run, "id", "local")) subs = self._ctx_substitutions or {} handles: Dict[str, Any] = {} @@ -621,20 +663,19 @@ def load_input_artifact( Returns: The loaded artifact. + + Raises: + ValueError: If the memory handle id is invalid or malformed. """ handle_id_any = getattr(artifact, "id", None) if not isinstance(handle_id_any, str): raise ValueError("Invalid memory handle id") run_id, step_name, output_name = self.parse_handle_id(handle_id_any) # Use per-run lock to avoid cross-run interference - with MemoryStepRuntime._GLOBAL_LOCK: - rlock = MemoryStepRuntime._RUN_LOCKS.setdefault( - run_id, threading.RLock() - ) + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) with rlock: - return MemoryStepRuntime._STORE.get(run_id, {}).get( - (step_name, output_name) - ) + return self._store.get(run_id, {}).get((step_name, output_name)) def store_output_artifacts( self, @@ -659,8 +700,6 @@ def store_output_artifacts( Returns: The stored artifacts. """ - from zenml.steps.step_context import get_step_context - ctx = get_step_context() run_id = str(getattr(ctx.pipeline_run, "id", "local")) try: @@ -670,12 +709,10 @@ def store_output_artifacts( pass step_name = str(getattr(ctx.step_run, "name", "step")) handles: Dict[str, Any] = {} - with MemoryStepRuntime._GLOBAL_LOCK: - rlock = MemoryStepRuntime._RUN_LOCKS.setdefault( - run_id, threading.RLock() - ) + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) with rlock: - rr = MemoryStepRuntime._STORE.setdefault(run_id, {}) + rr = self._store.setdefault(run_id, {}) for output_name, value in output_data.items(): rr[(step_name, output_name)] = value handle_id = self.make_handle_id(run_id, step_name, output_name) @@ -793,8 +830,8 @@ def reset(self, run_id: str) -> None: Args: run_id: The run id to clear. """ - with MemoryStepRuntime._GLOBAL_LOCK: + with self._global_lock: try: - MemoryStepRuntime._STORE.pop(run_id, None) + self._store.pop(run_id, None) finally: - MemoryStepRuntime._RUN_LOCKS.pop(run_id, None) + self._run_locks.pop(run_id, None) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 3305e21dd70..682d9e08d4f 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -308,15 +308,11 @@ def main() -> None: for owner_reference in owner_references: owner_reference.controller = False - # Build a runtime for request factory using capture mode from config + # Build a runtime for request factory (batch context) try: - from zenml.capture.config import CapturePolicy from zenml.execution.factory import get_runtime - mode_cfg = getattr( - deployment.pipeline_configuration, "capture", None - ) - _runtime = get_runtime(CapturePolicy.from_value(mode_cfg)) + _runtime = get_runtime(serving=False, memory_only=False) except Exception: _runtime = None diff --git a/src/zenml/models/v2/core/pipeline_deployment.py b/src/zenml/models/v2/core/pipeline_deployment.py index 947185e7c7b..40085ade47c 100644 --- a/src/zenml/models/v2/core/pipeline_deployment.py +++ b/src/zenml/models/v2/core/pipeline_deployment.py @@ -75,6 +75,24 @@ class PipelineDeploymentBase(BaseZenModel): default=None, title="The pipeline spec of the deployment.", ) + # Canonical capture fields (single source of truth at runtime) + capture_memory_only: bool = Field( + default=False, + title="Serving-only: execute in memory without persistence.", + ) + capture_code: bool = Field( + default=True, title="Capture code/source/docstrings in metadata." + ) + capture_logs: bool = Field(default=True, title="Persist step logs.") + capture_metadata: bool = Field( + default=True, title="Publish run/step metadata." + ) + capture_visualizations: bool = Field( + default=True, title="Persist artifact visualizations." + ) + capture_metrics: bool = Field( + default=True, title="Emit runtime metrics (realtime)." + ) @property def should_prevent_build_reuse(self) -> bool: @@ -165,6 +183,24 @@ class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata): default=None, title="Optional path where the code is stored in the artifact store.", ) + # Canonical capture fields (mirrored on response) + capture_memory_only: bool = Field( + default=False, + title="Serving-only: execute in memory without persistence.", + ) + capture_code: bool = Field( + default=True, title="Capture code/source/docstrings in metadata." + ) + capture_logs: bool = Field(default=True, title="Persist step logs.") + capture_metadata: bool = Field( + default=True, title="Publish run/step metadata." + ) + capture_visualizations: bool = Field( + default=True, title="Persist artifact visualizations." + ) + capture_metrics: bool = Field( + default=True, title="Emit runtime metrics (realtime)." + ) pipeline: Optional[PipelineResponse] = Field( default=None, title="The pipeline associated with the deployment." diff --git a/src/zenml/orchestrators/run_entity_manager.py b/src/zenml/orchestrators/run_entity_manager.py index 3dc9ce7bf25..a1cce98a902 100644 --- a/src/zenml/orchestrators/run_entity_manager.py +++ b/src/zenml/orchestrators/run_entity_manager.py @@ -147,9 +147,25 @@ class _StepRunStub: outputs: Dict[str, Any] = None # type: ignore[assignment] regular_inputs: Dict[str, Any] = None # type: ignore[assignment] + # Minimal config for template substitutions used by StepRunner + @dataclass + class _Cfg: + # Step-level toggles are optional and may be None + enable_step_logs: Optional[bool] = None + enable_artifact_metadata: Optional[bool] = None + enable_artifact_visualization: Optional[bool] = None + substitutions: Dict[str, str] = None # type: ignore[assignment] + + config: Any = _Cfg() + def __post_init__(self) -> None: # noqa: D401 self.outputs = {} self.regular_inputs = {} + # Default to empty substitutions mapping + try: + self.config.substitutions = {} + except Exception: + pass return _StepRunStub(id=run_id, name=step_name) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 3f7d531f91f..76a6db52e78 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -20,9 +20,8 @@ from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple -from zenml.capture.config import CapturePolicy from zenml.client import Client -from zenml.config.step_configurations import Step +from zenml.config.step_configurations import Step, StepConfiguration from zenml.config.step_run_info import StepRunInfo from zenml.constants import ( ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, @@ -55,6 +54,7 @@ get_or_create_shared_memory_runtime, ) from zenml.orchestrators.step_runner import StepRunner +from zenml.orchestrators.utils import is_serving_context from zenml.stack import Stack from zenml.utils import exception_utils, string_utils from zenml.utils.time_utils import utc_now @@ -148,6 +148,90 @@ def __init__( self._step_run: Optional[StepRunResponse] = None self._setup_signal_handlers() + # --- Serving helpers --- + def _validate_and_merge_request_params( + self, + req_params: Dict[str, Any], + effective_step_config: StepConfiguration, + ) -> Dict[str, Any]: + """Safely merge request parameters with allowlist and light validation. + + Only keys already declared in the pipeline parameters are merged. + Performs simple type-coercion against defaults where possible and + applies size limits to avoid oversized payloads. + + TODO(beta->prod): derive expected types from the pipeline entrypoint + annotations (or a generated parameter schema) instead of the current + defaults-based heuristic; add a total payload size limit. + + Args: + req_params: Raw parameters dictionary from the request. + effective_step_config: The current effective step configuration. + + Returns: + Merged and validated parameters dictionary. + """ + if not req_params: + return effective_step_config.parameters or {} + + declared = set((effective_step_config.parameters or {}).keys()) + allowed = {k: v for k, v in req_params.items() if k in declared} + dropped = set(req_params.keys()) - declared + if dropped: + logger.warning( + "Dropping unknown request parameters: %s", sorted(dropped) + ) + + validated: Dict[str, Any] = {} + for key, value in allowed.items(): + # Size limits + try: + if isinstance(value, str) and len(value) > 10_000: + logger.warning( + "Dropping oversized string parameter '%s' (%s chars)", + key, + len(value), + ) + continue + if ( + isinstance(value, (list, dict)) + and len(str(value)) > 50_000 + ): + logger.warning( + "Dropping oversized collection parameter '%s'", key + ) + continue + except Exception: + # If size introspection fails, keep conservative and drop + logger.warning( + "Dropping parameter '%s' due to size check error", key + ) + continue + + # Type coercion against defaults, if present + try: + defaults = effective_step_config.parameters or {} + if key in defaults and defaults[key] is not None: + expected_t = type(defaults[key]) + if not isinstance(value, expected_t): + try: + value = expected_t(value) # best-effort coercion + except Exception: + logger.warning( + "Type mismatch for parameter '%s', dropping", + key, + ) + continue + except Exception: + # On any error, accept original value (already allowlisted) + pass + + validated[key] = value + + merged = dict(effective_step_config.parameters or {}) + merged.update(validated) + return merged + def _setup_signal_handlers(self) -> None: """Set up signal handlers for graceful shutdown, chaining previous handlers.""" try: @@ -246,28 +330,69 @@ def launch(self) -> None: else None ) - # Determine capture-based runtime and memory-only mode early - mode_cfg = getattr( - self._deployment.pipeline_configuration, "capture", None + # Determine serving context and canonical capture flags + in_serving_ctx = is_serving_context() + mem_only_flag = bool( + getattr(self._deployment, "capture_memory_only", False) + ) + # Dev fallback: if canonical field missing or False, derive from typed capture + if not mem_only_flag: + try: + from zenml.capture.config import Capture as _Cap + + cap_typed = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + if isinstance(cap_typed, _Cap) and bool( + getattr(cap_typed, "memory_only", False) + ): + mem_only_flag = True + except Exception: + pass + # memory_only applies only in serving; warn and ignore otherwise + memory_only = mem_only_flag if in_serving_ctx else False + if mem_only_flag and not in_serving_ctx: + logger.warning( + "memory_only=True configured but not in serving; ignoring." + ) + + metrics_enabled = bool( + getattr(self._deployment, "capture_metrics", True) + ) + if metrics_enabled is True: + try: + from zenml.capture.config import Capture as _Cap + + cap_typed = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + if isinstance(cap_typed, _Cap): + metrics_enabled = bool(getattr(cap_typed, "metrics", True)) + except Exception: + pass + runtime = get_runtime( + serving=in_serving_ctx, + memory_only=memory_only, + metrics_enabled=metrics_enabled, ) - capture_policy = CapturePolicy.from_value(mode_cfg) - runtime = get_runtime(capture_policy) # Store for later use self._runtime = runtime - # Serving context detection - in_serving_ctx = os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None - memory_only = bool(capture_policy.get_option("memory_only", False)) - # Debug messages to clarify behavior - if capture_policy.mode.name == "REALTIME" and not in_serving_ctx: - logger.warning( - "REALTIME mode enabled outside serving (development). Performance/ordering may vary." + # Apply observability toggles to runtime + try: + setattr( + runtime, + "_metadata_enabled", + bool(getattr(self._deployment, "capture_metadata", True)), ) - if memory_only and not in_serving_ctx: - # Ignore memory_only outside serving: fall back to normal (Batch/Realtime) behavior - logger.warning( - "memory_only=True requested outside serving; ignoring and proceeding with standard execution." + setattr( + runtime, + "_visualizations_enabled", + bool( + getattr(self._deployment, "capture_visualizations", True) + ), ) - memory_only = False + except Exception: + pass # Select entity manager and, if memory-only, set up shared runtime is_memory_only_path = memory_only and in_serving_ctx @@ -279,35 +404,21 @@ def launch(self) -> None: self._runtime = shared except Exception: pass + logger.info( + "[Memory-only] Serving context detected; using in-process memory handoff (no runs/artifacts)." + ) entity_manager = MemoryRunEntityManager(self) else: entity_manager = DefaultRunEntityManager(self) pipeline_run, run_was_created = entity_manager.create_or_reuse_run() - if ( - capture_policy.mode.name == "REALTIME" - and "flush_on_step_end" - not in getattr(capture_policy, "options", {}) - and in_serving_ctx - ): - flush_opt = False - else: - # Honor capture option: flush_on_step_end (default True) - flush_opt = capture_policy.get_option("flush_on_step_end", True) - # Configure runtime flush behavior if supported - set_flush = getattr(runtime, "set_flush_on_step_end", None) - if callable(set_flush): - try: - set_flush(bool(flush_opt)) - except Exception as e: - logger.debug( - "Could not configure runtime flush behavior: %s", e - ) + # No flush configuration: batch is blocking, serving is async by default # Enable or disable step logs storage if ( handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False) or tracking_disabled + or is_memory_only_path # never persist logs in memory-only ): step_logging_enabled = False else: @@ -319,6 +430,11 @@ def launch(self) -> None: logs_context = nullcontext() logs_model = None + # Apply observability toggle from canonical capture + capture_logs = bool(getattr(self._deployment, "capture_logs", True)) + if not capture_logs: + step_logging_enabled = False + if step_logging_enabled and not tracking_disabled: # Configure the logs logs_uri = step_logging.prepare_logs_uri( @@ -353,8 +469,7 @@ def launch(self) -> None: ) # Honor capture.code flag (default True) - code_opt = capture_policy.get_option("code", True) - code_enabled = str(code_opt).lower() not in {"false", "0", "off"} + code_enabled = bool(getattr(self._deployment, "capture_code", True)) # Prepare step run creation if isinstance(entity_manager, DefaultRunEntityManager): @@ -584,23 +699,26 @@ def _run_step( } ) - # Merge request-level parameters (serving) for memory-only runtime + # Merge request-level parameters in serving (applies to all runtimes) runtime = getattr(self, "_runtime", None) - if isinstance(runtime, MemoryStepRuntime): + if is_serving_context(): try: req_env = os.getenv("ZENML_SERVING_REQUEST_PARAMS") req_params = json.loads(req_env) if req_env else {} - if not req_params: - req_params = ( - self._deployment.pipeline_configuration.parameters - or {} - ) if req_params: - merged = dict(effective_step_config.parameters or {}) - merged.update(req_params) + merged = self._validate_and_merge_request_params( + req_params, effective_step_config + ) effective_step_config = effective_step_config.model_copy( update={"parameters": merged} ) + try: + logger.info( + "[Serving] Request parameters merged into step config: %s", + sorted(list(req_params.keys())), + ) + except Exception: + pass except Exception: pass @@ -660,9 +778,11 @@ def _run_step( output_artifact_uris=output_artifact_uris, ) except: # noqa: E722 - output_utils.remove_artifact_dirs( - artifact_uris=list(output_artifact_uris.values()) - ) + # Best-effort cleanup only for filesystem URIs + if not isinstance(runtime, MemoryStepRuntime): + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) + ) raise duration = time.time() - start_time @@ -679,9 +799,7 @@ def _run_step( self._step_name in cfg.spec.upstream_steps for name, cfg in self._deployment.step_configurations.items() ) - should_flush = has_downstream or ( - os.getenv("ZENML_SERVING_CAPTURE_DEFAULT") is not None - ) + should_flush = has_downstream or is_serving_context() if should_flush: try: runtime.flush() @@ -723,29 +841,7 @@ def _run_step_with_step_operator( environment.update(secrets) environment[ENV_ZENML_STEP_OPERATOR] = "True" - # Propagate capture mode to the step operator environment so that - # the entrypoint can construct the appropriate runtime. - try: - mode_cfg = getattr( - self._deployment.pipeline_configuration, "capture", None - ) - if mode_cfg: - # If typed capture with explicit mode, export it; unified Capture has no mode - try: - from zenml.capture.config import ( - BatchCapture, - RealtimeCapture, - ) - - if isinstance(mode_cfg, RealtimeCapture): - environment["ZENML_CAPTURE_MODE"] = "REALTIME" - elif isinstance(mode_cfg, BatchCapture): - environment["ZENML_CAPTURE_MODE"] = "BATCH" - except Exception: - pass - environment["ZENML_ENABLE_STEP_RUNTIME"] = "true" - except Exception: - pass + # No capture mode propagation; runtime behavior derived from context logger.info( "Using step operator `%s` to run step `%s`.", step_operator.name, diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index ef23134c816..a3deffd1aa7 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -112,7 +112,7 @@ def is_tracking_enabled( - 'none' (case-insensitive) or False -> disable tracking - any other value or missing -> enable tracking - For serving, respects ZENML_SERVING_CAPTURE_DEFAULT when pipeline settings are absent. + Serving context does not change this; capture options are typed-only. Args: pipeline_settings: Pipeline configuration settings mapping, if any. @@ -121,27 +121,11 @@ def is_tracking_enabled( Whether tracking should be enabled. """ if not pipeline_settings: - # Check for serving default when no pipeline settings - import os - - serving_default = ( - os.getenv("ZENML_SERVING_CAPTURE_DEFAULT", "").strip().lower() - ) - if serving_default in {"none", "off", "false", "0", "disabled"}: - return False return True try: capture_value = pipeline_settings.get("capture") if capture_value is None: - # Check for serving default when capture setting is missing - import os - - serving_default = ( - os.getenv("ZENML_SERVING_CAPTURE_DEFAULT", "").strip().lower() - ) - if serving_default in {"none", "off", "false", "0", "disabled"}: - return False return True if isinstance(capture_value, bool): return capture_value @@ -176,7 +160,14 @@ def is_tracking_enabled( def is_tracking_disabled( pipeline_settings: Optional[Dict[str, Any]] = None, ) -> bool: - """True if tracking/persistence should be disabled completely.""" + """True if tracking/persistence should be disabled completely. + + Args: + pipeline_settings: Optional pipeline settings mapping. + + Returns: + True if tracking should be disabled, False otherwise. + """ return not is_tracking_enabled(pipeline_settings) @@ -187,17 +178,51 @@ def is_tracking_disabled( def tap_store_step_outputs(step_name: str, outputs: Dict[str, Any]) -> None: - """Store step outputs in the serve tap for in-memory handoff.""" + """Store step outputs in the serve tap for in-memory handoff. + + Args: + step_name: Name of the step producing outputs. + outputs: Mapping of output name to value. + """ current_tap = _serve_output_tap.get({}) current_tap[step_name] = outputs _serve_output_tap.set(current_tap) def tap_get_step_outputs(step_name: str) -> Optional[Dict[str, Any]]: - """Get step outputs from the serve tap.""" + """Get step outputs from the serve tap. + + Args: + step_name: Name of the step whose outputs to fetch. + + Returns: + Optional mapping of outputs for the step if present, else None. + """ return _serve_output_tap.get({}).get(step_name) +# Serving context marker +_serving_ctx: ContextVar[bool] = ContextVar("serving_ctx", default=False) + + +def set_serving_context(value: bool) -> None: + """Set whether the current execution is in a serving context. + + Args: + value: True if running inside the serving service, else False. + """ + _serving_ctx.set(bool(value)) + + +def is_serving_context() -> bool: + """Return True if running inside a serving context. + + Returns: + True if serving context is active, otherwise False. + """ + return _serving_ctx.get() + + # Serve pipeline state context _serve_pipeline_state: ContextVar[Optional[Any]] = ContextVar( "serve_pipeline_state", default=None @@ -205,12 +230,20 @@ def tap_get_step_outputs(step_name: str) -> Optional[Dict[str, Any]]: def set_pipeline_state(state: Optional[Any]) -> None: - """Set pipeline state for serving context.""" + """Set pipeline state for serving context. + + Args: + state: Optional pipeline state object to associate with this request. + """ _serve_pipeline_state.set(state) def get_pipeline_state() -> Optional[Any]: - """Get pipeline state for serving context.""" + """Get pipeline state for serving context. + + Returns: + Optional pipeline state object if set, else None. + """ return _serve_pipeline_state.get(None) diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index 7c83cb745ec..0dc30c40bb7 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -25,11 +25,7 @@ overload, ) -from zenml.capture.config import ( - BatchCapture, - Capture, - RealtimeCapture, -) +from zenml.capture.config import Capture from zenml.logger import get_logger if TYPE_CHECKING: @@ -67,7 +63,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None, + capture: Optional[Capture] = None, ) -> Callable[["F"], "Pipeline"]: ... @@ -89,7 +85,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None, + capture: Optional[Capture] = None, ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]: """Decorator to create a pipeline. @@ -130,14 +126,7 @@ def inner_decorator(func: "F") -> "Pipeline": from zenml.pipelines.pipeline_definition import Pipeline # Directly store typed capture config - cap = capture - cap_val: Optional[Union[Capture, BatchCapture, RealtimeCapture]] = None - if cap is not None: - if not isinstance(cap, (Capture, BatchCapture, RealtimeCapture)): - raise ValueError( - "'capture' must be a Capture, BatchCapture or RealtimeCapture." - ) - cap_val = cap + cap_val = capture p = Pipeline( name=name or func.__name__, diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index d94ad89e48d..cd67cb51b9a 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -42,11 +42,7 @@ from zenml import constants from zenml.analytics.enums import AnalyticsEvent from zenml.analytics.utils import track_handler -from zenml.capture.config import ( - BatchCapture, - Capture, - RealtimeCapture, -) +from zenml.capture.config import Capture from zenml.client import Client from zenml.config.compiler import Compiler from zenml.config.pipeline_configurations import ( @@ -153,9 +149,7 @@ def __init__( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[ - Union[Capture, BatchCapture, RealtimeCapture] - ] = None, + capture: Optional[Capture] = None, ) -> None: """Initializes a pipeline. @@ -188,8 +182,7 @@ def __init__( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. - capture: Capture policy for the pipeline (typed only): Capture, - BatchCapture or RealtimeCapture. + capture: Capture configuration for the pipeline (typed only). """ self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -341,9 +334,7 @@ def configure( parameters: Optional[Dict[str, Any]] = None, merge: bool = True, substitutions: Optional[Dict[str, str]] = None, - capture: Optional[ - Union[Capture, BatchCapture, RealtimeCapture] - ] = None, + capture: Optional[Capture] = None, ) -> Self: """Configures the pipeline. @@ -390,8 +381,7 @@ def configure( retry: Retry configuration for the pipeline steps. parameters: input parameters for the pipeline. substitutions: Extra placeholders to use in the name templates. - capture: Capture policy for the pipeline (typed only). Use - BatchCapture/RealtimeCapture or omit entirely to use sensible defaults. + capture: Capture configuration for the pipeline (typed only). Returns: The pipeline instance that this method was called on. @@ -422,15 +412,7 @@ def configure( tags = self._configuration.tags + tags # Directly store typed capture config - cap_norm = None - if capture is not None: - if not isinstance( - capture, (Capture, BatchCapture, RealtimeCapture) - ): - raise ValueError( - "'capture' must be a Capture, BatchCapture or RealtimeCapture." - ) - cap_norm = capture + cap_norm = capture values = dict_utils.remove_none_values( { diff --git a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py index 7641fea8df9..d2321d8afe9 100644 --- a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence from uuid import UUID -from sqlalchemy import TEXT, Column, String, UniqueConstraint +from sqlalchemy import TEXT, Boolean, Column, String, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import joinedload, object_session from sqlalchemy.sql.base import ExecutableOption @@ -87,6 +87,26 @@ class PipelineDeploymentSchema(BaseSchema, table=True): ) code_path: Optional[str] = Field(nullable=True) + # Canonical capture fields + capture_memory_only: bool = Field( + sa_column=Column(Boolean, nullable=False, default=False), default=False + ) + capture_code: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_logs: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_metadata: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_visualizations: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_metrics: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + # Foreign keys user_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, @@ -319,6 +339,14 @@ def from_request( if request.pipeline_spec else None, code_path=request.code_path, + capture_memory_only=getattr(request, "capture_memory_only", False), + capture_code=getattr(request, "capture_code", True), + capture_logs=getattr(request, "capture_logs", True), + capture_metadata=getattr(request, "capture_metadata", True), + capture_visualizations=getattr( + request, "capture_visualizations", True + ), + capture_metrics=getattr(request, "capture_metrics", True), ) def to_model( @@ -390,6 +418,12 @@ def to_model( else None, code_path=self.code_path, template_id=self.template_id, + capture_memory_only=self.capture_memory_only, + capture_code=self.capture_code, + capture_logs=self.capture_logs, + capture_metadata=self.capture_metadata, + capture_visualizations=self.capture_visualizations, + capture_metrics=self.capture_metrics, ) resources = None diff --git a/tests/unit/execution/test_default_runtime_metadata_toggle.py b/tests/unit/execution/test_default_runtime_metadata_toggle.py new file mode 100644 index 00000000000..740ff487569 --- /dev/null +++ b/tests/unit/execution/test_default_runtime_metadata_toggle.py @@ -0,0 +1,37 @@ +"""Unit tests for DefaultStepRuntime metadata/visualization toggles.""" + +from types import SimpleNamespace + +from zenml.execution.step_runtime import DefaultStepRuntime + + +def test_publish_metadata_skips_when_disabled(monkeypatch): + rt = DefaultStepRuntime() + setattr(rt, "_metadata_enabled", False) + + called = {"run": 0, "step": 0} + + def _pub_run_md(*a, **k): + called["run"] += 1 + + def _pub_step_md(*a, **k): + called["step"] += 1 + + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_pipeline_run_metadata", + _pub_run_md, + ) + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_step_run_metadata", + _pub_step_md, + ) + + rt.publish_pipeline_run_metadata( + pipeline_run_id=SimpleNamespace(), pipeline_run_metadata={} + ) + rt.publish_step_run_metadata( + step_run_id=SimpleNamespace(), step_run_metadata={} + ) + + assert called["run"] == 0 + assert called["step"] == 0 diff --git a/tests/unit/execution/test_memory_runtime.py b/tests/unit/execution/test_memory_runtime.py new file mode 100644 index 00000000000..49e0574834a --- /dev/null +++ b/tests/unit/execution/test_memory_runtime.py @@ -0,0 +1,69 @@ +"""Unit tests for MemoryStepRuntime instance-scoped isolation.""" + +from types import SimpleNamespace + +from zenml.execution.step_runtime import MemoryStepRuntime + + +def test_memory_runtime_instance_isolated_store(monkeypatch): + """Each runtime instance isolates values by run id; no cross leakage.""" + # Create two independent runtimes (instance-scoped stores) + rt1 = MemoryStepRuntime() + rt2 = MemoryStepRuntime() + + # Patch get_step_context to return minimal stubs + class _Ctx: + def __init__(self, run_id: str, step_name: str): + self.pipeline_run = SimpleNamespace(id=run_id) + self.step_run = SimpleNamespace(name=step_name) + + def get_output_metadata(self, name: str): + return {} + + def get_output_tags(self, name: str): + return [] + + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx("run-1", "s1"), + ) + + # Store with rt1 + outputs = {"out": 123} + handles1 = rt1.store_output_artifacts( + output_data=outputs, + output_materializers={"out": ()}, + output_artifact_uris={"out": "memory://run-1/s1/out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + h1 = handles1["out"] + + # Switch context for rt2 + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx("run-2", "s2"), + ) + handles2 = rt2.store_output_artifacts( + output_data={"out": 999}, + output_materializers={"out": ()}, + output_artifact_uris={"out": "memory://run-2/s2/out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + h2 = handles2["out"] + + # rt1 should load its own value + v1 = rt1.load_input_artifact(artifact=h1, data_type=int, stack=None) + assert v1 == 123 + + # rt2 should load its own value + v2 = rt2.load_input_artifact(artifact=h2, data_type=int, stack=None) + assert v2 == 999 + + # rt1 should NOT see rt2 value + assert ( + rt1.load_input_artifact(artifact=h2, data_type=int, stack=None) is None + ) diff --git a/tests/unit/execution/test_realtime_runtime.py b/tests/unit/execution/test_realtime_runtime.py new file mode 100644 index 00000000000..d1223bf02ed --- /dev/null +++ b/tests/unit/execution/test_realtime_runtime.py @@ -0,0 +1,41 @@ +"""Unit tests for RealtimeStepRuntime queue/backpressure and sweep.""" + +import queue +from types import SimpleNamespace + +from zenml.execution.realtime_runtime import RealtimeStepRuntime + + +def test_realtime_queue_full_inline_fallback(monkeypatch): + """When queue is full, publish events are processed inline as fallback.""" + rt = RealtimeStepRuntime(ttl_seconds=1, max_entries=8) + + # Replace queue with a tiny one and fill it + rt._q = queue.Queue(maxsize=1) # type: ignore[attr-defined] + rt._q.put(("dummy", (), {})) # fill once + + called = {"step": 0} + + def _pub_step_run_metadata(*args, **kwargs): + called["step"] += 1 + + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_step_run_metadata", + _pub_step_run_metadata, + ) + + # This put_nowait should hit Full and process inline + rt.publish_step_run_metadata( + step_run_id=SimpleNamespace(), step_run_metadata={} + ) + assert called["step"] == 1 + + +def test_realtime_sweep_expired_no_keyerror(): + """Expired cache entries are swept safely without KeyError races.""" + rt = RealtimeStepRuntime(ttl_seconds=0, max_entries=8) + # Insert an expired cache entry manually + with rt._lock: # type: ignore[attr-defined] + rt._cache["k1"] = ("v", 0.0) # type: ignore[attr-defined] + # Should not raise + rt._sweep_expired() diff --git a/tests/unit/execution/test_step_runtime_artifact_write.py b/tests/unit/execution/test_step_runtime_artifact_write.py new file mode 100644 index 00000000000..d89ed28bb71 --- /dev/null +++ b/tests/unit/execution/test_step_runtime_artifact_write.py @@ -0,0 +1,78 @@ +"""Unit test for defensive artifact write behavior (retry + validate).""" + +from types import SimpleNamespace + +from zenml.execution.step_runtime import DefaultStepRuntime + + +def test_artifact_write_retry_and_validate(monkeypatch): + """First batch create fails, retry succeeds; responses length validated.""" + rt = DefaultStepRuntime() + + # Patch helpers used to build requests + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_successful_step_run", + lambda *a, **k: None, + ) + + # Minimal step context stub + class _Ctx: + def __init__(self): + self.pipeline_run = SimpleNamespace( + config=SimpleNamespace(tags=None), pipeline=None + ) + self.step_run = SimpleNamespace(name="step") + + def get_output_metadata(self, name: str): + return {} + + def get_output_tags(self, name: str): + return [] + + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx(), + ) + + # Patch request preparation to avoid heavy imports + monkeypatch.setattr( + "zenml.execution.step_runtime._store_artifact_data_and_prepare_request", + lambda **k: {"req": k}, + ) + # Patch materializer selection + monkeypatch.setattr( + "zenml.execution.step_runtime.materializer_utils.select_materializer", + lambda data_type, materializer_classes: object, + ) + monkeypatch.setattr( + "zenml.execution.step_runtime.source_utils.load_and_validate_class", + lambda *a, **k: object, + ) + + calls = {"attempts": 0} + + class _Client: + class _Store: + def batch_create_artifact_versions(self, reqs): + calls["attempts"] += 1 + if calls["attempts"] == 1: + raise RuntimeError("transient") + # Return matching length list + return [SimpleNamespace(id=i) for i in range(len(reqs))] + + zen_store = _Store() + + monkeypatch.setattr( + "zenml.execution.step_runtime.Client", lambda: _Client() + ) + + res = rt.store_output_artifacts( + output_data={"out": 1}, + output_materializers={"out": ()}, + output_artifact_uris={"out": "uri://out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + assert "out" in res + assert calls["attempts"] == 2 diff --git a/tests/unit/orchestrators/test_step_launcher_params.py b/tests/unit/orchestrators/test_step_launcher_params.py new file mode 100644 index 00000000000..74d259198ac --- /dev/null +++ b/tests/unit/orchestrators/test_step_launcher_params.py @@ -0,0 +1,49 @@ +"""Unit tests for StepLauncher request parameter validation/merge. + +These tests verify allowlisting, simple type coercion, and size caps when +merging request parameters into the effective step configuration in serving. +""" + +from zenml.orchestrators.step_launcher import StepLauncher + + +def test_validate_and_merge_request_params_allowlist_and_types(monkeypatch): + """Allowlist known params and coerce simple types; drop unknowns.""" + # Use the real method by binding to a StepLauncher instance with minimal init + sl = StepLauncher.__new__(StepLauncher) # type: ignore + + class Cfg: + def __init__(self): + self.parameters = {"city": "paris", "count": 1} + + effective = Cfg() + req = { + "city": "munich", # allowed, string + "count": "2", # allowed, coercible to int + "unknown": "drop-me", # not declared + } + + merged = StepLauncher._validate_and_merge_request_params( + sl, req, effective + ) + assert merged["city"] == "munich" + assert merged["count"] == 2 + assert "unknown" not in merged + + +def test_validate_and_merge_request_params_size_caps(monkeypatch): + """Drop oversized string/collection parameters per safety caps.""" + sl = StepLauncher.__new__(StepLauncher) # type: ignore + + class Cfg: + def __init__(self): + self.parameters = {"text": "ok"} + + effective = Cfg() + big = "x" * 20000 # 20KB string -> dropped + req = {"text": big} + merged = StepLauncher._validate_and_merge_request_params( + sl, req, effective + ) + # Should keep the default, drop oversize + assert merged["text"] == "ok" From aa65e7b245b6ca67803e4679f0d4e284ef13f844 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 7 Sep 2025 18:05:08 +0100 Subject: [PATCH 4/8] Add pipeline endpoint capture migration This commit introduces a new Alembic migration that creates the `pipeline_endpoint` table and modifies the `pipeline_deployment` table to include additional capture-related columns. The new schema supports enhanced capture configurations for pipeline endpoints, improving the overall functionality and flexibility of the ZenML framework. The migration includes the following changes: - Creation of the `pipeline_endpoint` table with relevant fields and constraints. - Addition of columns for capturing various aspects of pipeline deployments, such as memory usage, logs, and metrics. This update lays the groundwork for improved pipeline management and monitoring capabilities. --- ...a848b2980c54_pipeline_endpoint_capture.py} | 56 ++++++++++++++----- 1 file changed, 43 insertions(+), 13 deletions(-) rename src/zenml/zen_stores/migrations/versions/{0d69e308846a_add_pipeline_endpoints.py => a848b2980c54_pipeline_endpoint_capture.py} (65%) diff --git a/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py b/src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py similarity index 65% rename from src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py rename to src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py index 8c397d21584..10cbd6168a2 100644 --- a/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py +++ b/src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py @@ -1,8 +1,8 @@ -"""add pipeline endpoints [0d69e308846a]. +"""pipeline endpoint + capture [a848b2980c54]. -Revision ID: 0d69e308846a -Revises: 0.84.3 -Create Date: 2025-08-26 10:30:52.737833 +Revision ID: a848b2980c54 +Revises: aae4eed923b5 +Create Date: 2025-09-07 18:04:15.320419 """ @@ -12,8 +12,8 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = "0d69e308846a" -down_revision = "0.84.3" +revision = "a848b2980c54" +down_revision = "aae4eed923b5" branch_labels = None depends_on = None @@ -36,7 +36,9 @@ def upgrade() -> None: sa.Column("auth_key", sa.TEXT(), nullable=True), sa.Column( "endpoint_metadata", - sa.String(length=16777215).with_variant(mysql.MEDIUMTEXT, "mysql"), + sa.String(length=16777215).with_variant( + mysql.MEDIUMTEXT(), "mysql" + ), nullable=False, ), sa.Column( @@ -45,18 +47,18 @@ def upgrade() -> None: nullable=True, ), sa.Column("deployer_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), - sa.ForeignKeyConstraint( - ["pipeline_deployment_id"], - ["pipeline_deployment.id"], - name="fk_pipeline_endpoint_pipeline_deployment_id_pipeline_deployment", - ondelete="SET NULL", - ), sa.ForeignKeyConstraint( ["deployer_id"], ["stack_component.id"], name="fk_pipeline_endpoint_deployer_id_stack_component", ondelete="SET NULL", ), + sa.ForeignKeyConstraint( + ["pipeline_deployment_id"], + ["pipeline_deployment.id"], + name="fk_pipeline_endpoint_pipeline_deployment_id_pipeline_deployment", + ondelete="SET NULL", + ), sa.ForeignKeyConstraint( ["project_id"], ["project.id"], @@ -76,11 +78,39 @@ def upgrade() -> None: name="unique_pipeline_endpoint_name_in_project", ), ) + with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op: + batch_op.add_column( + sa.Column("capture_memory_only", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_code", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_logs", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_metadata", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_visualizations", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_metrics", sa.Boolean(), nullable=False) + ) + # ### end Alembic commands ### def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op: + batch_op.drop_column("capture_metrics") + batch_op.drop_column("capture_visualizations") + batch_op.drop_column("capture_metadata") + batch_op.drop_column("capture_logs") + batch_op.drop_column("capture_code") + batch_op.drop_column("capture_memory_only") + op.drop_table("pipeline_endpoint") # ### end Alembic commands ### From bb6e25e7793882781edc18c6a90fc6fa1afc678a Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 7 Sep 2025 21:27:44 +0100 Subject: [PATCH 5/8] Add capture configuration to weather pipeline This commit updates the weather pipeline example to include a memory-only capture configuration using the `Capture` class. This enhancement allows for improved management of pipeline execution without persisting data to a database or filesystem. Additionally, the `run_entity_manager.py` file has been modified to utilize `field(default_factory=...)` for better initialization of dataclass fields, ensuring that default values are generated correctly. The `step_launcher.py` file has also been updated to handle memory-only stubs gracefully during execution interruptions. These changes contribute to a more robust and flexible pipeline serving architecture, aligning with recent refactors in the ZenML framework. --- examples/serving/weather_pipeline.py | 2 ++ src/zenml/orchestrators/run_entity_manager.py | 8 ++--- src/zenml/orchestrators/step_launcher.py | 32 +++++++++++-------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index 910b4257fae..9e1e5acba67 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -18,6 +18,7 @@ from typing import Dict from zenml import pipeline, step +from zenml.capture.config import Capture from zenml.config import DockerSettings # Import enums for type-safe capture mode configuration @@ -222,6 +223,7 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: @pipeline( on_init=init_hook, + capture=Capture(memory_only=True), settings={ "docker": docker_settings, "deployer.gcp": { diff --git a/src/zenml/orchestrators/run_entity_manager.py b/src/zenml/orchestrators/run_entity_manager.py index a1cce98a902..229e6754907 100644 --- a/src/zenml/orchestrators/run_entity_manager.py +++ b/src/zenml/orchestrators/run_entity_manager.py @@ -6,7 +6,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, Optional, Protocol, Tuple, cast @@ -116,7 +116,7 @@ class _PipelineRunStub: name: str model_version: Any = None pipeline: Any = None - config: Any = _PRCfg() + config: Any = field(default_factory=_PRCfg) return _PipelineRunStub(id=run_id, name=run_id), True @@ -143,7 +143,7 @@ class _StepRunStub: name: str model_version: Any = None logs: Optional[Any] = None - status: Any = _StatusStub() + status: Any = field(default_factory=_StatusStub) outputs: Dict[str, Any] = None # type: ignore[assignment] regular_inputs: Dict[str, Any] = None # type: ignore[assignment] @@ -156,7 +156,7 @@ class _Cfg: enable_artifact_visualization: Optional[bool] = None substitutions: Dict[str, str] = None # type: ignore[assignment] - config: Any = _Cfg() + config: Any = field(default_factory=_Cfg) def __post_init__(self) -> None: # noqa: D401 self.outputs = {} diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 76a6db52e78..c1e0b5e245e 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -266,15 +266,18 @@ def signal_handler(signum: int, frame: Any) -> None: client = Client() pipeline_run = None - if self._step_run: + # Memory-only stubs do not have a pipeline_run_id; handle gracefully + if self._step_run and hasattr(self._step_run, "pipeline_run_id"): pipeline_run = client.get_pipeline_run( self._step_run.pipeline_run_id ) - else: + elif self._step_run is None: raise RunInterruptedException( - "The execution was interrupted and the step does not " - "exist yet." + "The execution was interrupted and the step does not exist yet." ) + else: + # Memory-only: no server-side run to update; just signal interruption + raise RunInterruptedException("The execution was interrupted.") if pipeline_run and pipeline_run.status in [ ExecutionStatus.STOPPING, @@ -296,15 +299,18 @@ def signal_handler(signum: int, frame: Any) -> None: except Exception as e: raise RunInterruptedException(str(e)) finally: - # Chain to previous handler if it exists and is not default/ignore - if signum == signal.SIGTERM and callable( - self._prev_sigterm_handler - ): - self._prev_sigterm_handler(signum, frame) - elif signum == signal.SIGINT and callable( - self._prev_sigint_handler - ): - self._prev_sigint_handler(signum, frame) + # Chain to previous handler if it exists, not default/ignore, + # and not this handler to avoid recursion + prev = None + if signum == signal.SIGTERM: + prev = self._prev_sigterm_handler + elif signum == signal.SIGINT: + prev = self._prev_sigint_handler + if prev and prev not in (signal.SIG_DFL, signal.SIG_IGN) and prev is not signal_handler: + try: + prev(signum, frame) + except Exception: + pass # Register handlers for common termination signals try: From ab59f6588d29683e18ceef6f51f380db4644f1ea Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Sun, 7 Sep 2025 22:27:30 +0100 Subject: [PATCH 6/8] Refactor step runtime architecture for improved clarity This commit introduces a new `DefaultStepRuntime` class, consolidating the runtime logic previously scattered across multiple files. The `MemoryStepRuntime` and `DefaultStepRuntime` have been separated into their respective files, enhancing modularity and maintainability. Additionally, the `weather_pipeline.py` example has been updated to ensure proper execution of the pipeline. These changes aim to streamline the step execution process and improve the overall structure of the ZenML codebase, aligning with recent architectural enhancements. --- examples/serving/weather_pipeline.py | 2 +- src/zenml/execution/default_runtime.py | 291 +++++++++ src/zenml/execution/factory.py | 8 +- src/zenml/execution/memory_runtime.py | 336 ++++++++++ src/zenml/execution/realtime_runtime.py | 2 +- src/zenml/execution/step_runtime.py | 583 +----------------- src/zenml/orchestrators/runtime_manager.py | 3 +- src/zenml/orchestrators/step_launcher.py | 19 +- src/zenml/orchestrators/step_run_utils.py | 3 +- src/zenml/orchestrators/step_runner.py | 3 +- .../test_default_runtime_metadata_toggle.py | 5 +- tests/unit/execution/test_memory_runtime.py | 2 +- .../test_step_runtime_artifact_write.py | 2 +- 13 files changed, 662 insertions(+), 597 deletions(-) create mode 100644 src/zenml/execution/default_runtime.py create mode 100644 src/zenml/execution/memory_runtime.py diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index 9e1e5acba67..f6b2bf13278 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -275,7 +275,7 @@ def weather_agent_pipeline(city: str = "London") -> str: # Create deployment without running deployment = weather_agent_pipeline._create_deployment() - # weather_agent_pipeline() + weather_agent_pipeline() print("\n✅ Pipeline deployed for run-only serving!") print(f"📋 Deployment ID: {deployment.id}") diff --git a/src/zenml/execution/default_runtime.py b/src/zenml/execution/default_runtime.py new file mode 100644 index 00000000000..6a6dfaea11c --- /dev/null +++ b/src/zenml/execution/default_runtime.py @@ -0,0 +1,291 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Default step runtime implementation (blocking publish, standard persistence).""" + +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact +from zenml.client import Client +from zenml.enums import ArtifactSaveType +from zenml.execution.step_runtime import BaseStepRuntime +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.materializer_registry import materializer_registry +from zenml.models import ArtifactVersionResponse +from zenml.steps.step_context import get_step_context +from zenml.utils import materializer_utils, source_utils, tag_utils +from zenml.utils.typing_utils import get_origin, is_union + +if TYPE_CHECKING: + from zenml.artifact_stores import BaseArtifactStore + from zenml.config.step_configurations import Step + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.models import PipelineRunResponse, StepRunResponse + from zenml.models.v2.core.step_run import StepRunInputResponse + from zenml.stack import Stack + from zenml.steps.utils import OutputSignature + +logger = get_logger(__name__) + + +class DefaultStepRuntime(BaseStepRuntime): + """Default runtime delegating to existing ZenML utilities. + + This keeps current behavior intact while providing a single place for the + step runner to call into. It intentionally mirrors logic from + `step_runner.py` and `orchestrators/input_utils.py`. + """ + + # --- Input Resolution --- + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, "StepRunInputResponse"]: + """Resolve step inputs. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: Optional map of step runs. + + Returns: + Mapping from input name to resolved step run input. + """ + # Local import to avoid circular import issues + from zenml.orchestrators import input_utils + + return input_utils.resolve_step_inputs( + step=step, pipeline_run=pipeline_run, step_runs=step_runs + ) + + # --- Artifact Load --- + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + + Returns: + The loaded Python value for the input artifact. + """ + # Skip materialization for `UnmaterializedArtifact`. + if data_type == UnmaterializedArtifact: + return UnmaterializedArtifact( + **artifact.get_hydrated_version().model_dump() + ) + + if data_type in (None, Any) or is_union(get_origin(data_type)): + # Use the stored artifact datatype when function annotation is not specific + data_type = source_utils.load(artifact.data_type) + + materializer_class: Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + artifact.materializer, expected_class=BaseMaterializer + ) + ) + + def _load(artifact_store: "BaseArtifactStore") -> Any: + materializer: BaseMaterializer = materializer_class( + uri=artifact.uri, artifact_store=artifact_store + ) + materializer.validate_load_type_compatibility(data_type) + return materializer.load(data_type=data_type) + + if artifact.artifact_store_id == stack.artifact_store.id: + stack.artifact_store._register() + return _load(artifact_store=stack.artifact_store) + else: + # Local import to avoid circular import issues + from zenml.orchestrators.utils import ( + register_artifact_store_filesystem, + ) + + with register_artifact_store_filesystem( + artifact.artifact_store_id + ) as target_store: + return _load(artifact_store=target_store) + + # --- Artifact Store --- + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + Mapping from output name to stored artifact version. + + Raises: + RuntimeError: If artifact batch creation fails after retries or + the number of responses does not match requests. + """ + # Apply capture toggles for metadata and visualizations + artifact_metadata_enabled = artifact_metadata_enabled and bool( + getattr(self, "_metadata_enabled", True) + ) + artifact_visualization_enabled = ( + artifact_visualization_enabled + and bool(getattr(self, "_visualizations_enabled", True)) + ) + + step_context = get_step_context() + artifact_requests: List[Any] = [] + + for output_name, return_value in output_data.items(): + data_type = type(return_value) + materializer_classes = output_materializers[output_name] + if materializer_classes: + materializer_class: Type[BaseMaterializer] = ( + materializer_utils.select_materializer( + data_type=data_type, + materializer_classes=materializer_classes, + ) + ) + else: + # Runtime selection if no explicit materializer recorded + default_materializer_source = ( + step_context.step_run.config.outputs[ + output_name + ].default_materializer_source + if step_context and step_context.step_run + else None + ) + + if default_materializer_source: + default_materializer_class: Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + default_materializer_source, + expected_class=BaseMaterializer, + ) + ) + materializer_registry.default_materializer = ( + default_materializer_class + ) + + materializer_class = materializer_registry[data_type] + + uri = output_artifact_uris[output_name] + artifact_config = output_annotations[output_name].artifact_config + + artifact_type = None + if artifact_config is not None: + has_custom_name = bool(artifact_config.name) + version = artifact_config.version + artifact_type = artifact_config.artifact_type + else: + has_custom_name, version = False, None + + # Name resolution mirrors existing behavior + if has_custom_name: + artifact_name = output_name + else: + if step_context.pipeline_run.pipeline: + pipeline_name = step_context.pipeline_run.pipeline.name + else: + pipeline_name = "unlisted" + step_name = step_context.step_run.name + artifact_name = f"{pipeline_name}::{step_name}::{output_name}" + + # Collect user metadata and tags + user_metadata = step_context.get_output_metadata(output_name) + tags = step_context.get_output_tags(output_name) + if step_context.pipeline_run.config.tags is not None: + for tag in step_context.pipeline_run.config.tags: + if isinstance(tag, tag_utils.Tag) and tag.cascade is True: + tags.append(tag.name) + + # Store artifact data and prepare a request to the server. + from zenml.artifacts.utils import ( + _store_artifact_data_and_prepare_request, + ) + + artifact_request = _store_artifact_data_and_prepare_request( + name=artifact_name, + data=return_value, + materializer_class=materializer_class, + uri=uri, + artifact_type=artifact_type, + store_metadata=artifact_metadata_enabled, + store_visualizations=artifact_visualization_enabled, + has_custom_name=has_custom_name, + version=version, + tags=tags, + save_type=ArtifactSaveType.STEP_OUTPUT, + metadata=user_metadata, + ) + artifact_requests.append(artifact_request) + + max_retries = 2 + delay = 1.0 + + for attempt in range(max_retries + 1): + try: + responses = Client().zen_store.batch_create_artifact_versions( + artifact_requests + ) + if len(responses) != len(artifact_requests): + raise RuntimeError( + f"Artifact batch creation returned {len(responses)}/{len(artifact_requests)} responses" + ) + return dict(zip(output_data.keys(), responses)) + except Exception as e: + if attempt < max_retries: + logger.warning( + "Artifact creation attempt %s failed: %s. Retrying in %.1fs...", + attempt + 1, + e, + delay, + ) + time.sleep(delay) + delay *= 1.5 + else: + logger.error( + "Failed to create artifacts after %s attempts: %s. Failing step to avoid inconsistency.", + max_retries + 1, + e, + ) + raise + + # TODO(beta->prod): Align with server to provide atomic batch create or + # compensating deletes. Consider idempotent requests and retriable error + # categories with jittered backoff. + raise RuntimeError( + "Artifact creation failed unexpectedly without raising" + ) diff --git a/src/zenml/execution/factory.py b/src/zenml/execution/factory.py index 53f6f07c2f3..c8958010f42 100644 --- a/src/zenml/execution/factory.py +++ b/src/zenml/execution/factory.py @@ -13,11 +13,9 @@ # permissions and limitations under the License. """Factory to construct a step runtime based on context and capture.""" -from zenml.execution.step_runtime import ( - BaseStepRuntime, - DefaultStepRuntime, - MemoryStepRuntime, -) +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.memory_runtime import MemoryStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime def get_runtime( diff --git a/src/zenml/execution/memory_runtime.py b/src/zenml/execution/memory_runtime.py new file mode 100644 index 00000000000..25923bd0b76 --- /dev/null +++ b/src/zenml/execution/memory_runtime.py @@ -0,0 +1,336 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Memory-only step runtime (in-process handoff, no DB/FS persistence).""" + +import threading +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Type + +from zenml.execution.step_runtime import BaseStepRuntime +from zenml.logger import get_logger +from zenml.steps.step_context import get_step_context +from zenml.utils import string_utils + +if TYPE_CHECKING: + from zenml.config.step_configurations import Step + from zenml.models import PipelineRunResponse, StepRunResponse + +logger = get_logger(__name__) + + +class MemoryStepRuntime(BaseStepRuntime): + """Pure in-memory execution runtime: no server calls, no persistence. + + Instance-scoped store to isolate requests. Values are accessible within the + same process for the same run id and step chain only. + """ + + @staticmethod + def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: + """Make a handle ID for an output artifact. + + Args: + run_id: The run ID. + step_name: The step name. + output_name: The output name. + + Returns: + The handle ID. + """ + return f"mem://{run_id}/{step_name}/{output_name}" + + @staticmethod + def parse_handle_id(handle_id: str) -> Tuple[str, str, str]: + """Parse a handle ID for an output artifact. + + Args: + handle_id: The handle ID. + + Returns: + The run ID, step name, and output name. + + Raises: + ValueError: If the handle id is malformed. + """ + if not isinstance(handle_id, str) or not handle_id.startswith( + "mem://" + ): + raise ValueError("Invalid memory handle id") + rest = handle_id[len("mem://") :] + # split into exactly 3 parts: run_id, step_name, output_name + parts = rest.split("/", 2) + if len(parts) != 3: + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = parts + # basic sanitization + for p in (run_id, step_name, output_name): + if not p or "\n" in p or "\r" in p: + raise ValueError("Invalid memory handle component") + return run_id, step_name, output_name + + class Handle: + """A handle for an output artifact.""" + + def __init__(self, id: str) -> None: + """Initialize the handle. + + Args: + id: The handle ID. + """ + self.id = id + + # Instance-scoped context for handle resolution (set by launcher) + def __init__(self) -> None: + """Initialize the memory runtime.""" + super().__init__() + self._ctx_run_id: Optional[str] = None + self._ctx_substitutions: Dict[str, str] = {} + self._active_run_ids: set[str] = set() + # Instance-scoped storage and locks per run_id + self._store: Dict[str, Dict[Tuple[str, str], Any]] = {} + self._run_locks: Dict[str, Any] = {} + self._global_lock: Any = threading.RLock() + + def set_context( + self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None + ) -> None: + """Set current memory-only context for handle resolution. + + Args: + run_id: The run ID. + substitutions: The substitutions. + """ + self._ctx_run_id = run_id + self._ctx_substitutions = substitutions or {} + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass + + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, Any]: + """Resolve step inputs by constructing in-memory handles. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + + Returns: + A mapping of input name to MemoryStepRuntime.Handle. + """ + run_id = self._ctx_run_id or str(getattr(pipeline_run, "id", "local")) + subs = self._ctx_substitutions or {} + handles: Dict[str, Any] = {} + for name, input_ in step.spec.inputs.items(): + resolved_output_name = string_utils.format_name_template( + input_.output_name, substitutions=subs + ) + handle_id = self.make_handle_id( + run_id=run_id, + step_name=input_.step_name, + output_name=resolved_output_name, + ) + handles[name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def load_input_artifact( + self, *, artifact: Any, data_type: Type[Any], stack: Any + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + + Returns: + The loaded artifact. + + Raises: + ValueError: If the memory handle id is invalid or malformed. + """ + handle_id_any = getattr(artifact, "id", None) + if not isinstance(handle_id_any, str): + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = self.parse_handle_id(handle_id_any) + # Use per-run lock to avoid cross-run interference + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) + with rlock: + return self._store.get(run_id, {}).get((step_name, output_name)) + + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type[Any], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, Any], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, Any]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + ctx = get_step_context() + run_id = str(getattr(ctx.pipeline_run, "id", "local")) + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass + step_name = str(getattr(ctx.step_run, "name", "step")) + handles: Dict[str, Any] = {} + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) + with rlock: + rr = self._store.setdefault(run_id, {}) + for output_name, value in output_data.items(): + rr[(step_name, output_name)] = value + handle_id = self.make_handle_id(run_id, step_name, output_name) + handles[output_name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def compute_cache_key( + self, + *, + step: Any, + input_artifacts: Mapping[str, Any], + artifact_store: Any, + project_id: Any, + ) -> str: + """Compute a cache key. + + Args: + step: The step to compute the cache key for. + input_artifacts: The input artifacts for the step. + artifact_store: The artifact store to compute the cache key for. + project_id: The project ID to compute the cache key for. + + Returns: + The computed cache key. + """ + return "" + + def get_cached_step_run(self, *, cache_key: str) -> None: + """Get a cached step run. + + Args: + cache_key: The cache key to get the cached step run for. + + Returns: + The cached step run if available, otherwise None. + """ + return None + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + return + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + return + + def publish_successful_step_run( + self, *, step_run_id: Any, output_artifact_ids: Any + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + return + + def publish_failed_step_run(self, *, step_run_id: Any) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + return + + def start(self) -> None: + """Start the memory runtime.""" + return + + def on_step_start(self) -> None: + """Optional hook when a step starts execution.""" + return + + def flush(self) -> None: + """Flush the memory runtime.""" + return + + def on_step_end(self) -> None: + """Optional hook when a step ends execution.""" + return + + def shutdown(self) -> None: + """Shutdown the memory runtime.""" + return + + def __del__(self) -> None: # noqa: D401 + """Best-effort cleanup of per-run memory when GC collects the runtime.""" + try: + for run_id in list(self._active_run_ids): + try: + self.reset(run_id) + except Exception: + pass + except Exception: + pass + + # --- Unified path helpers --- + def reset(self, run_id: str) -> None: + """Clear all in-memory data associated with a specific run. + + Args: + run_id: The run id to clear. + """ + with self._global_lock: + try: + self._store.pop(run_id, None) + finally: + self._run_locks.pop(run_id, None) diff --git a/src/zenml/execution/realtime_runtime.py b/src/zenml/execution/realtime_runtime.py index bee8b16712e..a982b70dc92 100644 --- a/src/zenml/execution/realtime_runtime.py +++ b/src/zenml/execution/realtime_runtime.py @@ -25,7 +25,7 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -from zenml.execution.step_runtime import DefaultStepRuntime +from zenml.execution.default_runtime import DefaultStepRuntime from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer from zenml.models import ArtifactVersionResponse diff --git a/src/zenml/execution/step_runtime.py b/src/zenml/execution/step_runtime.py index ba1fd8924ea..c1318927513 100644 --- a/src/zenml/execution/step_runtime.py +++ b/src/zenml/execution/step_runtime.py @@ -20,32 +20,16 @@ Enable usage by setting environment variable `ZENML_ENABLE_STEP_RUNTIME=true`. """ -import threading -import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Type from uuid import UUID -from zenml.artifacts.unmaterialized_artifact import ( - UnmaterializedArtifact, -) -from zenml.client import Client -from zenml.enums import ArtifactSaveType from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer -from zenml.materializers.materializer_registry import materializer_registry from zenml.models import ArtifactVersionResponse # Note: avoid importing zenml.orchestrators modules at import time to prevent # circular dependencies. Where needed, import locally within methods. -from zenml.steps.step_context import get_step_context -from zenml.utils import ( - materializer_utils, - source_utils, - string_utils, - tag_utils, -) -from zenml.utils.typing_utils import get_origin, is_union if TYPE_CHECKING: from zenml.artifact_stores import BaseArtifactStore @@ -129,7 +113,7 @@ def compute_cache_key( self, *, step: "Step", - input_artifact_ids: Dict[str, UUID], + input_artifacts: Mapping[str, "ArtifactVersionResponse"], artifact_store: "BaseArtifactStore", project_id: UUID, ) -> str: @@ -139,7 +123,7 @@ def compute_cache_key( Args: step: The step to compute the cache key for. - input_artifact_ids: The input artifact IDs. + input_artifacts: The input artifacts. artifact_store: The artifact store to compute the cache key for. project_id: The project ID to compute the cache key for. @@ -151,7 +135,7 @@ def compute_cache_key( return cache_utils.generate_cache_key( step=step, - input_artifact_ids=input_artifact_ids, + input_artifacts=input_artifacts, artifact_store=artifact_store, project_id=project_id, ) @@ -276,562 +260,3 @@ def should_flush_on_step_end(self) -> bool: True to flush on step end; False otherwise. """ return True - - -class DefaultStepRuntime(BaseStepRuntime): - """Default runtime delegating to existing ZenML utilities. - - This keeps current behavior intact while providing a single place for the - step runner to call into. It intentionally mirrors logic from - `step_runner.py` and `orchestrators/input_utils.py`. - """ - - # --- Input Resolution --- - def resolve_step_inputs( - self, - *, - step: "Step", - pipeline_run: "PipelineRunResponse", - step_runs: Optional[Dict[str, "StepRunResponse"]] = None, - ) -> Dict[str, "StepRunInputResponse"]: - """Resolve step inputs. - - Args: - step: The step to resolve inputs for. - pipeline_run: The pipeline run to resolve inputs for. - step_runs: Optional map of step runs. - - Returns: - Mapping from input name to resolved step run input. - """ - # Local import to avoid circular import issues - from zenml.orchestrators import input_utils - - return input_utils.resolve_step_inputs( - step=step, pipeline_run=pipeline_run, step_runs=step_runs - ) - - # --- Artifact Load --- - def load_input_artifact( - self, - *, - artifact: ArtifactVersionResponse, - data_type: Type[Any], - stack: "Stack", - ) -> Any: - """Load an input artifact. - - Args: - artifact: The artifact to load. - data_type: The data type of the artifact. - stack: The stack to load the artifact from. - - Returns: - The loaded Python value for the input artifact. - """ - # Skip materialization for `UnmaterializedArtifact`. - if data_type == UnmaterializedArtifact: - return UnmaterializedArtifact( - **artifact.get_hydrated_version().model_dump() - ) - - if data_type in (None, Any) or is_union(get_origin(data_type)): - # Use the stored artifact datatype when function annotation is not specific - data_type = source_utils.load(artifact.data_type) - - materializer_class: Type[BaseMaterializer] = ( - source_utils.load_and_validate_class( - artifact.materializer, expected_class=BaseMaterializer - ) - ) - - def _load(artifact_store: "BaseArtifactStore") -> Any: - materializer: BaseMaterializer = materializer_class( - uri=artifact.uri, artifact_store=artifact_store - ) - materializer.validate_load_type_compatibility(data_type) - return materializer.load(data_type=data_type) - - if artifact.artifact_store_id == stack.artifact_store.id: - stack.artifact_store._register() - return _load(artifact_store=stack.artifact_store) - else: - # Local import to avoid circular import issues - from zenml.orchestrators.utils import ( - register_artifact_store_filesystem, - ) - - with register_artifact_store_filesystem( - artifact.artifact_store_id - ) as target_store: - return _load(artifact_store=target_store) - - # --- Artifact Store --- - def store_output_artifacts( - self, - *, - output_data: Dict[str, Any], - output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], - output_artifact_uris: Dict[str, str], - output_annotations: Dict[str, "OutputSignature"], - artifact_metadata_enabled: bool, - artifact_visualization_enabled: bool, - ) -> Dict[str, ArtifactVersionResponse]: - """Store output artifacts. - - Args: - output_data: The output data. - output_materializers: The output materializers. - output_artifact_uris: The output artifact URIs. - output_annotations: The output annotations. - artifact_metadata_enabled: Whether artifact metadata is enabled. - artifact_visualization_enabled: Whether artifact visualization is enabled. - - Returns: - Mapping from output name to stored artifact version. - - Raises: - RuntimeError: If artifact batch creation fails after retries or - the number of responses does not match requests. - """ - # Apply capture toggles for metadata and visualizations - artifact_metadata_enabled = artifact_metadata_enabled and bool( - getattr(self, "_metadata_enabled", True) - ) - artifact_visualization_enabled = ( - artifact_visualization_enabled - and bool(getattr(self, "_visualizations_enabled", True)) - ) - - step_context = get_step_context() - artifact_requests: List[Any] = [] - - for output_name, return_value in output_data.items(): - data_type = type(return_value) - materializer_classes = output_materializers[output_name] - if materializer_classes: - materializer_class: Type[BaseMaterializer] = ( - materializer_utils.select_materializer( - data_type=data_type, - materializer_classes=materializer_classes, - ) - ) - else: - # Runtime selection if no explicit materializer recorded - default_materializer_source = ( - step_context.step_run.config.outputs[ - output_name - ].default_materializer_source - if step_context and step_context.step_run - else None - ) - - if default_materializer_source: - default_materializer_class: Type[BaseMaterializer] = ( - source_utils.load_and_validate_class( - default_materializer_source, - expected_class=BaseMaterializer, - ) - ) - materializer_registry.default_materializer = ( - default_materializer_class - ) - - materializer_class = materializer_registry[data_type] - - uri = output_artifact_uris[output_name] - artifact_config = output_annotations[output_name].artifact_config - - artifact_type = None - if artifact_config is not None: - has_custom_name = bool(artifact_config.name) - version = artifact_config.version - artifact_type = artifact_config.artifact_type - else: - has_custom_name, version = False, None - - # Name resolution mirrors existing behavior - if has_custom_name: - artifact_name = output_name - else: - if step_context.pipeline_run.pipeline: - pipeline_name = step_context.pipeline_run.pipeline.name - else: - pipeline_name = "unlisted" - step_name = step_context.step_run.name - artifact_name = f"{pipeline_name}::{step_name}::{output_name}" - - # Collect user metadata and tags - user_metadata = step_context.get_output_metadata(output_name) - tags = step_context.get_output_tags(output_name) - if step_context.pipeline_run.config.tags is not None: - for tag in step_context.pipeline_run.config.tags: - if isinstance(tag, tag_utils.Tag) and tag.cascade is True: - tags.append(tag.name) - - # Store artifact data and prepare a request to the server. - from zenml.artifacts.utils import ( - _store_artifact_data_and_prepare_request, - ) - - artifact_request = _store_artifact_data_and_prepare_request( - name=artifact_name, - data=return_value, - materializer_class=materializer_class, - uri=uri, - artifact_type=artifact_type, - store_metadata=artifact_metadata_enabled, - store_visualizations=artifact_visualization_enabled, - has_custom_name=has_custom_name, - version=version, - tags=tags, - save_type=ArtifactSaveType.STEP_OUTPUT, - metadata=user_metadata, - ) - artifact_requests.append(artifact_request) - - max_retries = 2 - delay = 1.0 - - for attempt in range(max_retries + 1): - try: - responses = Client().zen_store.batch_create_artifact_versions( - artifact_requests - ) - if len(responses) != len(artifact_requests): - raise RuntimeError( - f"Artifact batch creation returned {len(responses)}/{len(artifact_requests)} responses" - ) - return dict(zip(output_data.keys(), responses)) - except Exception as e: - if attempt < max_retries: - logger.warning( - "Artifact creation attempt %s failed: %s. Retrying in %.1fs...", - attempt + 1, - e, - delay, - ) - time.sleep(delay) - delay *= 1.5 - else: - logger.error( - "Failed to create artifacts after %s attempts: %s. Failing step to avoid inconsistency.", - max_retries + 1, - e, - ) - raise - - # TODO(beta->prod): Align with server to provide atomic batch create or - # compensating deletes. Consider idempotent requests and retriable error - # categories with jittered backoff. - raise RuntimeError( - "Artifact creation failed unexpectedly without raising" - ) - - -class MemoryStepRuntime(BaseStepRuntime): - """Pure in-memory execution runtime: no server calls, no persistence. - - Instance-scoped store to isolate requests. Values are accessible within the - same process for the same run id and step chain only. - """ - - @staticmethod - def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: - """Make a handle ID for an output artifact. - - Args: - run_id: The run ID. - step_name: The step name. - output_name: The output name. - - Returns: - The handle ID. - """ - return f"mem://{run_id}/{step_name}/{output_name}" - - @staticmethod - def parse_handle_id(handle_id: str) -> Tuple[str, str, str]: - """Parse a handle ID for an output artifact. - - Args: - handle_id: The handle ID. - - Returns: - The run ID, step name, and output name. - - Raises: - ValueError: If the handle id is malformed. - """ - if not isinstance(handle_id, str) or not handle_id.startswith( - "mem://" - ): - raise ValueError("Invalid memory handle id") - rest = handle_id[len("mem://") :] - # split into exactly 3 parts: run_id, step_name, output_name - parts = rest.split("/", 2) - if len(parts) != 3: - raise ValueError("Invalid memory handle id") - run_id, step_name, output_name = parts - # basic sanitization - for p in (run_id, step_name, output_name): - if not p or "\n" in p or "\r" in p: - raise ValueError("Invalid memory handle component") - return run_id, step_name, output_name - - class Handle: - """A handle for an output artifact.""" - - def __init__(self, id: str) -> None: - """Initialize the handle. - - Args: - id: The handle ID. - """ - self.id = id - - # Instance-scoped context for handle resolution (set by launcher) - def __init__(self) -> None: - """Initialize the memory runtime.""" - super().__init__() - self._ctx_run_id: Optional[str] = None - self._ctx_substitutions: Dict[str, str] = {} - self._active_run_ids: set[str] = set() - # Instance-scoped storage and locks per run_id - self._store: Dict[str, Dict[Tuple[str, str], Any]] = {} - self._run_locks: Dict[str, Any] = {} - self._global_lock: Any = threading.RLock() - - def set_context( - self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None - ) -> None: - """Set current memory-only context for handle resolution. - - Args: - run_id: The run ID. - substitutions: The substitutions. - """ - self._ctx_run_id = run_id - self._ctx_substitutions = substitutions or {} - try: - if run_id: - self._active_run_ids.add(run_id) - except Exception: - pass - - def resolve_step_inputs( - self, - *, - step: "Step", - pipeline_run: "PipelineRunResponse", - step_runs: Optional[Dict[str, "StepRunResponse"]] = None, - ) -> Dict[str, Any]: - """Resolve step inputs by constructing in-memory handles. - - Args: - step: The step to resolve inputs for. - pipeline_run: The pipeline run to resolve inputs for. - step_runs: The step runs to resolve inputs for. - - Returns: - A mapping of input name to MemoryStepRuntime.Handle. - """ - run_id = self._ctx_run_id or str(getattr(pipeline_run, "id", "local")) - subs = self._ctx_substitutions or {} - handles: Dict[str, Any] = {} - for name, input_ in step.spec.inputs.items(): - resolved_output_name = string_utils.format_name_template( - input_.output_name, substitutions=subs - ) - handle_id = self.make_handle_id( - run_id=run_id, - step_name=input_.step_name, - output_name=resolved_output_name, - ) - handles[name] = MemoryStepRuntime.Handle(handle_id) - return handles - - def load_input_artifact( - self, *, artifact: Any, data_type: Type[Any], stack: Any - ) -> Any: - """Load an input artifact. - - Args: - artifact: The artifact to load. - data_type: The data type of the artifact. - stack: The stack to load the artifact from. - - Returns: - The loaded artifact. - - Raises: - ValueError: If the memory handle id is invalid or malformed. - """ - handle_id_any = getattr(artifact, "id", None) - if not isinstance(handle_id_any, str): - raise ValueError("Invalid memory handle id") - run_id, step_name, output_name = self.parse_handle_id(handle_id_any) - # Use per-run lock to avoid cross-run interference - with self._global_lock: - rlock = self._run_locks.setdefault(run_id, threading.RLock()) - with rlock: - return self._store.get(run_id, {}).get((step_name, output_name)) - - def store_output_artifacts( - self, - *, - output_data: Dict[str, Any], - output_materializers: Dict[str, Tuple[Type[Any], ...]], - output_artifact_uris: Dict[str, str], - output_annotations: Dict[str, Any], - artifact_metadata_enabled: bool, - artifact_visualization_enabled: bool, - ) -> Dict[str, Any]: - """Store output artifacts. - - Args: - output_data: The output data. - output_materializers: The output materializers. - output_artifact_uris: The output artifact URIs. - output_annotations: The output annotations. - artifact_metadata_enabled: Whether artifact metadata is enabled. - artifact_visualization_enabled: Whether artifact visualization is enabled. - - Returns: - The stored artifacts. - """ - ctx = get_step_context() - run_id = str(getattr(ctx.pipeline_run, "id", "local")) - try: - if run_id: - self._active_run_ids.add(run_id) - except Exception: - pass - step_name = str(getattr(ctx.step_run, "name", "step")) - handles: Dict[str, Any] = {} - with self._global_lock: - rlock = self._run_locks.setdefault(run_id, threading.RLock()) - with rlock: - rr = self._store.setdefault(run_id, {}) - for output_name, value in output_data.items(): - rr[(step_name, output_name)] = value - handle_id = self.make_handle_id(run_id, step_name, output_name) - handles[output_name] = MemoryStepRuntime.Handle(handle_id) - return handles - - def compute_cache_key( - self, - *, - step: Any, - input_artifact_ids: Dict[str, Any], - artifact_store: Any, - project_id: Any, - ) -> str: - """Compute a cache key. - - Args: - step: The step to compute the cache key for. - input_artifact_ids: The input artifact IDs. - artifact_store: The artifact store to compute the cache key for. - project_id: The project ID to compute the cache key for. - - Returns: - The computed cache key. - """ - return "" - - def get_cached_step_run(self, *, cache_key: str) -> None: - """Get a cached step run. - - Args: - cache_key: The cache key to get the cached step run for. - - Returns: - The cached step run if available, otherwise None. - """ - return None - - def publish_pipeline_run_metadata( - self, *, pipeline_run_id: Any, pipeline_run_metadata: Any - ) -> None: - """Publish pipeline run metadata. - - Args: - pipeline_run_id: The pipeline run ID. - pipeline_run_metadata: The pipeline run metadata. - """ - return - - def publish_step_run_metadata( - self, *, step_run_id: Any, step_run_metadata: Any - ) -> None: - """Publish step run metadata. - - Args: - step_run_id: The step run ID. - step_run_metadata: The step run metadata. - """ - return - - def publish_successful_step_run( - self, *, step_run_id: Any, output_artifact_ids: Any - ) -> None: - """Publish a successful step run. - - Args: - step_run_id: The step run ID. - output_artifact_ids: The output artifact IDs. - """ - return - - def publish_failed_step_run(self, *, step_run_id: Any) -> None: - """Publish a failed step run. - - Args: - step_run_id: The step run ID. - """ - return - - def start(self) -> None: - """Start the memory runtime.""" - return - - def on_step_start(self) -> None: - """Optional hook when a step starts execution.""" - return - - def flush(self) -> None: - """Flush the memory runtime.""" - return - - def on_step_end(self) -> None: - """Optional hook when a step ends execution.""" - return - - def shutdown(self) -> None: - """Shutdown the memory runtime.""" - return - - def __del__(self) -> None: # noqa: D401 - """Best-effort cleanup of per-run memory when GC collects the runtime.""" - try: - for run_id in list(self._active_run_ids): - try: - self.reset(run_id) - except Exception: - pass - except Exception: - pass - - # --- Unified path helpers --- - def reset(self, run_id: str) -> None: - """Clear all in-memory data associated with a specific run. - - Args: - run_id: The run id to clear. - """ - with self._global_lock: - try: - self._store.pop(run_id, None) - finally: - self._run_locks.pop(run_id, None) diff --git a/src/zenml/orchestrators/runtime_manager.py b/src/zenml/orchestrators/runtime_manager.py index 4c2fa7602a1..69ee02f01f5 100644 --- a/src/zenml/orchestrators/runtime_manager.py +++ b/src/zenml/orchestrators/runtime_manager.py @@ -10,7 +10,8 @@ from contextvars import ContextVar from typing import Optional -from zenml.execution.step_runtime import BaseStepRuntime, MemoryStepRuntime +from zenml.execution.memory_runtime import MemoryStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime # Shared runtime context for the lifetime of a single request. _shared_runtime: ContextVar[Optional[BaseStepRuntime]] = ContextVar( diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index c1e0b5e245e..7f5111d7c35 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -32,7 +32,7 @@ from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException from zenml.execution.factory import get_runtime -from zenml.execution.step_runtime import MemoryStepRuntime +from zenml.execution.memory_runtime import MemoryStepRuntime from zenml.logger import get_logger from zenml.logging import step_logging from zenml.models import ( @@ -267,7 +267,9 @@ def signal_handler(signum: int, frame: Any) -> None: pipeline_run = None # Memory-only stubs do not have a pipeline_run_id; handle gracefully - if self._step_run and hasattr(self._step_run, "pipeline_run_id"): + if self._step_run and hasattr( + self._step_run, "pipeline_run_id" + ): pipeline_run = client.get_pipeline_run( self._step_run.pipeline_run_id ) @@ -277,7 +279,9 @@ def signal_handler(signum: int, frame: Any) -> None: ) else: # Memory-only: no server-side run to update; just signal interruption - raise RunInterruptedException("The execution was interrupted.") + raise RunInterruptedException( + "The execution was interrupted." + ) if pipeline_run and pipeline_run.status in [ ExecutionStatus.STOPPING, @@ -306,9 +310,14 @@ def signal_handler(signum: int, frame: Any) -> None: prev = self._prev_sigterm_handler elif signum == signal.SIGINT: prev = self._prev_sigint_handler - if prev and prev not in (signal.SIG_DFL, signal.SIG_IGN) and prev is not signal_handler: + if ( + prev + and prev not in (signal.SIG_DFL, signal.SIG_IGN) + and prev is not signal_handler + ): try: - prev(signum, frame) + if callable(prev): + prev(signum, frame) except Exception: pass diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 6e37f47f4db..aaf34583151 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -20,7 +20,8 @@ from zenml.config.step_configurations import Step from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH from zenml.enums import ExecutionStatus -from zenml.execution.step_runtime import BaseStepRuntime, DefaultStepRuntime +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime from zenml.logger import get_logger from zenml.model.utils import link_artifact_version_to_model_version from zenml.models import ( diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index c7d33eef525..d8cd127d805 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -41,7 +41,8 @@ ) from zenml.enums import ArtifactSaveType from zenml.exceptions import StepInterfaceError -from zenml.execution.step_runtime import BaseStepRuntime, DefaultStepRuntime +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime from zenml.logger import get_logger from zenml.logging.step_logging import PipelineLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer diff --git a/tests/unit/execution/test_default_runtime_metadata_toggle.py b/tests/unit/execution/test_default_runtime_metadata_toggle.py index 740ff487569..f85666600f6 100644 --- a/tests/unit/execution/test_default_runtime_metadata_toggle.py +++ b/tests/unit/execution/test_default_runtime_metadata_toggle.py @@ -2,19 +2,22 @@ from types import SimpleNamespace -from zenml.execution.step_runtime import DefaultStepRuntime +from zenml.execution.default_runtime import DefaultStepRuntime def test_publish_metadata_skips_when_disabled(monkeypatch): + """Test that metadata is not published when disabled.""" rt = DefaultStepRuntime() setattr(rt, "_metadata_enabled", False) called = {"run": 0, "step": 0} def _pub_run_md(*a, **k): + """Mock publish pipeline run metadata.""" called["run"] += 1 def _pub_step_md(*a, **k): + """Mock publish step run metadata.""" called["step"] += 1 monkeypatch.setattr( diff --git a/tests/unit/execution/test_memory_runtime.py b/tests/unit/execution/test_memory_runtime.py index 49e0574834a..fda2314743a 100644 --- a/tests/unit/execution/test_memory_runtime.py +++ b/tests/unit/execution/test_memory_runtime.py @@ -2,7 +2,7 @@ from types import SimpleNamespace -from zenml.execution.step_runtime import MemoryStepRuntime +from zenml.execution.memory_runtime import MemoryStepRuntime def test_memory_runtime_instance_isolated_store(monkeypatch): diff --git a/tests/unit/execution/test_step_runtime_artifact_write.py b/tests/unit/execution/test_step_runtime_artifact_write.py index d89ed28bb71..c1658e544b4 100644 --- a/tests/unit/execution/test_step_runtime_artifact_write.py +++ b/tests/unit/execution/test_step_runtime_artifact_write.py @@ -2,7 +2,7 @@ from types import SimpleNamespace -from zenml.execution.step_runtime import DefaultStepRuntime +from zenml.execution.default_runtime import DefaultStepRuntime def test_artifact_write_retry_and_validate(monkeypatch): From cc67f2388810b5559fe6d33180c52ae9fab11ba1 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Mon, 8 Sep 2025 13:38:18 +0100 Subject: [PATCH 7/8] Add beta hardening checklist for production readiness This commit introduces a new document, `beta_todos.md`, outlining a comprehensive checklist for post-beta hardening efforts. The checklist includes key areas such as serving runtime, artifact write semantics, request parameter schema, monitoring, and resource management. Each section details specific tasks aimed at enhancing production readiness and scalability. This addition serves as a roadmap for future improvements and ensures that all necessary steps are documented for achieving a robust and reliable deployment of the ZenML framework. --- docs/PR_DESCRIPTION.md | 106 ----------------------------------------- docs/beta_todos.md | 62 ++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 106 deletions(-) delete mode 100644 docs/PR_DESCRIPTION.md create mode 100644 docs/beta_todos.md diff --git a/docs/PR_DESCRIPTION.md b/docs/PR_DESCRIPTION.md deleted file mode 100644 index b3cfa388796..00000000000 --- a/docs/PR_DESCRIPTION.md +++ /dev/null @@ -1,106 +0,0 @@ -## Beta: Unified Serving Capture, Memory-Only Isolation, and Realtime Hardening - -This PR delivers a focused, pragmatic refactor to make serving reliable and easy to reason about for a beta release. It simplifies capture configuration to a single typed `Capture`, unifies the execution path, introduces memory-only isolation, and hardens the realtime runtime with bounded resources and better shutdown behavior. - -### Summary - -- Collapse capture to a single typed API: `Capture(memory_only, code, logs, metadata, visualizations, metrics)`. -- Canonical capture fields on deployments; StepLauncher reads only those (no env/dict overrides). -- Serving request parameters are merged safely (allowlist + light validation + size caps); logged. -- Memory-only serving mode: truly no runs/artifacts/log writes; in-process handoff with per-request isolation. -- Realtime runtime: bounded queue, safe cache sweep, circuit-breaker maintained, improved shutdown and metrics. -- Defensive artifact writes: validation and minimal retries/backoff; fail fast on partial responses. -- In-code TODOs added for post-beta roadmap (transactions, multi-worker/async publishing, monitoring). - -### Motivation - -- Eliminate confusing capture modes and env overrides in code paths. -- Ensure serving is fast (async by default) and memory-only mode never touches DB/FS. -- Prevent cross-request contamination in memory-only; bound resource usage under load. -- Provide clear logs and metrics for diagnosis; pave the way for production hardening. - -### Key Behavioral Changes - -- Pipeline code uses a single `Capture` type; dicts/strings disallowed in code paths. -- Serving merges request parameters only from a declared allowlist; oversized/mismatched params are dropped with warnings. -- Memory-only serving executes fully in-process (no runs/artifacts), with explicit logs; step logs disabled to avoid FS writes. -- Realtime runtime backgrounds publishing with a bounded queue; if the queue is full, events are processed inline. - -### File-Level Changes (Selected) - -- Capture & Compiler - - `src/zenml/capture/config.py`: Single `Capture` dataclass; removed BatchCapture/RealtimeCapture/CapturePolicy. - - `src/zenml/config/compiler.py`: Normalizes typed capture into canonical deployment fields. - - `src/zenml/models/v2/core/pipeline_deployment.py`: Adds canonical capture fields to deployment models. - - `src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py`: Adds DB columns for canonical capture fields. - -- Orchestrator - - `src/zenml/orchestrators/step_launcher.py`: - - Uses canonical fields and serving context. - - Adds `_validate_and_merge_request_params` (allowlist + type coercion + size caps). - - Disables logs in memory-only; avoids FS cleanup for `memory://` URIs. - - `src/zenml/orchestrators/run_entity_manager.py`: In-memory step_run stub with minimal config (`enable_*`, `substitutions`). - - `src/zenml/orchestrators/utils.py`: Serving context helpers and docstrings; removed request-level override plumbing. - -- Execution Runtimes - - `src/zenml/execution/step_runtime.py`: - - `MemoryStepRuntime`: instance-scoped store/locks; per-run cleanup; no globals. - - `DefaultStepRuntime.store_output_artifacts`: defensive batch create (retries/backoff), response count validation; TODO for server-side atomicity. - - `src/zenml/execution/realtime_runtime.py`: - - Bounded queue (maxsize=1024), inline fallback on Full. - - Safe cache sweep (snapshot + safe pop, small time budget). - - Shutdown logs final metrics and warns on non-graceful termination; TODOs for thread-pool or async migration and metrics export. - -- Serving Service & Docs - - `src/zenml/deployers/serving/service.py`: Serving context handling; parameter exposure; cleanup. - - `docs/book/serving/*`: Updated to single Capture, serving async default, memory-only warning/behavior. - - `examples/serving/README.md`: Updated to reflect new serving model; memory-only usage. - -### Configuration & Tuning - -- Serving mode is inferred by context (batch vs. serving). No per-request capture overrides. -- Realtime runtime tuning via env: - - `ZENML_RT_CACHE_TTL_SECONDS` (default 60), `ZENML_RT_CACHE_MAX_ENTRIES` (default 256) - - `ZENML_RT_ERR_REPORT_INTERVAL` (default 15), circuit breaker envs unchanged -- Memory-only: ignored outside serving with a warning. - -### Testing & Validation - -- Unit - - Request parameter validation: allowlist, size caps, type coercion. - - Memory runtime isolation: per-instance store; no cross-contamination. - - Realtime runtime: queue Full → inline fallback; race-free cache sweep; shutdown metrics. - - Defensive artifact writes: retries/backoff; mismatch detection. - -- Manual - - Memory-only serving: no `/runs` or `/artifact_versions` calls; explicit log: `[Memory-only] … in-process handoff (no runs/artifacts)`. - - Serving async default: responses return immediately; background updates proceed. - -### Risk & Mitigations - -- Request param merge: now restricted by allowlist/size/type; unknowns dropped with warnings. -- Memory-only: per-request isolation and no FS/DB writes; logs disabled to avoid side effects. -- Realtime: bounded queue with inline fallback; circuit breaker remains in place. -- Artifact writes: fail fast rather than proceed with partial results; TODO for server-side atomicity. - -### In-Code TODOs (Post-Beta Roadmap) - -- Realtime runtime publishing: - - Either thread pool (ThreadPoolExecutor) workers or asyncio runtime once client has async publish calls and we want async serving. - - Preserve bounded backpressure and orderly shutdown. -- Request parameter schema derived from entrypoint annotations; add total payload size caps and strict mode. -- Server-side transactional/compensating semantics for artifact writes; adopt idempotent, category-aware retries. -- Metrics export to Prometheus; per-worker metrics; worker liveness/health signals; process memory watchdog. - -### How to Review - -- Focus on StepLauncher (param merge, memory-only flags) and runtimes (Memory/Realtime). -- Verify serving behavior in logs; check that memory-only path never touches DB/FS. -- Review TODOs in code for future milestones. - -### Rollout - -- Tag as beta and monitor runtime metrics (`queue_depth`, `failed_total`, `cache_hit_rate`, `op_latency_p95_s`). -- Scale by increasing HTTP workers and replicas; memory-only is fastest for prototypes. -- Provide guidance on cache sizing and memory usage in docs. - diff --git a/docs/beta_todos.md b/docs/beta_todos.md new file mode 100644 index 00000000000..a6af45f4831 --- /dev/null +++ b/docs/beta_todos.md @@ -0,0 +1,62 @@ +# Beta Hardening TODOs + +This is a living checklist for post‑beta hardening. All beta blockers are already implemented; the items below are for production readiness and scale. + +## Serving Runtime & Publishing + +- Multi‑worker scaling per process + - Option A (threads): Add a small ThreadPoolExecutor consuming the existing bounded queue; preserve backpressure and flush semantics. + - Option B (async): Introduce an asyncio loop + asyncio.Queue + async workers, once client/publish calls have async variants and we opt into async serving. + - Keep bounded queue, inline fallback on Full, and orderly shutdown (join/cancel with timeout). + +- Backpressure & batching + - Tune queue maxsize defaults; expose env knob `ZENML_RT_QUEUE_MAXSIZE`. + - Optional: micro‑batch compatible events for fewer round‑trips. + +- Circuit breaker refinements + - Distinguish network vs. logical errors for better decisions. + - Add optional cool‑down logs with guidance. + +## Artifact Write Semantics + +- Server‑side atomicity / compensation + - Align with server to provide atomic batch create or server‑side compensation. + - Client: switch from best‑effort retries to idempotent, category‑aware retries once server semantics are defined. + - Document consistency guarantees and failure behavior. + +## Request Parameter Schema & Safety + +- Parameter schema from entrypoint annotations + - Generate/derive expected types from pipeline entrypoint annotations (or compiled schema) rather than inferring from defaults. + - Add total payload size cap; add per‑type caps (e.g., list length, dict depth). + - Optional: strict mode that rejects unknown params rather than dropping. + +## Monitoring, Metrics, Health + +- Metrics enrichment + - Export runtime metrics to Prometheus (queue depth, cache hit rate, error rate, op latency histograms). + - Add per‑worker metrics if multi‑worker is enabled. + +- Health/liveness + - Expose background worker liveness/health via the service. + - Add simple self‑check endpoints and document alerts. + +## Memory & Resource Management + +- Process memory monitoring / limits + - Add process memory watchdog and log warnings; document recommended container limits. + - Add a user‑facing docs note about caching large artifacts and tuning `max_entries` accordingly. + +## Operational Docs & UX + +- Serving docs + - Add a prominent warning about memory usage for large cached artifacts and sizing `ZENML_RT_CACHE_MAX_ENTRIES`. + - Add examples for scaling processes/replicas and interpreting metrics. + +## Notes (Implemented in Beta) + +- Request param allowlist / type coercion / size caps +- Memory‑only isolation (instance‑scoped) and cleanup +- Bounded queue with inline fallback; race‑free cache sweep +- Graceful shutdown with timeout and final metrics +- Defensive artifact write behavior with minimal retries and response validation From b030e4db88c389e7e869320c06aab9e0a151ea1c Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Mon, 8 Sep 2025 13:55:12 +0100 Subject: [PATCH 8/8] format --- src/zenml/config/pipeline_run_configuration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 1e34da4d2d0..b8b3efee959 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -36,6 +36,7 @@ class PipelineRunConfiguration( StrictBaseModel, pydantic_utils.YAMLSerializationMixin ): """Class for pipeline run configurations.""" + run_name: Optional[str] = Field( default=None, description="The name of the pipeline run." ) @@ -71,7 +72,7 @@ class PipelineRunConfiguration( description="The build to use for the pipeline run.", ) # Optional typed capture override per run (no dicts/strings) - capture: Optional[Capture] =Field( + capture: Optional[Capture] = Field( default=None, description="The capture to use for the pipeline run.", )