Skip to content

Commit

Permalink
fix(worker): archive worker must be kept alive for processing (#1567)
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <[email protected]>
Co-authored-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
laurent-laporte-pro and sylvlecl authored Jun 7, 2023
1 parent 2784828 commit 34e1675
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 58 deletions.
29 changes: 25 additions & 4 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,40 @@ class Event(BaseModel):
channel: str = ""


EventListener = Callable[[Event], Awaitable[None]]


class IEventBus(ABC):
"""
Interface for the event bus.
The event bus provides 2 communication mechanisms:
- a broadcasting mechanism, where events are pushed to all
registered listeners
- a message queue mechanism: a message can be pushed to
a specified queue. Only consumers registered for that
queue will be called to handle those messages.
"""

@abstractmethod
def push(self, event: Event) -> None:
"""
Pushes an event to registered listeners.
"""
pass

@abstractmethod
def queue(self, event: Event, queue: str) -> None:
"""
Queues an event at the end of the specified queue.
"""
pass

@abstractmethod
def add_queue_consumer(
self, listener: Callable[[Event], Awaitable[None]], queue: str
) -> str:
def add_queue_consumer(self, listener: EventListener, queue: str) -> str:
"""
Adds a consumer for events on the specified queue.
"""
pass

@abstractmethod
Expand All @@ -74,7 +95,7 @@ def remove_queue_consumer(self, listener_id: str) -> None:
@abstractmethod
def add_listener(
self,
listener: Callable[[Event], Awaitable[None]],
listener: EventListener,
type_filter: Optional[List[EventType]] = None,
) -> str:
"""
Expand Down
26 changes: 19 additions & 7 deletions antarest/core/interfaces/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,28 @@


class IService(ABC):
def __init__(self) -> None:
self.thread = threading.Thread(
target=self._loop,
name=self.__class__.__name__,
daemon=True,
)
"""
A base class for long running processing services.
Processing may be started either in a background thread or in current thread.
Implementations must implement the `_loop` method.
"""

def start(self, threaded: bool = True) -> None:
"""
Starts the processing loop.
Args:
threaded: if True, the loop is started in a daemon thread,
else in this thread
"""
if threaded:
self.thread.start()
thread = threading.Thread(
target=self._loop,
name=self.__class__.__name__,
daemon=True,
)
thread.start()
else:
self._loop()

Expand Down
2 changes: 1 addition & 1 deletion antarest/worker/archive_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
[f"{ArchiveWorker.TASK_TYPE}_{workspace}"],
)

def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
logger.info(f"Executing task {task_info.json()}")
try:
# sourcery skip: extract-method
Expand Down
2 changes: 1 addition & 1 deletion antarest/worker/simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
cache=LocalCache(),
)

def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
if task_info.task_type == GENERATE_TIMESERIES_TASK_NAME:
return self.execute_timeseries_generation_task(task_info)
elif task_info.task_type == GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME:
Expand Down
55 changes: 34 additions & 21 deletions antarest/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import threading
import time
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, List, Union, Any

Expand Down Expand Up @@ -57,65 +58,77 @@ def __call__(self, future: "Future[Any]") -> None:

# fixme: `AbstractWorker` should not inherit from `IService`
class AbstractWorker(IService):
"""
Base class for workers which listens and process events.
The worker listens for task command events on specified queues,
and processes them with the implementation defined `_execute_task`.
"""

def __init__(
self,
name: str,
event_bus: IEventBus,
accept: List[str],
) -> None:
"""
Initializes a worker.
Args:
name: Name of this worker
event_bus: Event bus used for receiving commands,
and sending back processing events.
accept: The list of queues from which the worker
should consume task commands.
"""
super().__init__()
# fixme: `AbstractWorker` should not have any `thread` attribute
del self.thread
self.name = name
self.event_bus = event_bus
self.accept = accept
self.threadpool = ThreadPoolExecutor(
max_workers=MAX_WORKERS,
thread_name_prefix="worker_task_",
)
self.lock = threading.Lock()

# fixme: `AbstractWorker.start` should not have any `threaded` parameter
def start(self, threaded: bool = True) -> None:
def _loop(self) -> None:
for task_type in self.accept:
self.event_bus.add_queue_consumer(
self._listen_for_tasks, task_type
)
# Wait a short time to allow the event bus to have the opportunity
# to process the tasks as soon as possible
time.sleep(0.01)

# fixme: `AbstractWorker` should not have any `_loop` function
def _loop(self) -> None:
pass
# All the work is actually performed by callbacks
# on events.
# However, we want to keep the service alive while
# it waits for new events, so infinite loop ...
while True:
time.sleep(1)

async def _listen_for_tasks(self, event: Event) -> None:
logger.info(f"Accepting new task {event.json()}")
task_info = WorkerTaskCommand.parse_obj(event.payload)
self.event_bus.push(
Event(
type=EventType.WORKER_TASK_STARTED,
payload=task_info,
# Use `NONE` for internal events
payload=task_info, # Use `NONE` for internal events
permissions=PermissionInfo(public_mode=PublicMode.NONE),
)
)
with self.lock:
# fmt: off
future = self.threadpool.submit(self._safe_execute_task, task_info)
callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id)
future.add_done_callback(callback)
# fmt: on
# fmt: off
future = self.threadpool.submit(self._safe_execute_task, task_info)
callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id)
future.add_done_callback(callback)
# fmt: on

def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
try:
return self.execute_task(task_info)
return self._execute_task(task_info)
except Exception as e:
logger.error(
f"Unexpected error occurred when executing task {task_info.json()}",
exc_info=e,
)
return TaskResult(success=False, message=repr(e))

def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
@abstractmethod
def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def assert_study(a: SUB_JSON, b: SUB_JSON) -> None:


def auto_retry_assert(
predicate: Callable[..., bool], timeout: int = 2
predicate: Callable[..., bool], timeout: int = 2, delay: float = 0.2
) -> None:
threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout)
while datetime.now(timezone.utc) < threshold:
if predicate():
return
time.sleep(0.2)
time.sleep(delay)
raise AssertionError()
2 changes: 1 addition & 1 deletion tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
super().__init__("test", event_bus, accept)
self.tmp_path = tmp_path

def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
# simulate a "long" task ;-)
time.sleep(0.01)
relative_path = task_info.task_args["file"]
Expand Down
2 changes: 1 addition & 1 deletion tests/worker/test_archive_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_archive_worker_action(tmp_path: Path):
"remove_src": True,
},
)
archive_worker.execute_task(task_info)
archive_worker._execute_task(task_info)

assert not zip_file.exists()
assert expected_output.exists()
Expand Down
6 changes: 3 additions & 3 deletions tests/worker/test_simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def test_execute_task(logger_mock: Mock, tmp_path: Path):
worker.study_factory = Mock()

with pytest.raises(NotImplementedError):
worker.execute_task(
worker._execute_task(
task_info=WorkerTaskCommand(
task_id="task_id", task_type="unknown", task_args={}
)
)

with pytest.raises(NotImplementedError):
worker.execute_task(
worker._execute_task(
task_info=WorkerTaskCommand(
task_id="task_id",
task_type=GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME,
task_args={},
)
)
study_path = tmp_path / "study"
result = worker.execute_task(
result = worker._execute_task(
task_info=WorkerTaskCommand(
task_id="task_id",
task_type=GENERATE_TIMESERIES_TASK_NAME,
Expand Down
39 changes: 22 additions & 17 deletions tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,27 @@ def __init__(
super().__init__("test", event_bus, accept)
self.tmp_path = tmp_path

def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
# simulate a "long" task ;-)
time.sleep(0.01)
relative_path = task_info.task_args["file"]
(self.tmp_path / relative_path).touch()
return TaskResult(success=True, message="")


@pytest.mark.skip(reason="disabled because it sometimes crashes randomly")
def test_simple_task(tmp_path: Path):
task_queue = "do_stuff"
event_bus = build_eventbus(MagicMock(), Config(), autostart=True)
event_bus.queue(
Event(
type=EventType.WORKER_TASK,
payload=WorkerTaskCommand(
task_type="touch stuff",
task_id="some task",
task_args={"file": "foo"},
),
permissions=PermissionInfo(public_mode=PublicMode.READ),
command_event = Event(
type=EventType.WORKER_TASK,
payload=WorkerTaskCommand(
task_type="touch stuff",
task_id="some task",
task_args={"file": "foo"},
),
task_queue,
permissions=PermissionInfo(public_mode=PublicMode.READ),
)
event_bus.queue(command_event, task_queue)

# Add some listeners to debug the event bus notifications
msg = []
Expand All @@ -61,9 +58,17 @@ async def notify(event: Event):
# Wait for the end of the processing
# Set a big value to `timeout` if you want to debug the worker
auto_retry_assert(lambda: (tmp_path / "foo").exists(), timeout=60)
auto_retry_assert(
lambda: msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"],
timeout=1,
delay=0.1,
)

# Wait a short time to allow the event bus to have the opportunity
# to process the notification of the end event.
time.sleep(0.1)

assert msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"]
msg.clear()
# Send a second event to check worker is still processing events
event_bus.queue(command_event, task_queue)
auto_retry_assert(
lambda: msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"],
timeout=1,
delay=0.1,
)

0 comments on commit 34e1675

Please sign in to comment.