diff --git a/antarest/core/interfaces/eventbus.py b/antarest/core/interfaces/eventbus.py index 863b9e9af4..ce2d935b5f 100644 --- a/antarest/core/interfaces/eventbus.py +++ b/antarest/core/interfaces/eventbus.py @@ -52,19 +52,40 @@ class Event(BaseModel): channel: str = "" +EventListener = Callable[[Event], Awaitable[None]] + + class IEventBus(ABC): + """ + Interface for the event bus. + + The event bus provides 2 communication mechanisms: + - a broadcasting mechanism, where events are pushed to all + registered listeners + - a message queue mechanism: a message can be pushed to + a specified queue. Only consumers registered for that + queue will be called to handle those messages. + """ + @abstractmethod def push(self, event: Event) -> None: + """ + Pushes an event to registered listeners. + """ pass @abstractmethod def queue(self, event: Event, queue: str) -> None: + """ + Queues an event at the end of the specified queue. + """ pass @abstractmethod - def add_queue_consumer( - self, listener: Callable[[Event], Awaitable[None]], queue: str - ) -> str: + def add_queue_consumer(self, listener: EventListener, queue: str) -> str: + """ + Adds a consumer for events on the specified queue. + """ pass @abstractmethod @@ -74,7 +95,7 @@ def remove_queue_consumer(self, listener_id: str) -> None: @abstractmethod def add_listener( self, - listener: Callable[[Event], Awaitable[None]], + listener: EventListener, type_filter: Optional[List[EventType]] = None, ) -> str: """ diff --git a/antarest/core/interfaces/service.py b/antarest/core/interfaces/service.py index 7adacc6182..9fbe464ecb 100644 --- a/antarest/core/interfaces/service.py +++ b/antarest/core/interfaces/service.py @@ -3,16 +3,28 @@ class IService(ABC): - def __init__(self) -> None: - self.thread = threading.Thread( - target=self._loop, - name=self.__class__.__name__, - daemon=True, - ) + """ + A base class for long running processing services. + + Processing may be started either in a background thread or in current thread. + Implementations must implement the `_loop` method. + """ def start(self, threaded: bool = True) -> None: + """ + Starts the processing loop. + + Args: + threaded: if True, the loop is started in a daemon thread, + else in this thread + """ if threaded: - self.thread.start() + thread = threading.Thread( + target=self._loop, + name=self.__class__.__name__, + daemon=True, + ) + thread.start() else: self._loop() diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index 83d56d9c7d..5980da866e 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -38,7 +38,7 @@ def __init__( [f"{ArchiveWorker.TASK_TYPE}_{workspace}"], ) - def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: logger.info(f"Executing task {task_info.json()}") try: # sourcery skip: extract-method diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index e419a54178..ea8d915af8 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -59,7 +59,7 @@ def __init__( cache=LocalCache(), ) - def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: if task_info.task_type == GENERATE_TIMESERIES_TASK_NAME: return self.execute_timeseries_generation_task(task_info) elif task_info.task_type == GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME: diff --git a/antarest/worker/worker.py b/antarest/worker/worker.py index ced8b7d834..13f56568ca 100644 --- a/antarest/worker/worker.py +++ b/antarest/worker/worker.py @@ -1,6 +1,7 @@ import logging import threading import time +from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor, Future from typing import Dict, List, Union, Any @@ -57,15 +58,30 @@ def __call__(self, future: "Future[Any]") -> None: # fixme: `AbstractWorker` should not inherit from `IService` class AbstractWorker(IService): + """ + Base class for workers which listens and process events. + + The worker listens for task command events on specified queues, + and processes them with the implementation defined `_execute_task`. + """ + def __init__( self, name: str, event_bus: IEventBus, accept: List[str], ) -> None: + """ + Initializes a worker. + + Args: + name: Name of this worker + event_bus: Event bus used for receiving commands, + and sending back processing events. + accept: The list of queues from which the worker + should consume task commands. + """ super().__init__() - # fixme: `AbstractWorker` should not have any `thread` attribute - del self.thread self.name = name self.event_bus = event_bus self.accept = accept @@ -73,21 +89,19 @@ def __init__( max_workers=MAX_WORKERS, thread_name_prefix="worker_task_", ) - self.lock = threading.Lock() - # fixme: `AbstractWorker.start` should not have any `threaded` parameter - def start(self, threaded: bool = True) -> None: + def _loop(self) -> None: for task_type in self.accept: self.event_bus.add_queue_consumer( self._listen_for_tasks, task_type ) - # Wait a short time to allow the event bus to have the opportunity - # to process the tasks as soon as possible - time.sleep(0.01) - # fixme: `AbstractWorker` should not have any `_loop` function - def _loop(self) -> None: - pass + # All the work is actually performed by callbacks + # on events. + # However, we want to keep the service alive while + # it waits for new events, so infinite loop ... + while True: + time.sleep(1) async def _listen_for_tasks(self, event: Event) -> None: logger.info(f"Accepting new task {event.json()}") @@ -95,21 +109,19 @@ async def _listen_for_tasks(self, event: Event) -> None: self.event_bus.push( Event( type=EventType.WORKER_TASK_STARTED, - payload=task_info, - # Use `NONE` for internal events + payload=task_info, # Use `NONE` for internal events permissions=PermissionInfo(public_mode=PublicMode.NONE), ) ) - with self.lock: - # fmt: off - future = self.threadpool.submit(self._safe_execute_task, task_info) - callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id) - future.add_done_callback(callback) - # fmt: on + # fmt: off + future = self.threadpool.submit(self._safe_execute_task, task_info) + callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id) + future.add_done_callback(callback) + # fmt: on def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: try: - return self.execute_task(task_info) + return self._execute_task(task_info) except Exception as e: logger.error( f"Unexpected error occurred when executing task {task_info.json()}", @@ -117,5 +129,6 @@ def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: ) return TaskResult(success=False, message=repr(e)) - def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + @abstractmethod + def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: raise NotImplementedError() diff --git a/tests/conftest.py b/tests/conftest.py index 7eff60d0a1..755053d959 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,11 +82,11 @@ def assert_study(a: SUB_JSON, b: SUB_JSON) -> None: def auto_retry_assert( - predicate: Callable[..., bool], timeout: int = 2 + predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2 ) -> None: threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout) while datetime.now(timezone.utc) < threshold: if predicate(): return - time.sleep(0.2) + time.sleep(delay) raise AssertionError() diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 5632f8c75f..de6d71db36 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -275,7 +275,7 @@ def __init__( super().__init__("test", event_bus, accept) self.tmp_path = tmp_path - def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: # simulate a "long" task ;-) time.sleep(0.01) relative_path = task_info.task_args["file"] diff --git a/tests/worker/test_archive_worker.py b/tests/worker/test_archive_worker.py index 707df64185..8630e49a38 100644 --- a/tests/worker/test_archive_worker.py +++ b/tests/worker/test_archive_worker.py @@ -44,7 +44,7 @@ def test_archive_worker_action(tmp_path: Path): "remove_src": True, }, ) - archive_worker.execute_task(task_info) + archive_worker._execute_task(task_info) assert not zip_file.exists() assert expected_output.exists() diff --git a/tests/worker/test_simulator_worker.py b/tests/worker/test_simulator_worker.py index aa602115c3..8ebcd8b80a 100644 --- a/tests/worker/test_simulator_worker.py +++ b/tests/worker/test_simulator_worker.py @@ -40,14 +40,14 @@ def test_execute_task(logger_mock: Mock, tmp_path: Path): worker.study_factory = Mock() with pytest.raises(NotImplementedError): - worker.execute_task( + worker._execute_task( task_info=WorkerTaskCommand( task_id="task_id", task_type="unknown", task_args={} ) ) with pytest.raises(NotImplementedError): - worker.execute_task( + worker._execute_task( task_info=WorkerTaskCommand( task_id="task_id", task_type=GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME, @@ -55,7 +55,7 @@ def test_execute_task(logger_mock: Mock, tmp_path: Path): ) ) study_path = tmp_path / "study" - result = worker.execute_task( + result = worker._execute_task( task_info=WorkerTaskCommand( task_id="task_id", task_type=GENERATE_TIMESERIES_TASK_NAME, diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 9bf7f64a6b..230396256c 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -20,7 +20,7 @@ def __init__( super().__init__("test", event_bus, accept) self.tmp_path = tmp_path - def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: # simulate a "long" task ;-) time.sleep(0.01) relative_path = task_info.task_args["file"] @@ -28,22 +28,19 @@ def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return TaskResult(success=True, message="") -@pytest.mark.skip(reason="disabled because it sometimes crashes randomly") def test_simple_task(tmp_path: Path): task_queue = "do_stuff" event_bus = build_eventbus(MagicMock(), Config(), autostart=True) - event_bus.queue( - Event( - type=EventType.WORKER_TASK, - payload=WorkerTaskCommand( - task_type="touch stuff", - task_id="some task", - task_args={"file": "foo"}, - ), - permissions=PermissionInfo(public_mode=PublicMode.READ), + command_event = Event( + type=EventType.WORKER_TASK, + payload=WorkerTaskCommand( + task_type="touch stuff", + task_id="some task", + task_args={"file": "foo"}, ), - task_queue, + permissions=PermissionInfo(public_mode=PublicMode.READ), ) + event_bus.queue(command_event, task_queue) # Add some listeners to debug the event bus notifications msg = [] @@ -61,9 +58,17 @@ async def notify(event: Event): # Wait for the end of the processing # Set a big value to `timeout` if you want to debug the worker auto_retry_assert(lambda: (tmp_path / "foo").exists(), timeout=60) + auto_retry_assert( + lambda: msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"], + timeout=1, + delay=0.1, + ) - # Wait a short time to allow the event bus to have the opportunity - # to process the notification of the end event. - time.sleep(0.1) - - assert msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"] + msg.clear() + # Send a second event to check worker is still processing events + event_bus.queue(command_event, task_queue) + auto_retry_assert( + lambda: msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"], + timeout=1, + delay=0.1, + )