Skip to content

Commit

Permalink
feat(ts-gen): display progress bar via websockets (#2194)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle authored Oct 29, 2024
1 parent eb10ca1 commit d845e5f
Show file tree
Hide file tree
Showing 43 changed files with 444 additions and 129 deletions.
9 changes: 5 additions & 4 deletions antarest/core/filesystem_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@

import typing_extensions as te
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from pydantic import Field
from starlette.responses import PlainTextResponse, StreamingResponse

from antarest.core.config import Config
from antarest.core.serialization import AntaresBaseModel
from antarest.core.utils.web import APITag
from antarest.login.auth import Auth

Expand All @@ -35,7 +36,7 @@


class FilesystemDTO(
BaseModel,
AntaresBaseModel,
extra="forbid",
json_schema_extra={
"example": {
Expand All @@ -61,7 +62,7 @@ class FilesystemDTO(


class MountPointDTO(
BaseModel,
AntaresBaseModel,
extra="forbid",
json_schema_extra={
"example": {
Expand Down Expand Up @@ -109,7 +110,7 @@ async def from_path(cls, name: str, path: Path) -> "MountPointDTO":


class FileInfoDTO(
BaseModel,
AntaresBaseModel,
extra="forbid",
json_schema_extra={
"example": {
Expand Down
7 changes: 5 additions & 2 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EventType(StrEnum):
WORKER_TASK_STARTED = "WORKER_TASK_STARTED"
WORKER_TASK_ENDED = "WORKER_TASK_ENDED"
LAUNCH_PROGRESS = "LAUNCH_PROGRESS"
TS_GENERATION_PROGRESS = "TS_GENERATION_PROGRESS"
TASK_PROGRESS = "TASK_PROGRESS"


class EventChannelDirectory:
Expand Down Expand Up @@ -137,6 +137,9 @@ def start(self, threaded: bool = True) -> None:


class DummyEventBusService(IEventBus):
def __init__(self) -> None:
self.events: List[Event] = []

def queue(self, event: Event, queue: str) -> None:
# Noop
pass
Expand All @@ -150,7 +153,7 @@ def remove_queue_consumer(self, listener_id: str) -> None:

def push(self, event: Event) -> None:
# Noop
pass
self.events.append(event)

def add_listener(
self,
Expand Down
57 changes: 45 additions & 12 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,23 @@

logger = logging.getLogger(__name__)

TaskUpdateNotifier = t.Callable[[str], None]
Task = t.Callable[[TaskUpdateNotifier], TaskResult]

DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours
"""Default timeout for `await_task` in seconds."""


class ITaskNotifier(ABC):
@abstractmethod
def notify_message(self, message: str) -> None:
raise NotImplementedError()

@abstractmethod
def notify_progress(self, progress: int) -> None:
raise NotImplementedError()


Task = t.Callable[[ITaskNotifier], TaskResult]


class ITaskService(ABC):
@abstractmethod
def add_worker_task(
Expand Down Expand Up @@ -94,11 +104,17 @@ def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT)


# noinspection PyUnusedLocal
def noop_notifier(message: str) -> None:
"""This function is used in tasks when no notification is required."""
class NoopNotifier(ITaskNotifier):
"""This class is used in tasks when no notification is required."""

def notify_message(self, message: str) -> None:
return

def notify_progress(self, progress: int) -> None:
return

class TaskJobLogRecorder:

class TaskLogAndProgressRecorder(ITaskNotifier):
"""
Callback used to register log messages in the TaskJob table.
Expand All @@ -107,15 +123,32 @@ class TaskJobLogRecorder:
session: The database session created in the same thread as the task thread.
"""

def __init__(self, task_id: str, session: Session):
def __init__(self, task_id: str, session: Session, event_bus: IEventBus) -> None:
self.session = session
self.task_id = task_id
self.event_bus = event_bus

def __call__(self, message: str) -> None:
def notify_message(self, message: str) -> None:
task = self.session.query(TaskJob).get(self.task_id)
if task:
task.logs.append(TaskJobLog(message=message, task_id=self.task_id))
db.session.commit()
self.session.commit()

def notify_progress(self, progress: int) -> None:
self.session.query(TaskJob).filter(TaskJob.id == self.task_id).update({TaskJob.progress: progress})
self.session.commit()

self.event_bus.push(
Event(
type=EventType.TASK_PROGRESS,
payload={
"task_id": self.task_id,
"progress": progress,
},
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + self.task_id,
)
)


class TaskJobService(ITaskService):
Expand All @@ -138,7 +171,7 @@ def _create_worker_task(
task_id: str,
task_type: str,
task_args: t.Dict[str, t.Union[int, float, bool, str]],
) -> t.Callable[[TaskUpdateNotifier], TaskResult]:
) -> Task:
task_result_wrapper: t.List[TaskResult] = []

def _create_awaiter(
Expand All @@ -152,7 +185,7 @@ async def _await_task_end(event: Event) -> None:
return _await_task_end

# noinspection PyUnusedLocal
def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult:
def _send_worker_task(logger_: ITaskNotifier) -> TaskResult:
listener_id = self.event_bus.add_listener(
_create_awaiter(task_result_wrapper),
[EventType.WORKER_TASK_ENDED],
Expand Down Expand Up @@ -380,7 +413,7 @@ def _run_task(
try:
with db():
# We must use the DB session attached to the current thread
result = callback(TaskJobLogRecorder(task_id, session=db.session))
result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus))

status = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
logger.info(f"Task {task_id} ended with status {status}")
Expand Down
4 changes: 2 additions & 2 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from antarest.core.model import PermissionInfo, PublicMode, StudyPermissionType
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.tasks.model import TaskResult, TaskType
from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier
from antarest.core.tasks.service import ITaskNotifier, ITaskService
from antarest.core.utils.archives import ArchiveFormat, archive_dir, is_zip, read_in_zip
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import StopWatch, concat_files, concat_files_to_str
Expand Down Expand Up @@ -598,7 +598,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F
export_path = Path(export_file_download.path)
export_id = export_file_download.id

def export_task(_: TaskUpdateNotifier) -> TaskResult:
def export_task(_: ITaskNotifier) -> TaskResult:
try:
#
archive_dir(output_path, export_path, archive_format=ArchiveFormat.ZIP)
Expand Down
4 changes: 2 additions & 2 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.serialization import from_json
from antarest.core.tasks.model import TaskResult, TaskType
from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier
from antarest.core.tasks.service import ITaskNotifier, ITaskService
from antarest.core.utils.archives import ArchiveFormat, archive_dir
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import StopWatch
Expand Down Expand Up @@ -510,7 +510,7 @@ def download_matrix_list(
export_path = Path(export_file_download.path)
export_id = export_file_download.id

def export_task(notifier: TaskUpdateNotifier) -> TaskResult:
def export_task(notifier: ITaskNotifier) -> TaskResult:
try:
self.create_matrix_files(matrix_ids=matrix_list, export_path=export_path)
self.file_transfer_manager.set_ready(export_id)
Expand Down
8 changes: 5 additions & 3 deletions antarest/study/business/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
import typing as t

from antares.study.version import StudyVersion
from pydantic import BaseModel

from antarest.core.exceptions import CommandApplicationError
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.requests import RequestParameters
from antarest.core.serialization import AntaresBaseModel
from antarest.study.business.all_optional_meta import camel_case_model
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.storage.utils import is_managed
from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.command_listener.command_listener import ICommandListener

# noinspection SpellCheckingInspection
GENERAL_DATA_PATH = "settings/generaldata"
Expand All @@ -35,11 +36,12 @@ def execute_or_add_commands(
file_study: FileStudy,
commands: t.Sequence[ICommand],
storage_service: StudyStorageService,
listener: t.Optional[ICommandListener] = None,
) -> None:
if isinstance(study, RawStudy):
executed_commands: t.MutableSequence[ICommand] = []
for command in commands:
result = command.apply(file_study)
result = command.apply(file_study, listener)
if not result.status:
raise CommandApplicationError(result.message)
executed_commands.append(command)
Expand Down Expand Up @@ -72,7 +74,7 @@ def execute_or_add_commands(

@camel_case_model
class FormFieldsBaseModel(
BaseModel,
AntaresBaseModel,
extra="forbid",
validate_assignment=True,
populate_by_name=True,
Expand Down
Loading

0 comments on commit d845e5f

Please sign in to comment.