diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index 1668bd02df14..cd0ccec4c948 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -99,6 +99,7 @@ TaskRunResult, Variable, Worker, + WorkerMetadata, WorkPool, WorkQueue, WorkQueueStatusDetail, @@ -2596,6 +2597,7 @@ async def send_worker_heartbeat( worker_name: str, heartbeat_interval_seconds: Optional[float] = None, get_worker_id: bool = False, + worker_metadata: Optional[WorkerMetadata] = None, ) -> Optional[UUID]: """ Sends a worker heartbeat for a given work pool. @@ -2604,20 +2606,20 @@ async def send_worker_heartbeat( work_pool_name: The name of the work pool to heartbeat against. worker_name: The name of the worker sending the heartbeat. return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. + worker_metadata: Metadata about the worker to send to the server. """ - + params = { + "name": worker_name, + "heartbeat_interval_seconds": heartbeat_interval_seconds, + } + if worker_metadata: + params["worker_metadata"] = worker_metadata.model_dump(mode="json") if get_worker_id: - return_dict = {"return_id": get_worker_id} - else: - return_dict = {} + params["return_id"] = get_worker_id resp = await self._client.post( f"/work_pools/{work_pool_name}/workers/heartbeat", - json={ - "name": worker_name, - "heartbeat_interval_seconds": heartbeat_interval_seconds, - } - | return_dict, + json=params, ) if ( diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 06f5c6149588..b4f2fc31cbee 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -1689,3 +1689,24 @@ class CsrfToken(ObjectBaseModel): __getattr__ = getattr_migration(__name__) + + +class Integration(PrefectBaseModel): + """A representation of an installed Prefect integration.""" + + name: str = Field(description="The name of the Prefect integration.") + version: str = Field(description="The version of the Prefect integration.") + + +class WorkerMetadata(PrefectBaseModel): + """ + Worker metadata. + + We depend on the structure of `integrations`, but otherwise, worker classes + should support flexible metadata. + """ + + integrations: List[Integration] = Field( + default=..., description="Prefect integrations installed in the worker." + ) + model_config = ConfigDict(extra="allow") diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index dba7ab4b8ee6..6e35918f3294 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -10,6 +10,7 @@ import anyio.abc import httpx import pendulum +from importlib_metadata import distributions from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal @@ -19,7 +20,12 @@ from prefect.client.base import ServerType from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas.actions import WorkPoolCreate, WorkPoolUpdate -from prefect.client.schemas.objects import StateType, WorkPool +from prefect.client.schemas.objects import ( + Integration, + StateType, + WorkerMetadata, + WorkPool, +) from prefect.client.utilities import inject_client from prefect.events import Event, RelatedResource, emit_event from prefect.events.related import object_as_related_resource, tags_as_related_resources @@ -438,6 +444,7 @@ def __init__( self._submitting_flow_run_ids = set() self._cancelling_flow_run_ids = set() self._scheduled_task_scopes = set() + self._worker_metadata_sent = False @classmethod def get_documentation_url(cls) -> str: @@ -717,47 +724,81 @@ async def _update_local_work_pool_info(self): self._work_pool = work_pool - async def _send_worker_heartbeat( - self, get_worker_id: bool = False - ) -> Optional[UUID]: + async def _worker_metadata(self) -> Optional[WorkerMetadata]: """ - Sends a heartbeat to the API. - - If `get_worker_id` is True, the worker ID will be retrieved from the API. + Returns metadata about installed Prefect collections for the worker. """ - if self._work_pool: - return await self._client.send_worker_heartbeat( - work_pool_name=self._work_pool_name, - worker_name=self.name, - heartbeat_interval_seconds=self.heartbeat_interval_seconds, - get_worker_id=get_worker_id, - ) + installed_integrations = load_prefect_collections().keys() - async def sync_with_backend(self): + integration_versions = [ + Integration(name=dist.metadata["Name"], version=dist.version) + for dist in distributions() + # PyPI packages often use dashes, but Python package names use underscores + # because they must be valid identifiers. + if dist.metadata.get("Name").replace("_", "-") in installed_integrations + ] + + if integration_versions: + return WorkerMetadata(integrations=integration_versions) + return None + + async def _send_worker_heartbeat(self) -> Optional[UUID]: """ - Updates the worker's local information about it's current work pool and - queues. Sends a worker heartbeat to the API. + Sends a heartbeat to the API. """ - await self._update_local_work_pool_info() + if not self._client: + self._logger.warning("Client has not been initialized; skipping heartbeat.") + return None + if not self._work_pool: + self._logger.debug("Worker has no work pool; skipping heartbeat.") + return None + + should_get_worker_id = self._should_get_worker_id() + + params = { + "work_pool_name": self._work_pool_name, + "worker_name": self.name, + "heartbeat_interval_seconds": self.heartbeat_interval_seconds, + "get_worker_id": should_get_worker_id, + } + if ( + self._client.server_type == ServerType.CLOUD + and not self._worker_metadata_sent + ): + worker_metadata = await self._worker_metadata() + if worker_metadata: + params["worker_metadata"] = worker_metadata + self._worker_metadata_sent = True + + worker_id = None try: - remote_id = await self._send_worker_heartbeat( - get_worker_id=(self._should_get_worker_id()) - ) + worker_id = await self._client.send_worker_heartbeat(**params) except httpx.HTTPStatusError as e: - if e.response.status_code == 422 and self._should_get_worker_id(): + if e.response.status_code == 422 and should_get_worker_id: self._logger.warning( "Failed to retrieve worker ID from the Prefect API server." ) - await self._send_worker_heartbeat(get_worker_id=False) - remote_id = None + params["get_worker_id"] = False + worker_id = await self._client.send_worker_heartbeat(**params) else: raise e - if self._should_get_worker_id() and remote_id is None: + if should_get_worker_id and worker_id is None: self._logger.warning( "Failed to retrieve worker ID from the Prefect API server." ) - elif self.backend_id is None and remote_id is not None: + + return worker_id + + async def sync_with_backend(self): + """ + Updates the worker's local information about it's current work pool and + queues. Sends a worker heartbeat to the API. + """ + await self._update_local_work_pool_info() + + remote_id = await self._send_worker_heartbeat() + if remote_id: self.backend_id = remote_id self._logger = get_worker_logger(self) @@ -769,6 +810,7 @@ def _should_get_worker_id(self): """Determines if the worker should request an ID from the API server.""" return ( get_current_settings().experiments.worker_logging_to_api_enabled + and self._client and self._client.server_type == ServerType.CLOUD and self.backend_id is None ) diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index dad4fadd9b8a..daf3fae6da21 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager from datetime import timedelta from typing import Generator, List +from unittest import mock from unittest.mock import ANY, MagicMock, Mock from uuid import UUID, uuid4 @@ -55,9 +56,11 @@ Flow, FlowRunNotificationPolicy, FlowRunPolicy, + Integration, StateType, TaskRun, Variable, + WorkerMetadata, WorkQueue, ) from prefect.client.schemas.responses import ( @@ -69,6 +72,7 @@ from prefect.client.utilities import inject_client from prefect.events import AutomationCore, EventTrigger, Posture from prefect.server.api.server import create_app +from prefect.server.database.orm_models import WorkPool from prefect.settings import ( PREFECT_API_DATABASE_MIGRATE_ON_START, PREFECT_API_KEY, @@ -2698,3 +2702,59 @@ def test_raise_for_api_version_mismatch_with_incompatible_versions( f"Found incompatible versions: client: {client_version}, server: {api_version}. " in str(e.value) ) + + +class TestPrefectClientWorkerHeartbeat: + async def test_worker_heartbeat( + self, prefect_client: PrefectClient, work_pool: WorkPool + ): + work_pool_name = str(work_pool.name) + await prefect_client.send_worker_heartbeat( + work_pool_name=work_pool_name, + worker_name="test-worker", + heartbeat_interval_seconds=10, + ) + workers = await prefect_client.read_workers_for_work_pool(work_pool_name) + assert len(workers) == 1 + assert workers[0].name == "test-worker" + assert workers[0].heartbeat_interval_seconds == 10 + + async def test_worker_heartbeat_sends_metadata_if_passed( + self, prefect_client: PrefectClient + ): + with mock.patch( + "prefect.client.orchestration.PrefectHttpxAsyncClient.post", + return_value=httpx.Response(status_code=204), + ) as mock_post: + await prefect_client.send_worker_heartbeat( + work_pool_name="work-pool", + worker_name="test-worker", + heartbeat_interval_seconds=10, + worker_metadata=WorkerMetadata( + integrations=[Integration(name="prefect-aws", version="1.0.0")] + ), + ) + assert mock_post.call_args[1]["json"] == { + "name": "test-worker", + "heartbeat_interval_seconds": 10, + "worker_metadata": { + "integrations": [{"name": "prefect-aws", "version": "1.0.0"}] + }, + } + + async def test_worker_heartbeat_does_not_send_metadata_if_not_passed( + self, prefect_client: PrefectClient + ): + with mock.patch( + "prefect.client.orchestration.PrefectHttpxAsyncClient.post", + return_value=httpx.Response(status_code=204), + ) as mock_post: + await prefect_client.send_worker_heartbeat( + work_pool_name="work-pool", + worker_name="test-worker", + heartbeat_interval_seconds=10, + ) + assert mock_post.call_args[1]["json"] == { + "name": "test-worker", + "heartbeat_interval_seconds": 10, + } diff --git a/tests/workers/test_base_worker.py b/tests/workers/test_base_worker.py index 02210b899b99..f7eb56577489 100644 --- a/tests/workers/test_base_worker.py +++ b/tests/workers/test_base_worker.py @@ -1,5 +1,6 @@ import uuid from typing import Any, Dict, Optional, Type +from unittest import mock from unittest.mock import MagicMock import httpx @@ -15,6 +16,7 @@ from prefect.client.base import ServerType from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas import FlowRun +from prefect.client.schemas.objects import WorkerMetadata from prefect.exceptions import ( CrashedRun, ObjectNotFound, @@ -1846,3 +1848,107 @@ async def test_env_merge_logic_is_deep( for key, value in expected_env.items(): assert config.env[key] == value + + +class TestBaseWorkerHeartbeat: + async def test_worker_heartbeat_sends_integrations( + self, work_pool, hosted_api_server, experimental_logging_enabled + ): + async with WorkerTestImpl(work_pool_name=work_pool.name) as worker: + await worker.start(run_once=True) + with mock.patch( + "prefect.workers.base.load_prefect_collections" + ) as mock_load_prefect_collections, mock.patch( + "prefect.client.orchestration.PrefectHttpxAsyncClient.post" + ) as mock_send_worker_heartbeat_post, mock.patch( + "prefect.workers.base.distributions" + ) as mock_distributions: + mock_load_prefect_collections.return_value = { + "prefect-aws": "1.0.0", + } + mock_distributions.return_value = [ + mock.MagicMock( + metadata={"Name": "prefect_aws"}, + version="1.0.0", + ) + ] + + async with get_client() as client: + worker._client = client + worker._client.server_type = ServerType.CLOUD + await worker.sync_with_backend() + + mock_send_worker_heartbeat_post.assert_called_once_with( + f"/work_pools/{work_pool.name}/workers/heartbeat", + json={ + "name": worker.name, + "heartbeat_interval_seconds": worker.heartbeat_interval_seconds, + "worker_metadata": { + "integrations": [ + {"name": "prefect_aws", "version": "1.0.0"} + ] + }, + "return_id": True, + }, + ) + + assert worker._worker_metadata_sent + + async def test_custom_worker_can_send_arbitrary_metadata( + self, work_pool, hosted_api_server, experimental_logging_enabled + ): + class CustomWorker(BaseWorker): + type: str = "test-custom-metadata" + job_configuration: Type[BaseJobConfiguration] = BaseJobConfiguration + + async def run(self): + pass + + async def _worker_metadata(self) -> WorkerMetadata: + return WorkerMetadata( + **{ + "integrations": [{"name": "prefect-aws", "version": "1.0.0"}], + "custom_field": "heya", + } + ) + + async with CustomWorker(work_pool_name=work_pool.name) as worker: + await worker.start(run_once=True) + with mock.patch( + "prefect.workers.base.load_prefect_collections" + ) as mock_load_prefect_collections, mock.patch( + "prefect.client.orchestration.PrefectHttpxAsyncClient.post" + ) as mock_send_worker_heartbeat_post, mock.patch( + "prefect.workers.base.distributions" + ) as mock_distributions: + mock_load_prefect_collections.return_value = { + "prefect-aws": "1.0.0", + } + mock_distributions.return_value = [ + mock.MagicMock( + metadata={"Name": "prefect-aws"}, + version="1.0.0", + ) + ] + + async with get_client() as client: + worker._client = client + worker._client.server_type = ServerType.CLOUD + await worker.sync_with_backend() + + mock_send_worker_heartbeat_post.assert_called_once_with( + f"/work_pools/{work_pool.name}/workers/heartbeat", + json={ + "name": worker.name, + "heartbeat_interval_seconds": worker.heartbeat_interval_seconds, + "worker_metadata": { + "integrations": [ + {"name": "prefect-aws", "version": "1.0.0"} + ], + "custom_field": "heya", + }, + "return_id": True, + }, + ) + + assert worker._worker_metadata_sent