Skip to content

Commit

Permalink
Merge branch 'dev' into fix/lower-case-issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvlecl authored Jan 16, 2025
2 parents 16db150 + 36ebe7b commit 9a90453
Show file tree
Hide file tree
Showing 39 changed files with 629 additions and 375 deletions.
6 changes: 6 additions & 0 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ class LocalConfig:
enable_nb_cores_detection: bool = True
nb_cores: NbCoresConfig = NbCoresConfig()
time_limit: TimeLimitConfig = TimeLimitConfig()
xpress_dir: Optional[str] = None
local_workspace: Path = Path("./local_workspace")

@classmethod
def from_dict(cls, data: JSON) -> "LocalConfig":
Expand All @@ -278,10 +280,14 @@ def from_dict(cls, data: JSON) -> "LocalConfig":
nb_cores = data.get("nb_cores", asdict(defaults.nb_cores))
if enable_nb_cores_detection:
nb_cores.update(cls._autodetect_nb_cores())
xpress_dir = data.get("xpress_dir", defaults.xpress_dir)
local_workspace = Path(data["local_workspace"]) if "local_workspace" in data else defaults.local_workspace
return cls(
binaries={str(v): Path(p) for v, p in binaries.items()},
enable_nb_cores_detection=enable_nb_cores_detection,
nb_cores=NbCoresConfig(**nb_cores),
xpress_dir=xpress_dir,
local_workspace=local_workspace,
)

@classmethod
Expand Down
149 changes: 77 additions & 72 deletions antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from antarest.core.config import Config
from antarest.core.interfaces.eventbus import Event, EventChannelDirectory, EventType, IEventBus
from antarest.core.jwt import JWTUser
from antarest.core.model import PermissionInfo, PublicMode
from antarest.core.requests import MustBeAuthenticatedError, RequestParameters, UserHasNotPermissionError
from antarest.core.tasks.model import (
Expand All @@ -40,6 +41,7 @@
from antarest.core.tasks.repository import TaskJobRepository
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import retry
from antarest.login.utils import current_user_context
from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -294,16 +296,16 @@ def _launch_task(
type=EventType.TASK_ADDED,
payload=TaskEventPayload(
id=task.id,
message=custom_event_messages.start
if custom_event_messages is not None
else f"Task {task.id} added",
message=(
custom_event_messages.start if custom_event_messages is not None else f"Task {task.id} added"
),
type=task.type,
study_id=task.ref_id,
).model_dump(),
permissions=PermissionInfo(owner=request_params.user.impersonator),
)
)
future = self.threadpool.submit(self._run_task, action, task.id, custom_event_messages)
future = self.threadpool.submit(self._run_task, action, task.id, request_params.user, custom_event_messages)
self.tasks[task.id] = future

def create_task_event_callback(self) -> t.Callable[[Event], t.Awaitable[None]]:
Expand Down Expand Up @@ -399,94 +401,97 @@ def _run_task(
self,
callback: Task,
task_id: str,
jwt_user: JWTUser,
custom_event_messages: t.Optional[CustomTaskEventMessages] = None,
) -> None:
# We need to catch all exceptions so that the calling thread is guaranteed
# to not die
try:
# attention: this function is executed in a thread, not in the main process
with db():
# Important to keep this retry for now,
# in case commit is not visible (read from replica ...)
task = retry(lambda: self.repo.get_or_raise(task_id))
task_type = task.type
study_id = task.ref_id
with current_user_context(token=jwt_user):
with db():
# Important to keep this retry for now,
# in case commit is not visible (read from replica ...)
task = retry(lambda: self.repo.get_or_raise(task_id))
task_type = task.type
study_id = task.ref_id

self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running",
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
self.event_bus.push(
Event(
type=EventType.TASK_RUNNING,
payload=TaskEventPayload(
id=task_id,
message=(
custom_event_messages.running
if custom_event_messages is not None
else f"Task {task_id} is running"
),
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
)

logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{TaskJob.status: TaskStatus.RUNNING.value}
)
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")

with db():
# We must use the DB session attached to the current thread
result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus))

status = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
logger.info(f"Task {task_id} ended with status {status}")

with db():
# Do not use the `timezone.utc` timezone to preserve a naive datetime.
completion_date = datetime.datetime.utcnow() if status.is_final() else None
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: status.value,
TaskJob.result_msg: result.message,
TaskJob.result_status: result.success,
TaskJob.result: result.return_value,
TaskJob.completion_date: completion_date,
}
)
db.session.commit()
logger.info(f"Starting task {task_id}")
with db():
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{TaskJob.status: TaskStatus.RUNNING.value}
)
db.session.commit()
logger.info(f"Task {task_id} set to RUNNING")

event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success]
event_msg = {True: "completed", False: "failed"}[result.success]
self.event_bus.push(
Event(
type=event_type,
payload=TaskEventPayload(
id=task_id,
message=(
custom_event_messages.end
if custom_event_messages is not None
else f"Task {task_id} {event_msg}"
),
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
with db():
# We must use the DB session attached to the current thread
result = callback(TaskLogAndProgressRecorder(task_id, db.session, self.event_bus))

status = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED
logger.info(f"Task {task_id} ended with status {status}")

with db():
# Do not use the `timezone.utc` timezone to preserve a naive datetime.
completion_date = datetime.datetime.utcnow() if status.is_final() else None
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: status.value,
TaskJob.result_msg: result.message,
TaskJob.result_status: result.success,
TaskJob.result: result.return_value,
TaskJob.completion_date: completion_date,
}
)
db.session.commit()

event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success]
event_msg = {True: "completed", False: "failed"}[result.success]
self.event_bus.push(
Event(
type=event_type,
payload=TaskEventPayload(
id=task_id,
message=(
custom_event_messages.end
if custom_event_messages is not None
else f"Task {task_id} {event_msg}"
),
type=task_type,
study_id=study_id,
).model_dump(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
)
)
)
except Exception as exc:
err_msg = f"Task {task_id} failed: Unhandled exception {exc}"
logger.error(err_msg, exc_info=exc)

try:
with db():
result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback."
db.session.query(TaskJob).filter(TaskJob.id == task_id).update(
{
TaskJob.status: TaskStatus.FAILED.value,
TaskJob.result_msg: result_msg,
TaskJob.result_msg: str(exc),
TaskJob.result_status: False,
TaskJob.completion_date: datetime.datetime.utcnow(),
}
Expand Down
8 changes: 1 addition & 7 deletions antarest/launcher/adapters/abstractlauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from antarest.core.interfaces.cache import ICache
from antarest.core.interfaces.eventbus import Event, EventChannelDirectory, EventType, IEventBus
from antarest.core.model import PermissionInfo, PublicMode
from antarest.core.requests import RequestParameters
from antarest.launcher.adapters.log_parser import LaunchProgressDTO
from antarest.launcher.model import JobStatus, LauncherParametersDTO, LogType

Expand Down Expand Up @@ -70,12 +69,7 @@ def __init__(

@abstractmethod
def run_study(
self,
study_uuid: str,
job_id: str,
version: SolverVersion,
launcher_parameters: LauncherParametersDTO,
params: RequestParameters,
self, study_uuid: str, job_id: str, version: SolverVersion, launcher_parameters: LauncherParametersDTO
) -> None:
raise NotImplementedError()

Expand Down
Loading

0 comments on commit 9a90453

Please sign in to comment.