diff --git a/antarest/__init__.py b/antarest/__init__.py index ada981b5ca..ea7c2d6185 100644 --- a/antarest/__init__.py +++ b/antarest/__init__.py @@ -7,9 +7,9 @@ # Standard project metadata -__version__ = "2.16.0" +__version__ = "2.16.1" __author__ = "RTE, Antares Web Team" -__date__ = "2023-11-30" +__date__ = "2023-12-14" # noinspection SpellCheckingInspection __credits__ = "(c) Réseau de Transport de l’Électricité (RTE)" diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index a666394d8b..ab39c3a566 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -189,6 +189,11 @@ def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) +class DuplicateConstraintName(HTTPException): + def __init__(self, message: str) -> None: + super().__init__(HTTPStatus.CONFLICT, message) + + class MissingDataError(HTTPException): def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) diff --git a/antarest/core/jwt.py b/antarest/core/jwt.py index ff9ffd1187..16849fa9f0 100644 --- a/antarest/core/jwt.py +++ b/antarest/core/jwt.py @@ -3,9 +3,7 @@ from pydantic import BaseModel from antarest.core.roles import RoleType -from antarest.login.model import Group, Identity - -ADMIN_ID = 1 +from antarest.login.model import ADMIN_ID, Group, Identity class JWTGroup(BaseModel): diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index af3a46b8f7..1d7a9e1566 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -1,11 +1,12 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, List, Optional +from typing import Any, List, Mapping, Optional from pydantic import BaseModel, Extra from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore from antarest.core.persistence import Base @@ -171,3 +172,29 @@ def __repr__(self) -> str: f" result_msg={self.result_msg}," f" result_status={self.result_status}" ) + + +def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None: + """ + Cancel all tasks that are currently running or pending. + + When the web application restarts, such as after a new deployment, any pending or running tasks may be lost. + To mitigate this, it is preferable to set these tasks to a "FAILED" status. + This ensures that users can easily identify the tasks that were affected by the restart and take appropriate + actions, such as restarting the tasks manually. + + Args: + engine: The database engine (SQLAlchemy connection to SQLite or PostgreSQL). + session_args: The session arguments (SQLAlchemy session arguments). + """ + updated_values = { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result_status: False, + TaskJob.result_msg: "Task was interrupted due to server restart", + TaskJob.completion_date: datetime.utcnow(), + } + with sessionmaker(bind=engine, **session_args)() as session: + session.query(TaskJob).filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])).update( + updated_values, synchronize_session=False + ) + session.commit() diff --git a/antarest/core/tasks/repository.py b/antarest/core/tasks/repository.py index 1994c55fab..294f63255b 100644 --- a/antarest/core/tasks/repository.py +++ b/antarest/core/tasks/repository.py @@ -1,9 +1,10 @@ import datetime +import typing as t from http import HTTPStatus from operator import and_ -from typing import Any, List, Optional from fastapi import HTTPException +from sqlalchemy.orm import Session # type: ignore from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus from antarest.core.utils.fastapi_sqlalchemy import db @@ -11,16 +12,45 @@ class TaskJobRepository: + """ + Database connector to manage Tasks/Jobs entities. + """ + + def __init__(self, session: t.Optional[Session] = None): + """ + Initialize the repository. + + Args: + session: Optional SQLAlchemy session to be used. + """ + self._session = session + + @property + def session(self) -> Session: + """ + Get the SQLAlchemy session for the repository. + + Returns: + SQLAlchemy session. + """ + if self._session is None: + # Get or create the session from a context variable (thread local variable) + return db.session + # Get the user-defined session + return self._session + def save(self, task: TaskJob) -> TaskJob: - task = db.session.merge(task) - db.session.add(task) - db.session.commit() + session = self.session + task = session.merge(task) + session.add(task) + session.commit() return task - def get(self, id: str) -> Optional[TaskJob]: - task: TaskJob = db.session.get(TaskJob, id) + def get(self, id: str) -> t.Optional[TaskJob]: + session = self.session + task: TaskJob = session.get(TaskJob, id) if task is not None: - db.session.refresh(task) + session.refresh(task) return task def get_or_raise(self, id: str) -> TaskJob: @@ -30,7 +60,7 @@ def get_or_raise(self, id: str) -> TaskJob: return task @staticmethod - def _combine_clauses(where_clauses: List[Any]) -> Any: + def _combine_clauses(where_clauses: t.List[t.Any]) -> t.Any: assert_this(len(where_clauses) > 0) if len(where_clauses) > 1: return and_( @@ -40,9 +70,9 @@ def _combine_clauses(where_clauses: List[Any]) -> Any: else: return where_clauses[0] - def list(self, filter: TaskListFilter, user: Optional[int] = None) -> List[TaskJob]: - query = db.session.query(TaskJob) - where_clauses: List[Any] = [] + def list(self, filter: TaskListFilter, user: t.Optional[int] = None) -> t.List[TaskJob]: + query = self.session.query(TaskJob) + where_clauses: t.List[t.Any] = [] if user: where_clauses.append(TaskJob.owner_id == user) if len(filter.status) > 0: @@ -74,19 +104,21 @@ def list(self, filter: TaskListFilter, user: Optional[int] = None) -> List[TaskJ elif len(where_clauses) == 1: query = query.where(*where_clauses) - tasks: List[TaskJob] = query.all() + tasks: t.List[TaskJob] = query.all() return tasks def delete(self, tid: str) -> None: - task = db.session.get(TaskJob, tid) + session = self.session + task = session.get(TaskJob, tid) if task: - db.session.delete(task) - db.session.commit() + session.delete(task) + session.commit() def update_timeout(self, task_id: str, timeout: int) -> None: """Update task status to TIMEOUT.""" - task: TaskJob = db.session.get(TaskJob, task_id) + session = self.session + task: TaskJob = session.get(TaskJob, task_id) task.status = TaskStatus.TIMEOUT task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds" task.result_status = False - db.session.commit() + session.commit() diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 28038e2d4c..07832e7365 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -1,16 +1,16 @@ import datetime import logging import time +import typing as t from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from http import HTTPStatus -from typing import Awaitable, Callable, Dict, List, Optional, Union from fastapi import HTTPException +from sqlalchemy.orm import Session # type: ignore from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, EventChannelDirectory, EventType, IEventBus -from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.model import PermissionInfo, PublicMode from antarest.core.requests import MustBeAuthenticatedError, RequestParameters, UserHasNotPermissionError from antarest.core.tasks.model import ( @@ -26,13 +26,12 @@ ) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.core.utils.utils import retry from antarest.worker.worker import WorkerTaskCommand, WorkerTaskResult logger = logging.getLogger(__name__) -TaskUpdateNotifier = Callable[[str], None] -Task = Callable[[TaskUpdateNotifier], TaskResult] +TaskUpdateNotifier = t.Callable[[str], None] +Task = t.Callable[[TaskUpdateNotifier], TaskResult] DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours """Default timeout for `await_task` in seconds.""" @@ -44,21 +43,21 @@ def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: raise NotImplementedError() @abstractmethod def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: raise NotImplementedError() @@ -73,7 +72,7 @@ def status_task( raise NotImplementedError() @abstractmethod - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: raise NotImplementedError() @abstractmethod @@ -86,6 +85,26 @@ def noop_notifier(message: str) -> None: """This function is used in tasks when no notification is required.""" +class TaskJobLogRecorder: + """ + Callback used to register log messages in the TaskJob table. + + Args: + task_id: The task id. + session: The database session created in the same thread as the task thread. + """ + + def __init__(self, task_id: str, session: Session): + self.session = session + self.task_id = task_id + + def __call__(self, message: str) -> None: + task = self.session.query(TaskJob).get(self.task_id) + if task: + task.logs.append(TaskJobLog(message=message, task_id=self.task_id)) + db.session.commit() + + class TaskJobService(ITaskService): def __init__( self, @@ -96,24 +115,22 @@ def __init__( self.config = config self.repo = repository self.event_bus = event_bus - self.tasks: Dict[str, Future[None]] = {} + self.tasks: t.Dict[str, Future[None]] = {} self.threadpool = ThreadPoolExecutor(max_workers=config.tasks.max_workers, thread_name_prefix="taskjob_") self.event_bus.add_listener(self.create_task_event_callback(), [EventType.TASK_CANCEL_REQUEST]) self.remote_workers = config.tasks.remote_workers - # set the status of previously running job to FAILED due to server restart - self._fix_running_status() def _create_worker_task( self, task_id: str, task_type: str, - task_args: Dict[str, Union[int, float, bool, str]], - ) -> Callable[[TaskUpdateNotifier], TaskResult]: - task_result_wrapper: List[TaskResult] = [] + task_args: t.Dict[str, t.Union[int, float, bool, str]], + ) -> t.Callable[[TaskUpdateNotifier], TaskResult]: + task_result_wrapper: t.List[TaskResult] = [] def _create_awaiter( - res_wrapper: List[TaskResult], - ) -> Callable[[Event], Awaitable[None]]: + res_wrapper: t.List[TaskResult], + ) -> t.Callable[[Event], t.Awaitable[None]]: async def _await_task_end(event: Event) -> None: task_event = WorkerTaskResult.parse_obj(event.payload) if task_event.task_id == task_id: @@ -155,11 +172,11 @@ def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: if not self.check_remote_worker_for_queue(task_queue): logger.warning(f"Failed to find configured remote worker for task queue {task_queue}") return None @@ -176,10 +193,10 @@ def add_worker_task( def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: task = self._create_task(name, task_type, ref_id, request_params) @@ -188,9 +205,9 @@ def add_task( def _create_task( self, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], request_params: RequestParameters, ) -> TaskJob: if not request_params.user: @@ -209,7 +226,7 @@ def _launch_task( self, action: Task, task: TaskJob, - custom_event_messages: Optional[CustomTaskEventMessages], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> None: if not request_params.user: @@ -230,7 +247,7 @@ def _launch_task( future = self.threadpool.submit(self._run_task, action, task.id, custom_event_messages) self.tasks[task.id] = future - def create_task_event_callback(self) -> Callable[[Event], Awaitable[None]]: + def create_task_event_callback(self) -> t.Callable[[Event], t.Awaitable[None]]: async def task_event_callback(event: Event) -> None: self._cancel_task(str(event.payload), dispatch=False) @@ -275,10 +292,10 @@ def status_task( detail=f"Failed to retrieve task {task_id} in db", ) - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: return [task.to_dto() for task in self.list_db_tasks(task_filter, request_params)] - def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskJob]: + def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskJob]: if not request_params.user: raise MustBeAuthenticatedError() user = None if request_params.user.is_site_admin() else request_params.user.impersonator @@ -297,25 +314,33 @@ def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db") end = time.time() + timeout_sec while time.time() < end: - with db(): - task = self.repo.get(task_id) - if task is None: - logger.error(f"Awaited task '{task_id}' was not found") - return - if TaskStatus(task.status).is_final(): - return + task_status = db.session.query(TaskJob.status).filter(TaskJob.id == task_id).scalar() + if task_status is None: + logger.error(f"Awaited task '{task_id}' was not found") + return + if TaskStatus(task_status).is_final(): + return logger.info("💤 Sleeping 2 seconds...") time.sleep(2) + logger.error(f"Timeout while awaiting task '{task_id}'") - with db(): - self.repo.update_timeout(task_id, timeout_sec) + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.TIMEOUT.value, + TaskJob.result_msg: f"Task '{task_id}' timeout after {timeout_sec} seconds", + TaskJob.result_status: False, + } + ) + db.session.commit() def _run_task( self, callback: Task, task_id: str, - custom_event_messages: Optional[CustomTaskEventMessages] = None, + custom_event_messages: t.Optional[CustomTaskEventMessages] = None, ) -> None: + # attention: this function is executed in a thread, not in the main process + self.event_bus.push( Event( type=EventType.TASK_RUNNING, @@ -332,22 +357,32 @@ def _run_task( logger.info(f"Starting task {task_id}") with db(): - task = retry(lambda: self.repo.get_or_raise(task_id)) - task.status = TaskStatus.RUNNING.value - self.repo.save(task) - logger.info(f"Task {task_id} set to RUNNING") + db.session.query(TaskJob).filter(TaskJob.id == task_id).update({TaskJob.status: TaskStatus.RUNNING.value}) + db.session.commit() + logger.info(f"Task {task_id} set to RUNNING") + try: with db(): - result = callback(self._task_logger(task_id)) - logger.info(f"Task {task_id} ended") + # We must use the DB session attached to the current thread + result = callback(TaskJobLogRecorder(task_id, session=db.session)) + + status = TaskStatus.COMPLETED if result.success else TaskStatus.FAILED + logger.info(f"Task {task_id} ended with status {status}") + with db(): - self._update_task_status( - task_id, - TaskStatus.COMPLETED if result.success else TaskStatus.FAILED, - result.success, - result.message, - result.return_value, + # Do not use the `timezone.utc` timezone to preserve a naive datetime. + completion_date = datetime.datetime.utcnow() if status.is_final() else None + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: status.value, + TaskJob.result_msg: result.message, + TaskJob.result_status: result.success, + TaskJob.result: result.return_value, + TaskJob.completion_date: completion_date, + } ) + db.session.commit() + event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success] event_msg = {True: "completed", False: "failed"}[result.success] self.event_bus.push( @@ -368,13 +403,19 @@ def _run_task( except Exception as exc: err_msg = f"Task {task_id} failed: Unhandled exception {exc}" logger.error(err_msg, exc_info=exc) + with db(): - self._update_task_status( - task_id, - TaskStatus.FAILED, - False, - f"{err_msg}\nSee the logs for detailed information and the error traceback.", + result_msg = f"{err_msg}\nSee the logs for detailed information and the error traceback." + db.session.query(TaskJob).filter(TaskJob.id == task_id).update( + { + TaskJob.status: TaskStatus.FAILED.value, + TaskJob.result_msg: result_msg, + TaskJob.result_status: False, + TaskJob.completion_date: datetime.datetime.utcnow(), + } ) + db.session.commit() + message = err_msg if custom_event_messages is None else custom_event_messages.end self.event_bus.push( Event( @@ -384,44 +425,3 @@ def _run_task( channel=EventChannelDirectory.TASK + task_id, ) ) - - def _task_logger(self, task_id: str) -> Callable[[str], None]: - def log_msg(message: str) -> None: - task = self.repo.get(task_id) - if task: - task.logs.append(TaskJobLog(message=message, task_id=task_id)) - self.repo.save(task) - - return log_msg - - def _fix_running_status(self) -> None: - with db(): - previous_tasks = self.list_db_tasks( - TaskListFilter(status=[TaskStatus.RUNNING, TaskStatus.PENDING]), - request_params=RequestParameters(user=DEFAULT_ADMIN_USER), - ) - for task in previous_tasks: - self._update_task_status( - task.id, - TaskStatus.FAILED, - False, - "Task was interrupted due to server restart", - ) - - def _update_task_status( - self, - task_id: str, - status: TaskStatus, - result: bool, - message: str, - command_result: Optional[str] = None, - ) -> None: - task = self.repo.get_or_raise(task_id) - task.status = status.value - task.result_msg = message - task.result_status = result - task.result = command_result - if status.is_final(): - # Do not use the `timezone.utc` timezone to preserve a naive datetime. - task.completion_date = datetime.datetime.utcnow() - self.repo.save(task) diff --git a/antarest/core/utils/fastapi_sqlalchemy/exceptions.py b/antarest/core/utils/fastapi_sqlalchemy/exceptions.py index 7e435ba286..ad1eccff2c 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/exceptions.py +++ b/antarest/core/utils/fastapi_sqlalchemy/exceptions.py @@ -1,5 +1,5 @@ class MissingSessionError(Exception): - """Excetion raised for when the user tries to access a database session before it is created.""" + """Exception raised for when the user tries to access a database session before it is created.""" def __init__(self) -> None: msg = """ diff --git a/antarest/login/main.py b/antarest/login/main.py index 9b487de5b7..d87a082abd 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -37,7 +37,7 @@ def build_login( """ if service is None: - user_repo = UserRepository(config) + user_repo = UserRepository() bot_repo = BotRepository() group_repo = GroupRepository() role_repo = RoleRepository() diff --git a/antarest/login/model.py b/antarest/login/model.py index 52106685bc..5012a4995c 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -1,11 +1,14 @@ +import contextlib import typing as t import uuid import bcrypt from pydantic.main import BaseModel from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, Sequence, String # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore from sqlalchemy.ext.hybrid import hybrid_property # type: ignore -from sqlalchemy.orm import relationship # type: ignore +from sqlalchemy.orm import relationship, sessionmaker # type: ignore from antarest.core.persistence import Base from antarest.core.roles import RoleType @@ -15,6 +18,19 @@ from antarest.launcher.model import JobResult +GROUP_ID = "admin" +"""Unique ID of the administrator group.""" + +GROUP_NAME = "admin" +"""Name of the administrator group.""" + +ADMIN_ID = 1 +"""Unique ID of the site administrator.""" + +ADMIN_NAME = "admin" +"""Name of the site administrator.""" + + class UserInfo(BaseModel): id: int name: str @@ -282,3 +298,32 @@ class CredentialsDTO(BaseModel): user: int access_token: str refresh_token: str + + +def init_admin_user(engine: Engine, session_args: t.Mapping[str, bool], admin_password: str) -> None: + """ + Create the default admin user, group and role if they do not already exist in the database. + + Args: + engine: The database engine (SQLAlchemy connection to SQLite or PostgreSQL). + session_args: The session arguments (SQLAlchemy session arguments). + admin_password: The admin password extracted from the configuration file. + """ + make_session = sessionmaker(bind=engine, **session_args) + with make_session() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + with contextlib.suppress(IntegrityError): + session.add(group) + session.commit() + + with make_session() as session: + user = User(id=ADMIN_ID, name=ADMIN_NAME, password=Password(admin_password)) + with contextlib.suppress(IntegrityError): + session.add(user) + session.commit() + + with make_session() as session: + role = Role(type=RoleType.ADMIN, identity_id=ADMIN_ID, group_id=GROUP_ID) + with contextlib.suppress(IntegrityError): + session.add(role) + session.commit() diff --git a/antarest/login/repository.py b/antarest/login/repository.py index 4f68e1924c..d70fa57e13 100644 --- a/antarest/login/repository.py +++ b/antarest/login/repository.py @@ -2,13 +2,10 @@ from typing import List, Optional from sqlalchemy import exists # type: ignore -from sqlalchemy.orm import joinedload # type: ignore +from sqlalchemy.orm import Session, joinedload # type: ignore -from antarest.core.config import Config -from antarest.core.jwt import ADMIN_ID -from antarest.core.roles import RoleType from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.login.model import Bot, Group, Password, Role, User, UserLdap +from antarest.login.model import Bot, Group, Role, User, UserLdap logger = logging.getLogger(__name__) @@ -18,37 +15,46 @@ class GroupRepository: Database connector to manage Group entity. """ - def __init__(self) -> None: - with db(): - self.save(Group(id="admin", name="admin")) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, group: Group) -> Group: - res = db.session.query(exists().where(Group.id == group.id)).scalar() + res = self.session.query(exists().where(Group.id == group.id)).scalar() if res: - db.session.merge(group) + self.session.merge(group) else: - db.session.add(group) - db.session.commit() + self.session.add(group) + self.session.commit() logger.debug(f"Group {group.id} saved") return group def get(self, id: str) -> Optional[Group]: - group: Group = db.session.query(Group).get(id) + group: Group = self.session.query(Group).get(id) return group def get_by_name(self, name: str) -> Group: - group: Group = db.session.query(Group).filter_by(name=name).first() + group: Group = self.session.query(Group).filter_by(name=name).first() return group def get_all(self) -> List[Group]: - groups: List[Group] = db.session.query(Group).all() + groups: List[Group] = self.session.query(Group).all() return groups def delete(self, id: str) -> None: - g = db.session.query(Group).get(id) - db.session.delete(g) - db.session.commit() + g = self.session.query(Group).get(id) + self.session.delete(g) + self.session.commit() logger.debug(f"Group {id} deleted") @@ -58,49 +64,46 @@ class UserRepository: Database connector to manage User entity. """ - def __init__(self, config: Config) -> None: - # init seed admin user from conf - with db(): - admin_user = self.get_by_name("admin") - if admin_user is None: - self.save( - User( - id=ADMIN_ID, - name="admin", - password=Password(config.security.admin_pwd), - ) - ) - elif not admin_user.password.check(config.security.admin_pwd): # type: ignore - admin_user.password = Password(config.security.admin_pwd) # type: ignore - self.save(admin_user) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, user: User) -> User: - res = db.session.query(exists().where(User.id == user.id)).scalar() + res = self.session.query(exists().where(User.id == user.id)).scalar() if res: - db.session.merge(user) + self.session.merge(user) else: - db.session.add(user) - db.session.commit() + self.session.add(user) + self.session.commit() logger.debug(f"User {user.id} saved") return user - def get(self, id: int) -> Optional[User]: - user: User = db.session.query(User).get(id) + def get(self, id_number: int) -> Optional[User]: + user: User = self.session.query(User).get(id_number) return user def get_by_name(self, name: str) -> Optional[User]: - user: User = db.session.query(User).filter_by(name=name).first() + user: User = self.session.query(User).filter_by(name=name).first() return user def get_all(self) -> List[User]: - users: List[User] = db.session.query(User).all() + users: List[User] = self.session.query(User).all() return users def delete(self, id: int) -> None: - u: User = db.session.query(User).get(id) - db.session.delete(u) - db.session.commit() + u: User = self.session.query(User).get(id) + self.session.delete(u) + self.session.commit() logger.debug(f"User {id} deleted") @@ -110,39 +113,54 @@ class UserLdapRepository: Database connector to manage UserLdap entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, user_ldap: UserLdap) -> UserLdap: - res = db.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() + res = self.session.query(exists().where(UserLdap.id == user_ldap.id)).scalar() if res: - db.session.merge(user_ldap) + self.session.merge(user_ldap) else: - db.session.add(user_ldap) - db.session.commit() + self.session.add(user_ldap) + self.session.commit() logger.debug(f"User LDAP {user_ldap.id} saved") return user_ldap - def get(self, id: int) -> Optional[UserLdap]: - user_ldap: Optional[UserLdap] = db.session.query(UserLdap).get(id) + def get(self, id_number: int) -> Optional[UserLdap]: + user_ldap: Optional[UserLdap] = self.session.query(UserLdap).get(id_number) return user_ldap def get_by_name(self, name: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(name=name).first() + user: UserLdap = self.session.query(UserLdap).filter_by(name=name).first() return user def get_by_external_id(self, external_id: str) -> Optional[UserLdap]: - user: UserLdap = db.session.query(UserLdap).filter_by(external_id=external_id).first() + user: UserLdap = self.session.query(UserLdap).filter_by(external_id=external_id).first() return user - def get_all(self) -> List[UserLdap]: - users_ldap: List[UserLdap] = db.session.query(UserLdap).all() + def get_all( + self, + ) -> List[UserLdap]: + users_ldap: List[UserLdap] = self.session.query(UserLdap).all() return users_ldap - def delete(self, id: int) -> None: - u: UserLdap = db.session.query(UserLdap).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: UserLdap = self.session.query(UserLdap).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"User LDAP {id} deleted") + logger.debug(f"User LDAP {id_number} deleted") class BotRepository: @@ -150,42 +168,57 @@ class BotRepository: Database connector to manage Bot entity. """ + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, bot: Bot) -> Bot: - res = db.session.query(exists().where(Bot.id == bot.id)).scalar() + res = self.session.query(exists().where(Bot.id == bot.id)).scalar() if res: raise ValueError("Bot already exist") else: - db.session.add(bot) - db.session.commit() + self.session.add(bot) + self.session.commit() logger.debug(f"Bot {bot.id} saved") return bot - def get(self, id: int) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).get(id) + def get(self, id_number: int) -> Optional[Bot]: + bot: Bot = self.session.query(Bot).get(id_number) return bot - def get_all(self) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).all() + def get_all( + self, + ) -> List[Bot]: + bots: List[Bot] = self.session.query(Bot).all() return bots - def delete(self, id: int) -> None: - u: Bot = db.session.query(Bot).get(id) - db.session.delete(u) - db.session.commit() + def delete(self, id_number: int) -> None: + u: Bot = self.session.query(Bot).get(id_number) + self.session.delete(u) + self.session.commit() - logger.debug(f"Bot {id} deleted") + logger.debug(f"Bot {id_number} deleted") def get_all_by_owner(self, owner: int) -> List[Bot]: - bots: List[Bot] = db.session.query(Bot).filter_by(owner=owner).all() + bots: List[Bot] = self.session.query(Bot).filter_by(owner=owner).all() return bots def get_by_name_and_owner(self, owner: int, name: str) -> Optional[Bot]: - bot: Bot = db.session.query(Bot).filter_by(owner=owner, name=name).first() + bot: Bot = self.session.query(Bot).filter_by(owner=owner, name=name).first() return bot - def exists(self, id: int) -> bool: - res: bool = db.session.query(exists().where(Bot.id == id)).scalar() + def exists(self, id_number: int) -> bool: + res: bool = self.session.query(exists().where(Bot.id == id_number)).scalar() return res @@ -194,29 +227,31 @@ class RoleRepository: Database connector to manage Role entity. """ - def __init__(self) -> None: - with db(): - if self.get(1, "admin") is None: - self.save( - Role( - type=RoleType.ADMIN, - identity=User(id=1), - group=Group(id="admin"), - ) - ) + def __init__( + self, + session: Optional[Session] = None, + ) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session def save(self, role: Role) -> Role: - role.group = db.session.merge(role.group) - role.identity = db.session.merge(role.identity) + role.group = self.session.merge(role.group) + role.identity = self.session.merge(role.identity) - db.session.add(role) - db.session.commit() + self.session.add(role) + self.session.commit() logger.debug(f"Role (user={role.identity}, group={role.group} saved") return role def get(self, user: int, group: str) -> Optional[Role]: - role: Role = db.session.query(Role).get((user, group)) + role: Role = self.session.query(Role).get((user, group)) return role def get_all_by_user(self, /, user_id: int) -> List[Role]: @@ -231,17 +266,17 @@ def get_all_by_user(self, /, user_id: int) -> List[Role]: """ # When we fetch the list of roles, we also need to fetch the associated groups. # We use a SQL query with joins to fetch all these data efficiently. - stm = db.session.query(Role).options(joinedload(Role.group)).filter_by(identity_id=user_id) + stm = self.session.query(Role).options(joinedload(Role.group)).filter_by(identity_id=user_id) roles: List[Role] = stm.all() return roles def get_all_by_group(self, group: str) -> List[Role]: - roles: List[Role] = db.session.query(Role).filter_by(group_id=group).all() + roles: List[Role] = self.session.query(Role).filter_by(group_id=group).all() return roles def delete(self, user: int, group: str) -> None: - r = db.session.query(Role).get((user, group)) - db.session.delete(r) - db.session.commit() + r = self.session.query(Role).get((user, group)) + self.session.delete(r) + self.session.commit() logger.debug(f"Role (user={user}, group={group} deleted") diff --git a/antarest/main.py b/antarest/main.py index 1e0c9183dd..700947a4dd 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -30,15 +30,18 @@ from antarest.core.logging.utils import LoggingMiddleware, configure_logger from antarest.core.requests import RATE_LIMIT_CONFIG from antarest.core.swagger import customize_openapi +from antarest.core.tasks.model import cancel_orphan_tasks +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.login.auth import Auth, JwtSettings +from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector -from antarest.singleton_services import SingletonServices +from antarest.singleton_services import start_all_services from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.study.storage.rawstudy.watcher import Watcher from antarest.tools.admin_lib import clean_locks -from antarest.utils import Module, create_services, init_db +from antarest.utils import SESSION_ARGS, Module, create_services, init_db_engine logger = logging.getLogger(__name__) @@ -246,7 +249,8 @@ def fastapi_app( ) # Database - init_db(config_file, config, auto_upgrade_db, application) + engine = init_db_engine(config_file, config, auto_upgrade_db) + application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) application.add_middleware(LoggingMiddleware) @@ -401,6 +405,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) + init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) services = create_services(config, application) if mount_front: @@ -428,6 +433,7 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) return application, services @@ -455,8 +461,7 @@ def main() -> None: # noinspection PyTypeChecker uvicorn.run(app, host="0.0.0.0", port=8080, log_config=LOGGING_CONFIG) else: - services = SingletonServices(arguments.config_file, [arguments.module]) - services.start() + start_all_services(arguments.config_file, [arguments.module]) if __name__ == "__main__": diff --git a/antarest/matrixstore/model.py b/antarest/matrixstore/model.py index 1f6e500c42..aa9a4a91a9 100644 --- a/antarest/matrixstore/model.py +++ b/antarest/matrixstore/model.py @@ -1,6 +1,6 @@ import datetime +import typing as t import uuid -from typing import Any, List, Union from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table # type: ignore @@ -29,7 +29,11 @@ class Matrix(Base): # type: ignore height: int = Column(Integer) created_at: datetime.datetime = Column(DateTime) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return f"Matrix(id={self.id}, shape={(self.height, self.width)}, created_at={self.created_at})" + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, Matrix): return False @@ -50,9 +54,9 @@ class MatrixInfoDTO(BaseModel): class MatrixDataSetDTO(BaseModel): id: str name: str - matrices: List[MatrixInfoDTO] + matrices: t.List[MatrixInfoDTO] owner: UserInfo - groups: List[GroupDTO] + groups: t.List[GroupDTO] public: bool created_at: str updated_at: str @@ -85,7 +89,11 @@ class MatrixDataSetRelation(Base): # type: ignore name: str = Column(String, primary_key=True) matrix: Matrix = relationship(Matrix) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return f"MatrixDataSetRelation(dataset_id={self.dataset_id}, matrix_id={self.matrix_id}, name={self.name})" + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, MatrixDataSetRelation): return False @@ -152,7 +160,18 @@ def to_dto(self) -> MatrixDataSetDTO: updated_at=str(self.updated_at), ) - def __eq__(self, other: Any) -> bool: + def __repr__(self) -> str: # pragma: no cover + """Returns a string representation of the matrix.""" + return ( + f"MatrixDataSet(id={self.id}," + f" name={self.name}," + f" owner_id={self.owner_id}," + f" public={self.public}," + f" created_at={self.created_at}," + f" updated_at={self.updated_at})" + ) + + def __eq__(self, other: t.Any) -> bool: if not isinstance(other, MatrixDataSet): return False @@ -181,9 +200,9 @@ def __eq__(self, other: Any) -> bool: class MatrixDTO(BaseModel): width: int height: int - index: List[str] - columns: List[str] - data: List[List[MatrixData]] + index: t.List[str] + columns: t.List[str] + data: t.List[t.List[MatrixData]] created_at: int = 0 id: str = "" @@ -198,12 +217,12 @@ class MatrixContent(BaseModel): columns: A list of columns indexes or names. """ - data: List[List[MatrixData]] - index: List[Union[int, str]] - columns: List[Union[int, str]] + data: t.List[t.List[MatrixData]] + index: t.List[t.Union[int, str]] + columns: t.List[t.Union[int, str]] class MatrixDataSetUpdateDTO(BaseModel): name: str - groups: List[str] + groups: t.List[str] public: bool diff --git a/antarest/matrixstore/repository.py b/antarest/matrixstore/repository.py index 6301e39c7f..9ab44a69ec 100644 --- a/antarest/matrixstore/repository.py +++ b/antarest/matrixstore/repository.py @@ -7,7 +7,7 @@ from filelock import FileLock from numpy import typing as npt from sqlalchemy import and_, exists # type: ignore -from sqlalchemy.orm import aliased # type: ignore +from sqlalchemy.orm import Session, aliased # type: ignore from antarest.core.utils.fastapi_sqlalchemy import db from antarest.matrixstore.model import Matrix, MatrixContent, MatrixData, MatrixDataSet @@ -20,23 +20,33 @@ class MatrixDataSetRepository: Database connector to manage Matrix metadata entity """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix_user_metadata: MatrixDataSet) -> MatrixDataSet: - res: bool = db.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() + res: bool = self.session.query(exists().where(MatrixDataSet.id == matrix_user_metadata.id)).scalar() if res: - matrix_user_metadata = db.session.merge(matrix_user_metadata) + matrix_user_metadata = self.session.merge(matrix_user_metadata) else: - db.session.add(matrix_user_metadata) - db.session.commit() + self.session.add(matrix_user_metadata) + self.session.commit() logger.debug(f"Matrix dataset {matrix_user_metadata.id} for user {matrix_user_metadata.owner_id} saved") return matrix_user_metadata - def get(self, id: str) -> t.Optional[MatrixDataSet]: - matrix: MatrixDataSet = db.session.query(MatrixDataSet).get(id) + def get(self, id_number: str) -> t.Optional[MatrixDataSet]: + matrix: MatrixDataSet = self.session.query(MatrixDataSet).get(id_number) return matrix def get_all_datasets(self) -> t.List[MatrixDataSet]: - matrix_datasets: t.List[MatrixDataSet] = db.session.query(MatrixDataSet).all() + matrix_datasets: t.List[MatrixDataSet] = self.session.query(MatrixDataSet).all() return matrix_datasets def query( @@ -54,7 +64,7 @@ def query( Returns: the list of metadata per user, matching the query """ - query = db.session.query(MatrixDataSet) + query = self.session.query(MatrixDataSet) if name is not None: query = query.filter(MatrixDataSet.name.ilike(f"%{name}%")) # type: ignore if owner is not None: @@ -63,9 +73,9 @@ def query( return datasets def delete(self, dataset_id: str) -> None: - dataset = db.session.query(MatrixDataSet).get(dataset_id) - db.session.delete(dataset) - db.session.commit() + dataset = self.session.query(MatrixDataSet).get(dataset_id) + self.session.delete(dataset) + self.session.commit() class MatrixRepository: @@ -73,28 +83,38 @@ class MatrixRepository: Database connector to manage Matrix entity. """ + def __init__(self, session: t.Optional[Session] = None) -> None: + self._session = session + + @property + def session(self) -> Session: + """Get the SqlAlchemy session or create a new one on the fly if not available in the current thread.""" + if self._session is None: + return db.session + return self._session + def save(self, matrix: Matrix) -> Matrix: - if db.session.query(exists().where(Matrix.id == matrix.id)).scalar(): - db.session.merge(matrix) + if self.session.query(exists().where(Matrix.id == matrix.id)).scalar(): + self.session.merge(matrix) else: - db.session.add(matrix) - db.session.commit() + self.session.add(matrix) + self.session.commit() logger.debug(f"Matrix {matrix.id} saved") return matrix def get(self, matrix_hash: str) -> t.Optional[Matrix]: - matrix: Matrix = db.session.query(Matrix).get(matrix_hash) + matrix: Matrix = self.session.query(Matrix).get(matrix_hash) return matrix def exists(self, matrix_hash: str) -> bool: - res: bool = db.session.query(exists().where(Matrix.id == matrix_hash)).scalar() + res: bool = self.session.query(exists().where(Matrix.id == matrix_hash)).scalar() return res def delete(self, matrix_hash: str) -> None: - if g := db.session.query(Matrix).get(matrix_hash): - db.session.delete(g) - db.session.commit() + if g := self.session.query(Matrix).get(matrix_hash): + self.session.delete(g) + self.session.commit() else: logger.warning(f"Trying to delete matrix {matrix_hash}, but was not found in database!") logger.debug(f"Matrix {matrix_hash} deleted") diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 4869ed11fa..c0a9d91788 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -54,6 +54,9 @@ class ISimpleMatrixService(ABC): + def __init__(self, matrix_content_repository: MatrixContentRepository) -> None: + self.matrix_content_repository = matrix_content_repository + @abstractmethod def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: raise NotImplementedError() @@ -72,15 +75,14 @@ def delete(self, matrix_id: str) -> None: class SimpleMatrixService(ISimpleMatrixService): - def __init__(self, bucket_dir: Path): - self.bucket_dir = bucket_dir - self.content_repo = MatrixContentRepository(bucket_dir) + def __init__(self, matrix_content_repository: MatrixContentRepository): + super().__init__(matrix_content_repository=matrix_content_repository) def create(self, data: Union[List[List[MatrixData]], npt.NDArray[np.float64]]) -> str: - return self.content_repo.save(data) + return self.matrix_content_repository.save(data) def get(self, matrix_id: str) -> MatrixDTO: - data = self.content_repo.get(matrix_id) + data = self.matrix_content_repository.get(matrix_id) return MatrixDTO.construct( id=matrix_id, width=len(data.columns), @@ -91,10 +93,10 @@ def get(self, matrix_id: str) -> MatrixDTO: ) def exists(self, matrix_id: str) -> bool: - return self.content_repo.exists(matrix_id) + return self.matrix_content_repository.exists(matrix_id) def delete(self, matrix_id: str) -> None: - self.content_repo.delete(matrix_id) + self.matrix_content_repository.delete(matrix_id) class MatrixService(ISimpleMatrixService): @@ -108,9 +110,9 @@ def __init__( config: Config, user_service: LoginService, ): + super().__init__(matrix_content_repository=matrix_content_repository) self.repo = repo self.repo_dataset = repo_dataset - self.matrix_content_repository = matrix_content_repository self.user_service = user_service self.file_transfer_manager = file_transfer_manager self.task_service = task_service diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 9b702a346b..f106099523 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -1,90 +1,76 @@ -import logging -import time from pathlib import Path -from typing import Dict, List +from typing import Dict, List, cast from antarest.core.config import Config from antarest.core.interfaces.service import IService from antarest.core.logging.utils import configure_logger +from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.study.storage.auto_archive_service import AutoArchiveService from antarest.utils import ( + SESSION_ARGS, Module, create_archive_worker, create_core_services, create_matrix_gc, create_simulator_worker, create_watcher, - init_db, + init_db_engine, ) -logger = logging.getLogger(__name__) - -class SingletonServices: - def __init__(self, config_file: Path, services_list: List[Module]) -> None: - self.services_list = self._init(config_file, services_list) - - @staticmethod - def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - init_db(config_file, config, False, None) - configure_logger(config) - - ( - cache, - event_bus, - task_service, - ft_manager, - login_service, - matrix_service, - study_service, - ) = create_core_services(None, config) - - services: Dict[Module, IService] = {} - - if Module.WATCHER in services_list: - watcher = create_watcher(config=config, application=None, study_service=study_service) - services[Module.WATCHER] = watcher - - if Module.MATRIX_GC in services_list: - matrix_gc = create_matrix_gc( - config=config, - application=None, - study_service=study_service, - matrix_service=matrix_service, - ) - services[Module.MATRIX_GC] = matrix_gc - - if Module.ARCHIVE_WORKER in services_list: - worker = create_archive_worker(config, "test", event_bus=event_bus) - services[Module.ARCHIVE_WORKER] = worker - - if Module.SIMULATOR_WORKER in services_list: - worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) - services[Module.SIMULATOR_WORKER] = worker - - if Module.AUTO_ARCHIVER in services_list: - auto_archive_service = AutoArchiveService(study_service, config) - services[Module.AUTO_ARCHIVER] = auto_archive_service - - return services - - def start(self) -> None: - for service in self.services_list: - self.services_list[service].start(threaded=True) - - self._loop() - - def _loop(self) -> None: - while True: - try: - pass - except Exception as e: - logger.error( - "Unexpected error happened while processing service manager loop", - exc_info=e, - ) - finally: - time.sleep(2) +def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IService]: + res = get_local_path() / "resources" + config = Config.from_yaml_file(res=res, file=config_file) + engine = init_db_engine( + config_file, + config, + False, + ) + DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) + configure_logger(config) + + ( + cache, + event_bus, + task_service, + ft_manager, + login_service, + matrix_service, + study_service, + ) = create_core_services(None, config) + + services: Dict[Module, IService] = {} + + if Module.WATCHER in services_list: + watcher = create_watcher(config=config, application=None, study_service=study_service) + services[Module.WATCHER] = watcher + + if Module.MATRIX_GC in services_list: + matrix_gc = create_matrix_gc( + config=config, + application=None, + study_service=study_service, + matrix_service=matrix_service, + ) + services[Module.MATRIX_GC] = matrix_gc + + if Module.ARCHIVE_WORKER in services_list: + worker = create_archive_worker(config, "test", event_bus=event_bus) + services[Module.ARCHIVE_WORKER] = worker + + if Module.SIMULATOR_WORKER in services_list: + worker = create_simulator_worker(config, matrix_service=matrix_service, event_bus=event_bus) + services[Module.SIMULATOR_WORKER] = worker + + if Module.AUTO_ARCHIVER in services_list: + auto_archive_service = AutoArchiveService(study_service, config) + services[Module.AUTO_ARCHIVER] = auto_archive_service + + return services + + +def start_all_services(config_file: Path, services_list: List[Module]) -> None: + services = _init(config_file, services_list) + for service in services: + services[service].start(threaded=True) diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index ca1f714750..7caeabd9ab 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -5,6 +5,7 @@ from antarest.core.exceptions import ( ConstraintAlreadyExistError, ConstraintIdNotFoundError, + DuplicateConstraintName, MissingDataError, NoBindingConstraintError, NoConstraintError, @@ -13,8 +14,13 @@ from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency +from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator +from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( + BindingConstraintProperties, + CreateBindingConstraint, +) from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint @@ -40,6 +46,10 @@ class UpdateBindingConstProps(BaseModel): value: Any +class BindingConstraintPropertiesWithName(BindingConstraintProperties): + name: str + + class BindingConstraintDTO(BaseModel): id: str name: str @@ -153,6 +163,32 @@ def get_binding_constraint( binding_constraint.append(new_config) return binding_constraint + def create_binding_constraint( + self, + study: Study, + data: BindingConstraintPropertiesWithName, + ) -> None: + binding_constraints = self.get_binding_constraint(study, None) + existing_ids = [bd.id for bd in binding_constraints] # type: ignore + bd_id = transform_name_to_id(data.name) + if bd_id in existing_ids: + raise DuplicateConstraintName(f"A binding constraint with the same name already exists: {bd_id}.") + + file_study = self.storage_service.get_storage(study).get_raw(study) + command = CreateBindingConstraint( + name=bd_id, + enabled=data.enabled, + time_step=data.time_step, + operator=data.operator, + coeffs=data.coeffs, + values=data.values, + filter_year_by_year=data.filter_year_by_year, + filter_synthesis=data.filter_synthesis, + comments=data.comments or "", + command_context=self.storage_service.variant_study_service.command_factory.command_context, + ) + execute_or_add_commands(study, file_study, [command], self.storage_service) + def update_binding_constraint( self, study: Study, diff --git a/antarest/study/main.py b/antarest/study/main.py index e4a981afd2..83ad90dca3 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -80,7 +80,9 @@ def build_study_service( cache=cache, ) - generator_matrix_constants = generator_matrix_constants or GeneratorMatrixConstants(matrix_service=matrix_service) + if not generator_matrix_constants: + generator_matrix_constants = GeneratorMatrixConstants(matrix_service=matrix_service) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 94a0220e37..1a830c7428 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -53,18 +53,19 @@ def save( if update_modification_date: metadata.updated_at = datetime.datetime.utcnow() - metadata.groups = [db.session.merge(g) for g in metadata.groups] + session = self.session + metadata.groups = [session.merge(g) for g in metadata.groups] if metadata.owner: - metadata.owner = db.session.merge(metadata.owner) - db.session.add(metadata) - db.session.commit() + metadata.owner = session.merge(metadata.owner) + session.add(metadata) + session.commit() if update_in_listing: self._update_study_from_cache_listing(metadata) return metadata def refresh(self, metadata: Study) -> None: - db.session.refresh(metadata) + self.session.refresh(metadata) def get(self, id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" @@ -72,7 +73,7 @@ def get(self, id: str) -> t.Optional[Study]: # to check the permissions of the current user efficiently. study: Study = ( # fmt: off - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .get(id) @@ -85,7 +86,7 @@ def one(self, id: str) -> Study: # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. study: Study = ( - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .filter_by(id=id) @@ -97,7 +98,7 @@ def get_list(self, study_id: t.List[str]) -> t.List[Study]: # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. studies: t.List[Study] = ( - db.session.query(Study) + self.session.query(Study) .options(joinedload(Study.owner)) .options(joinedload(Study.groups)) .where(Study.id.in_(study_id)) @@ -106,16 +107,16 @@ def get_list(self, study_id: t.List[str]) -> t.List[Study]: return studies def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]: - study: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id) + study: StudyAdditionalData = self.session.query(StudyAdditionalData).get(study_id) return study def get_all(self) -> t.List[Study]: entity = with_polymorphic(Study, "*") - studies: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() + studies: t.List[Study] = self.session.query(entity).filter(RawStudy.missing.is_(None)).all() return studies def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: - query = db.session.query(RawStudy) + query = self.session.query(RawStudy) if not show_missing: query = query.filter(RawStudy.missing.is_(None)) studies: t.List[RawStudy] = query.all() @@ -123,9 +124,10 @@ def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: def delete(self, id: str) -> None: logger.debug(f"Deleting study {id}") - u: Study = db.session.query(Study).get(id) - db.session.delete(u) - db.session.commit() + session = self.session + u: Study = session.query(Study).get(id) + session.delete(u) + session.commit() self._remove_study_from_cache_listing(id) def _remove_study_from_cache_listing(self, study_id: str) -> None: diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py index 5e4059252a..51e426fda2 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingconstraints_ini.py @@ -3,6 +3,23 @@ from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode +# noinspection SpellCheckingInspection class BindingConstraintsIni(IniFileNode): + """ + Handle the binding constraints configuration file: `/input/bindingconstraints/bindingconstraints.ini`. + + This files contains a list of sections numbered from 1 to n. + + Each section contains the following fields: + + - `name`: the name of the binding constraint. + - `id`: the id of the binding constraint (normalized name in lower case). + - `enabled`: whether the binding constraint is enabled or not. + - `type`: the frequency of the binding constraint ("hourly", "daily" or "weekly") + - `operator`: the operator of the binding constraint ("both", "equal", "greater", "less") + - `comment`: a comment + - and a list of coefficients (one per line) of the form `{area1}%{area2} = {coeff}`. + """ + def __init__(self, context: ContextServer, config: FileStudyTreeConfig): - IniFileNode.__init__(self, context, config, types={}) + super().__init__(context, config, types={}) diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py index e86dedfe18..69fe669183 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py @@ -7,18 +7,22 @@ BindingConstraintsIni, ) from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( - default_binding_constraint_daily, - default_binding_constraint_hourly, - default_binding_constraint_weekly, + default_bc_hourly, + default_bc_weekly_daily, ) class BindingConstraints(FolderNode): + """ + Handle the binding constraints folder which contains the binding constraints + configuration and matrices. + """ + def build(self) -> TREE: default_matrices = { - BindingConstraintFrequency.HOURLY: default_binding_constraint_hourly, - BindingConstraintFrequency.DAILY: default_binding_constraint_daily, - BindingConstraintFrequency.WEEKLY: default_binding_constraint_weekly, + BindingConstraintFrequency.HOURLY: default_bc_hourly, + BindingConstraintFrequency.DAILY: default_bc_weekly_daily, + BindingConstraintFrequency.WEEKLY: default_bc_weekly_daily, } children: TREE = { binding.id: InputSeriesMatrix( @@ -31,6 +35,7 @@ def build(self) -> TREE: for binding in self.config.bindings } + # noinspection SpellCheckingInspection children["bindingconstraints"] = BindingConstraintsIni( self.context, self.config.next_file("bindingconstraints.ini") ) diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py index 995dbf92f0..88b58c5369 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/input.py @@ -18,7 +18,12 @@ class Input(FolderNode): + """ + Handle the input folder which contains all the input data of the study. + """ + def build(self) -> TREE: + # noinspection SpellCheckingInspection children: TREE = { "areas": InputAreas(self.context, self.config.next_file("areas")), "bindingconstraints": BindingConstraints(self.context, self.config.next_file("bindingconstraints")), diff --git a/antarest/study/storage/study_upgrader/__init__.py b/antarest/study/storage/study_upgrader/__init__.py index 6b96dc711b..1993b4a0c3 100644 --- a/antarest/study/storage/study_upgrader/__init__.py +++ b/antarest/study/storage/study_upgrader/__init__.py @@ -201,9 +201,10 @@ def _copies_only_necessary_files(files_to_upgrade: List[Path], study_path: Path, The list of files and folders that were really copied. It's the same as files_to_upgrade but without any children that has parents already in the list. """ - files_to_upgrade.append(Path("study.antares")) + files_to_copy = _filters_out_children_files(files_to_upgrade) + files_to_copy.append(Path("study.antares")) files_to_retrieve = [] - for path in files_to_upgrade: + for path in files_to_copy: entire_path = study_path / path if entire_path.is_dir(): if not (tmp_path / path).exists(): @@ -220,6 +221,22 @@ def _copies_only_necessary_files(files_to_upgrade: List[Path], study_path: Path, return files_to_retrieve +def _filters_out_children_files(files_to_upgrade: List[Path]) -> List[Path]: + """ + Filters out children paths of "input" if "input" is already in the list. + Args: + files_to_upgrade: List[Path]: List of the files and folders concerned by the upgrade. + Returns: + The list of files filtered + """ + is_input_in_files_to_upgrade = Path("input") in files_to_upgrade + if is_input_in_files_to_upgrade: + files_to_keep = [Path("input")] + files_to_keep.extend(path for path in files_to_upgrade if "input" not in path.parts) + return files_to_keep + return files_to_upgrade + + def _replace_safely_original_files(files_to_replace: List[Path], study_path: Path, tmp_path: Path) -> None: """ Replace files/folders of the study that should be upgraded by their copy already upgraded in the tmp directory. diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index e0fd1d1e3c..9aa5a9b397 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -48,6 +48,7 @@ class CommandExtractor(ICommandExtractor): def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices() self.patch_service = patch_service self.command_context = CommandContext( generator_matrix_constants=self.generator_matrix_constants, diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py index e7b20a1137..f093c8e4a3 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py @@ -1,10 +1,13 @@ import numpy as np -default_binding_constraint_hourly = np.zeros((8760, 3), dtype=np.float64) -default_binding_constraint_hourly.flags.writeable = False +# Matrice shapes for binding constraints are different from usual shapes, +# because we need to take leap years into account, which contains 366 days and 8784 hours. +# Also, we use the same matrices for "weekly" and "daily" frequencies, +# because the solver calculates the weekly matrix from the daily matrix. +# See https://github.com/AntaresSimulatorTeam/AntaREST/issues/1843 -default_binding_constraint_daily = np.zeros((365, 3), dtype=np.float64) -default_binding_constraint_daily.flags.writeable = False +default_bc_hourly = np.zeros((8784, 3), dtype=np.float64) +default_bc_hourly.flags.writeable = False -default_binding_constraint_weekly = np.zeros((52, 3), dtype=np.float64) -default_binding_constraint_weekly.flags.writeable = False +default_bc_weekly_daily = np.zeros((366, 3), dtype=np.float64) +default_bc_weekly_daily.flags.writeable = False diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 8cb973785e..6a4dc233d4 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -36,8 +36,10 @@ # Binding constraint aliases BINDING_CONSTRAINT_HOURLY = "empty_2nd_member_hourly" -BINDING_CONSTRAINT_DAILY = "empty_2nd_member_daily" -BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_daily" +"""2D-matrix of shape (8784, 3), filled-in with zeros for hourly binding constraints.""" + +BINDING_CONSTRAINT_WEEKLY_DAILY = "empty_2nd_member_weekly_daily" +"""2D-matrix of shape (366, 3), filled-in with zeros for weekly/daily binding constraints.""" # Short-term storage aliases ST_STORAGE_PMAX_INJECTION = ONES_SCENARIO_MATRIX @@ -47,6 +49,7 @@ ST_STORAGE_INFLOWS = EMPTY_SCENARIO_MATRIX MATRIX_PROTOCOL_PREFIX = "matrix://" +_LOCK_FILE_NAME = "matrix_constant_init.lock" # noinspection SpellCheckingInspection @@ -54,50 +57,53 @@ class GeneratorMatrixConstants: def __init__(self, matrix_service: ISimpleMatrixService) -> None: self.hashes: Dict[str, str] = {} self.matrix_service: ISimpleMatrixService = matrix_service - with FileLock(str(Path(tempfile.gettempdir()) / "matrix_constant_init.lock")): - self._init() - - def _init(self) -> None: - self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.max_power - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( - matrix_constants.hydro.v7.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( - matrix_constants.hydro.v6.reservoir - ) - self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( - matrix_constants.hydro.v7.inflow_pattern - ) - self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( - matrix_constants.hydro.v7.credit_modulations - ) - self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) - self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) - self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) - - self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create(matrix_constants.thermals.prepro.modulation) - self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) - self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) - self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) - self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) - - self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) - self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) - self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) - self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) - - # Binding constraint matrices - series = matrix_constants.binding_constraint.series - self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_binding_constraint_hourly) - self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily) - self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create(series.default_binding_constraint_weekly) - - # Some short-term storage matrices use np.ones((8760, 1)) - self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( - matrix_constants.st_storage.series.pmax_injection - ) + self._lock_dir = tempfile.gettempdir() + + def init_constant_matrices( + self, + ) -> None: + with FileLock(str(Path(self._lock_dir) / _LOCK_FILE_NAME)): + self.hashes[HYDRO_COMMON_CAPACITY_MAX_POWER_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.max_power + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V7] = self.matrix_service.create( + matrix_constants.hydro.v7.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_RESERVOIR_V6] = self.matrix_service.create( + matrix_constants.hydro.v6.reservoir + ) + self.hashes[HYDRO_COMMON_CAPACITY_INFLOW_PATTERN] = self.matrix_service.create( + matrix_constants.hydro.v7.inflow_pattern + ) + self.hashes[HYDRO_COMMON_CAPACITY_CREDIT_MODULATION] = self.matrix_service.create( + matrix_constants.hydro.v7.credit_modulations + ) + self.hashes[PREPRO_CONVERSION] = self.matrix_service.create(matrix_constants.prepro.conversion) + self.hashes[PREPRO_DATA] = self.matrix_service.create(matrix_constants.prepro.data) + self.hashes[THERMAL_PREPRO_DATA] = self.matrix_service.create(matrix_constants.thermals.prepro.data) + + self.hashes[THERMAL_PREPRO_MODULATION] = self.matrix_service.create( + matrix_constants.thermals.prepro.modulation + ) + self.hashes[LINK_V7] = self.matrix_service.create(matrix_constants.link.v7.link) + self.hashes[LINK_V8] = self.matrix_service.create(matrix_constants.link.v8.link) + self.hashes[LINK_DIRECT] = self.matrix_service.create(matrix_constants.link.v8.direct) + self.hashes[LINK_INDIRECT] = self.matrix_service.create(matrix_constants.link.v8.indirect) + + self.hashes[NULL_MATRIX_NAME] = self.matrix_service.create(NULL_MATRIX) + self.hashes[EMPTY_SCENARIO_MATRIX] = self.matrix_service.create(NULL_SCENARIO_MATRIX) + self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) + self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) + + # Binding constraint matrices + series = matrix_constants.binding_constraint.series + self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_bc_hourly) + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] = self.matrix_service.create(series.default_bc_weekly_daily) + + # Some short-term storage matrices use np.ones((8760, 1)) + self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( + matrix_constants.st_storage.series.pmax_injection + ) def get_hydro_max_power(self, version: int) -> str: if version > 650: @@ -152,16 +158,16 @@ def get_default_miscgen(self) -> str: return MATRIX_PROTOCOL_PREFIX + self.hashes[MISCGEN_TS] def get_binding_constraint_hourly(self) -> str: - """2D-matrix of shape (8760, 3), filled-in with zeros.""" + """2D-matrix of shape (8784, 3), filled-in with zeros.""" return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_HOURLY] def get_binding_constraint_daily(self) -> str: - """2D-matrix of shape (365, 3), filled-in with zeros.""" - return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_DAILY] + """2D-matrix of shape (366, 3), filled-in with zeros.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] def get_binding_constraint_weekly(self) -> str: - """2D-matrix of shape (52, 3), filled-in with zeros.""" - return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY] + """2D-matrix of shape (366, 3), filled-in with zeros, same as daily.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY_DAILY] def get_st_storage_pmax_injection(self) -> str: """2D-matrix of shape (8760, 1), filled-in with ones.""" diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 178c918a0c..901294a73d 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import numpy as np -from pydantic import Field, validator +from pydantic import BaseModel, Field, validator from antarest.matrixstore.model import MatrixData from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency @@ -40,10 +40,15 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp If the matrix shape does not match the expected shape for the given time step. If the matrix values contain NaN (Not-a-Number). """ + # Matrice shapes for binding constraints are different from usual shapes, + # because we need to take leap years into account, which contains 366 days and 8784 hours. + # Also, we use the same matrices for "weekly" and "daily" frequencies, + # because the solver calculates the weekly matrix from the daily matrix. + # See https://github.com/AntaresSimulatorTeam/AntaREST/issues/1843 shapes = { - BindingConstraintFrequency.HOURLY: (8760, 3), - BindingConstraintFrequency.DAILY: (365, 3), - BindingConstraintFrequency.WEEKLY: (52, 3), + BindingConstraintFrequency.HOURLY: (8784, 3), + BindingConstraintFrequency.DAILY: (366, 3), + BindingConstraintFrequency.WEEKLY: (366, 3), } # Check the matrix values and create the corresponding matrix link array = np.array(values, dtype=np.float64) @@ -53,12 +58,9 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp raise ValueError("Matrix values cannot contain NaN") -class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta): - """ - Abstract class for binding constraint commands. - """ - +class BindingConstraintProperties(BaseModel): # todo: add the `name` attribute because it should also be updated + # It would lead to an API change as update_binding_constraint currently does not have it enabled: bool = True time_step: BindingConstraintFrequency operator: BindingConstraintOperator @@ -68,6 +70,12 @@ class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta): filter_synthesis: Optional[str] = None comments: Optional[str] = None + +class AbstractBindingConstraintCommand(BindingConstraintProperties, ICommand, metaclass=ABCMeta): + """ + Abstract class for binding constraint commands. + """ + def to_dto(self) -> CommandDTO: args = { "enabled": self.enabled, diff --git a/antarest/study/storage/variantstudy/model/dbmodel.py b/antarest/study/storage/variantstudy/model/dbmodel.py index 3e547bce13..1a88a76853 100644 --- a/antarest/study/storage/variantstudy/model/dbmodel.py +++ b/antarest/study/storage/variantstudy/model/dbmodel.py @@ -99,7 +99,7 @@ def snapshot_dir(self) -> Path: """Get the path of the snapshot directory.""" return Path(self.path) / "snapshot" - def is_snapshot_recent(self) -> bool: + def is_snapshot_up_to_date(self) -> bool: """Check if the snapshot exists and is up-to-date.""" return ( (self.snapshot is not None) diff --git a/antarest/study/storage/variantstudy/snapshot_generator.py b/antarest/study/storage/variantstudy/snapshot_generator.py index f36632ea87..50972ae99a 100644 --- a/antarest/study/storage/variantstudy/snapshot_generator.py +++ b/antarest/study/storage/variantstudy/snapshot_generator.py @@ -4,7 +4,6 @@ import datetime import logging import shutil -import tempfile import typing as t from pathlib import Path @@ -51,8 +50,6 @@ def __init__( self.study_factory = study_factory self.patch_service = patch_service self.repository = repository - # Temporary directory used to generate the snapshot - self._tmp_dir: Path = Path() def generate_snapshot( self, @@ -75,32 +72,29 @@ def generate_snapshot( root_study, descendants = self._retrieve_descendants(variant_study_id) assert_permission_on_studies(jwt_user, [root_study, *descendants], StudyPermissionType.READ, raising=True) - ref_study, cmd_blocks = search_ref_study(root_study, descendants, from_scratch=from_scratch) + search_result = search_ref_study(root_study, descendants, from_scratch=from_scratch) - # We are going to generate the snapshot in a temporary directory which will be renamed - # at the end of the process. This prevents incomplete snapshots in case of error. + ref_study = search_result.ref_study + cmd_blocks = search_result.cmd_blocks - # Get snapshot directory and prepare a temporary directory next to it. + # Get snapshot directory variant_study = descendants[-1] snapshot_dir = variant_study.snapshot_dir - snapshot_dir.parent.mkdir(parents=True, exist_ok=True) - self._tmp_dir = Path(tempfile.mkdtemp(dir=snapshot_dir.parent, prefix=f"~{snapshot_dir.name}", suffix=".tmp")) + try: - logger.info(f"Exporting the reference study '{ref_study.id}' to '{self._tmp_dir.name}'...") - self._export_ref_study(ref_study) + if search_result.force_regenerate or not snapshot_dir.exists(): + logger.info(f"Exporting the reference study '{ref_study.id}' to '{snapshot_dir.name}'...") + shutil.rmtree(snapshot_dir, ignore_errors=True) + self._export_ref_study(snapshot_dir, ref_study) logger.info(f"Applying commands to the reference study '{ref_study.id}'...") - results = self._apply_commands(variant_study, ref_study, cmd_blocks) - - if (snapshot_dir / "user").exists(): - logger.info("Keeping previous unmanaged user config...") - shutil.copytree(snapshot_dir / "user", self._tmp_dir / "user", dirs_exist_ok=True) + results = self._apply_commands(snapshot_dir, variant_study, cmd_blocks) # The snapshot is generated, we also need to de-normalize the matrices. file_study = self.study_factory.create_from_fs( - self._tmp_dir, + snapshot_dir, study_id=variant_study_id, - output_path=self._tmp_dir / OUTPUT_RELATIVE_PATH, + output_path=snapshot_dir / OUTPUT_RELATIVE_PATH, use_cache=False, # Avoid saving the study config in the cache ) if denormalize: @@ -112,26 +106,20 @@ def generate_snapshot( variant_study.snapshot = VariantStudySnapshot( id=variant_study_id, created_at=datetime.datetime.utcnow(), - last_executed_command=cmd_blocks[-1].id if cmd_blocks else None, + last_executed_command=variant_study.commands[-1].id if variant_study.commands else None, ) logger.info(f"Reading additional data from files for study {file_study.config.study_id}") variant_study.additional_data = self._read_additional_data(file_study) self.repository.save(variant_study) - # Store the study config in the cache (with adjusted paths). - file_study.config.study_path = file_study.config.path = snapshot_dir - file_study.config.output_path = snapshot_dir / OUTPUT_RELATIVE_PATH self._update_cache(file_study) except Exception: - shutil.rmtree(self._tmp_dir, ignore_errors=True) + shutil.rmtree(snapshot_dir, ignore_errors=True) raise else: - # Rename the temporary directory to the final snapshot directory - shutil.rmtree(snapshot_dir, ignore_errors=True) - self._tmp_dir.rename(snapshot_dir) try: notifier(results.json()) except Exception as exc: @@ -149,12 +137,12 @@ def _retrieve_descendants(self, variant_study_id: str) -> t.Tuple[RawStudy, t.Se root_study = self.repository.one(descendant_ids[0]) return root_study, descendants - def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: - self._tmp_dir.rmdir() # remove the temporary directory for shutil.copytree + def _export_ref_study(self, snapshot_dir: Path, ref_study: t.Union[RawStudy, VariantStudy]) -> None: if isinstance(ref_study, VariantStudy): + snapshot_dir.parent.mkdir(parents=True, exist_ok=True) export_study_flat( ref_study.snapshot_dir, - self._tmp_dir, + snapshot_dir, self.study_factory, denormalize=False, # de-normalization is done at the end outputs=False, # do NOT export outputs @@ -162,7 +150,7 @@ def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: elif isinstance(ref_study, RawStudy): self.raw_study_service.export_study_flat( ref_study, - self._tmp_dir, + snapshot_dir, denormalize=False, # de-normalization is done at the end outputs=False, # do NOT export outputs ) @@ -171,15 +159,15 @@ def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: def _apply_commands( self, + snapshot_dir: Path, variant_study: VariantStudy, - ref_study: t.Union[RawStudy, VariantStudy], cmd_blocks: t.Sequence[CommandBlock], ) -> GenerationResultInfoDTO: commands = [self.command_factory.to_command(cb.to_dto()) for cb in cmd_blocks] generator = VariantCommandGenerator(self.study_factory) results = generator.generate( commands, - self._tmp_dir, + snapshot_dir, variant_study, delete_on_failure=False, # Not needed, because we are using a temporary directory notifier=None, @@ -208,12 +196,22 @@ def _update_cache(self, file_study: FileStudy) -> None: ) +class RefStudySearchResult(t.NamedTuple): + """ + Result of the search for the reference study. + """ + + ref_study: t.Union[RawStudy, VariantStudy] + cmd_blocks: t.Sequence[CommandBlock] + force_regenerate: bool = False + + def search_ref_study( root_study: t.Union[RawStudy, VariantStudy], descendants: t.Sequence[VariantStudy], *, from_scratch: bool = False, -) -> t.Tuple[t.Union[RawStudy, VariantStudy], t.Sequence[CommandBlock]]: +) -> RefStudySearchResult: """ Search for the reference study and the commands to use for snapshot generation. @@ -225,6 +223,9 @@ def search_ref_study( Returns: The reference study and the commands to use for snapshot generation. """ + if not descendants: + # Edge case where the list of studies is empty. + return RefStudySearchResult(ref_study=root_study, cmd_blocks=[], force_regenerate=True) # The reference study is the root study or a variant study with a valid snapshot ref_study: t.Union[RawStudy, VariantStudy] @@ -236,42 +237,68 @@ def search_ref_study( # In the case of a from scratch generation, the root study will be used as the reference study. # We need to retrieve all commands from the descendants of variants in order to apply them # on the reference study. - ref_study = root_study - cmd_blocks = [c for v in descendants for c in v.commands] + return RefStudySearchResult( + ref_study=root_study, + cmd_blocks=[c for v in descendants for c in v.commands], + force_regenerate=True, + ) - else: - # To generate the last variant of a descendant of variants, we must search for - # the most recent snapshot in order to use it as a reference study. - # If no snapshot is found, we use the root study as a reference study. - - snapshot_vars = [v for v in descendants if v.is_snapshot_recent()] - - if snapshot_vars: - # We use the most recent snapshot as a reference study - ref_study = max(snapshot_vars, key=lambda v: v.snapshot.created_at) - - # This variant's snapshot corresponds to the commands actually generated - # at the time of the snapshot. However, we need to retrieve the remaining commands, - # because the snapshot generation may be incomplete. - last_exec_cmd = ref_study.snapshot.last_executed_command # ID of the command - if not last_exec_cmd: - # It is unlikely that this case will occur, but it means that - # the snapshot is not correctly generated (corrupted database). - # It better to use all commands to force snapshot re-generation. - cmd_blocks = ref_study.commands[:] - else: - command_ids = [c.id for c in ref_study.commands] - last_exec_index = command_ids.index(last_exec_cmd) - cmd_blocks = ref_study.commands[last_exec_index + 1 :] - - # We need to add all commands from the descendants of variants - # starting at the first descendant of reference study. - index = descendants.index(ref_study) - cmd_blocks.extend([c for v in descendants[index + 1 :] for c in v.commands]) + # To reuse the snapshot of the current variant, the last executed command + # must be one of the commands of the current variant. + curr_variant = descendants[-1] + if curr_variant.snapshot: + last_exec_cmd = curr_variant.snapshot.last_executed_command + command_ids = [c.id for c in curr_variant.commands] + # If the variant has no command, we can reuse the snapshot if it is recent + if not last_exec_cmd and not command_ids and curr_variant.is_snapshot_up_to_date(): + return RefStudySearchResult( + ref_study=curr_variant, + cmd_blocks=[], + force_regenerate=False, + ) + elif last_exec_cmd and last_exec_cmd in command_ids: + # We can reuse the snapshot of the current variant + last_exec_index = command_ids.index(last_exec_cmd) + return RefStudySearchResult( + ref_study=curr_variant, + cmd_blocks=curr_variant.commands[last_exec_index + 1 :], + force_regenerate=False, + ) + # We cannot reuse the snapshot of the current variant + # To generate the last variant of a descendant of variants, we must search for + # the most recent snapshot in order to use it as a reference study. + # If no snapshot is found, we use the root study as a reference study. + + snapshot_vars = [v for v in descendants if v.is_snapshot_up_to_date()] + + if snapshot_vars: + # We use the most recent snapshot as a reference study + ref_study = max(snapshot_vars, key=lambda v: v.snapshot.created_at) + + # This variant's snapshot corresponds to the commands actually generated + # at the time of the snapshot. However, we need to retrieve the remaining commands, + # because the snapshot generation may be incomplete. + last_exec_cmd = ref_study.snapshot.last_executed_command # ID of the command + command_ids = [c.id for c in ref_study.commands] + if not last_exec_cmd or last_exec_cmd not in command_ids: + # The last executed command may be missing (probably caused by a bug) + # or may reference a removed command. + # This requires regenerating the snapshot from scratch, + # with all commands from the reference study. + cmd_blocks = ref_study.commands[:] else: - # We use the root study as a reference study - ref_study = root_study - cmd_blocks = [c for v in descendants for c in v.commands] + last_exec_index = command_ids.index(last_exec_cmd) + cmd_blocks = ref_study.commands[last_exec_index + 1 :] + + # We need to add all commands from the descendants of variants + # starting at the first descendant of reference study. + index = descendants.index(ref_study) + cmd_blocks.extend([c for v in descendants[index + 1 :] for c in v.commands]) + + else: + # We use the root study as a reference study + ref_study = root_study + cmd_blocks = [c for v in descendants for c in v.commands] - return ref_study, cmd_blocks + return RefStudySearchResult(ref_study=ref_study, cmd_blocks=cmd_blocks, force_regenerate=True) diff --git a/antarest/study/storage/variantstudy/variant_command_extractor.py b/antarest/study/storage/variantstudy/variant_command_extractor.py index 5a88dde857..bd052a6c0a 100644 --- a/antarest/study/storage/variantstudy/variant_command_extractor.py +++ b/antarest/study/storage/variantstudy/variant_command_extractor.py @@ -20,6 +20,7 @@ class VariantCommandsExtractor: def __init__(self, matrix_service: ISimpleMatrixService, patch_service: PatchService): self.matrix_service = matrix_service self.generator_matrix_constants = GeneratorMatrixConstants(self.matrix_service) + self.generator_matrix_constants.init_constant_matrices() self.command_extractor = CommandExtractor(self.matrix_service, patch_service=patch_service) def extract(self, study: FileStudy) -> List[CommandDTO]: diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index de6fa1651c..f9d3eea0aa 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -654,7 +654,7 @@ def generate( if variant_study.parent_id is None: raise NoParentStudyError(variant_study_id) - return self.generate_task(variant_study, denormalize) + return self.generate_task(variant_study, denormalize, from_scratch=from_scratch) def generate_study_config( self, diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index de9fbcecd1..440539a4ab 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -26,7 +26,11 @@ ) from antarest.study.business.areas.st_storage_management import * from antarest.study.business.areas.thermal_management import * -from antarest.study.business.binding_constraint_management import ConstraintTermDTO, UpdateBindingConstProps +from antarest.study.business.binding_constraint_management import ( + BindingConstraintPropertiesWithName, + ConstraintTermDTO, + UpdateBindingConstProps, +) from antarest.study.business.correlation_management import CorrelationFormFields, CorrelationManager, CorrelationMatrix from antarest.study.business.district_manager import DistrictCreationDTO, DistrictInfoDTO, DistrictUpdateDTO from antarest.study.business.general_management import GeneralFormFields @@ -857,6 +861,23 @@ def update_binding_constraint( study = study_service.check_study_access(uuid, StudyPermissionType.WRITE, params) return study_service.binding_constraint_manager.update_binding_constraint(study, binding_constraint_id, data) + @bp.post( + "/studies/{uuid}/bindingconstraints", + tags=[APITag.study_data], + summary="Create a binding constraint", + response_model=None, + ) + def create_binding_constraint( + uuid: str, data: BindingConstraintPropertiesWithName, current_user: JWTUser = Depends(auth.get_current_user) + ) -> None: + logger.info( + f"Creating a new binding constraint for study {uuid}", + extra={"user": current_user.id}, + ) + params = RequestParameters(user=current_user) + study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) + return study_service.binding_constraint_manager.create_binding_constraint(study, data) + @bp.post( "/studies/{uuid}/bindingconstraints/{binding_constraint_id}/term", tags=[APITag.study_data], diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 5ade3d214b..c3c5db9dff 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -24,6 +24,7 @@ from antarest.core.config import CacheConfig from antarest.core.tasks.model import TaskDTO from antarest.core.utils.utils import StopWatch, get_local_path +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.model import NEW_DEFAULT_STUDY_VERSION, STUDY_REFERENCE_TEMPLATES @@ -140,7 +141,12 @@ def render_template(self, study_version: str = NEW_DEFAULT_STUDY_VERSION) -> Non def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> GenerationResultInfoDTO: stopwatch = StopWatch() - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) local_cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -149,8 +155,10 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene cache=local_cache, ) generator = VariantCommandGenerator(study_factory) + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices() command_factory = CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), ) @@ -176,8 +184,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: commands_output_dir.mkdir(parents=True) matrices_dir = commands_output_dir / MATRIX_STORE_DIR matrices_dir.mkdir() - - matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) matrix_resolver = UriResolverService(matrix_service) cache = LocalCache(CacheConfig()) study_factory = StudyFactory( @@ -187,7 +199,12 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: ) study = study_factory.create_from_fs(study_path, str(study_path), use_cache=False) - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) extractor = VariantCommandsExtractor(local_matrix_service, patch_service=PatchService()) command_list = extractor.extract(study) @@ -233,7 +250,12 @@ def generate_diff( study_id = "empty_base" path_study = output_dir / study_id - local_matrix_service = SimpleMatrixService(matrices_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrices_dir, + ) + local_matrix_service = SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) resolver = UriResolverService(matrix_service=local_matrix_service) cache = LocalCache() diff --git a/antarest/utils.py b/antarest/utils.py index d49951017f..39ea094168 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -1,7 +1,8 @@ +import datetime import logging from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Mapping, Optional, Tuple import redis import sqlalchemy.ext.baked # type: ignore @@ -12,6 +13,7 @@ from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore from antarest.core.cache.main import build_cache @@ -20,13 +22,11 @@ from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import IEventBus -from antarest.core.logging.utils import configure_logger from antarest.core.maintenance.main import build_maintenance_manager from antarest.core.persistence import upgrade_db from antarest.core.tasks.main import build_taskjob_manager from antarest.core.tasks.service import ITaskService -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware -from antarest.core.utils.utils import get_local_path, new_redis_instance +from antarest.core.utils.utils import new_redis_instance from antarest.eventbus.main import build_eventbus from antarest.launcher.main import build_launcher from antarest.login.main import build_login @@ -46,6 +46,19 @@ logger = logging.getLogger(__name__) +SESSION_ARGS: Mapping[str, bool] = { + "autocommit": False, + "expire_on_commit": False, + "autoflush": False, +} +""" +This mapping can be used to instantiate a new session, for example: + +>>> with sessionmaker(engine, **SESSION_ARGS)() as session: +... session.execute("SELECT 1") +""" + + class Module(str, Enum): APP = "app" WATCHER = "watcher" @@ -55,12 +68,11 @@ class Module(str, Enum): SIMULATOR_WORKER = "simulator_worker" -def init_db( +def init_db_engine( config_file: Path, config: Config, auto_upgrade_db: bool, - application: Optional[FastAPI], -) -> None: +) -> Engine: if auto_upgrade_db: upgrade_db(config_file) connect_args: Dict[str, Any] = {} @@ -86,19 +98,7 @@ def init_db( engine = create_engine(config.db.db_url, echo=config.debug, connect_args=connect_args, **extra) - session_args = { - "autocommit": False, - "expire_on_commit": False, - "autoflush": False, - } - if application: - application.add_middleware( - DBSessionMiddleware, - custom_engine=engine, - session_args=session_args, - ) - else: - DBSessionMiddleware(None, custom_engine=engine, session_args=session_args) + return engine def create_event_bus( @@ -264,14 +264,3 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: services["cache"] = cache services["maintenance"] = maintenance_service return services - - -def create_env(config_file: Path) -> Dict[str, Any]: - """ - Create application services env for testing and scripting purpose - """ - res = get_local_path() / "resources" - config = Config.from_yaml_file(res=res, file=config_file) - configure_logger(config) - init_db(config_file, config, False, None) - return create_services(config, None) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index f48e0941ca..1e587bf6e3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,45 @@ Antares Web Changelog ===================== +v2.16.1 (2023-12-14) +-------------------- + +### Features + +* **ui:** add manual submit on clusters form [`#1852`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1852) +* **ui-modelling:** add dynamic area selection on Areas tab click [`#1835`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1835) +* **ui-storages:** use percentage values instead of ratio values [`#1846`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1846) +* **upgrade:** correction of study upgrade when upgrading from v8.2 to v8.6 (creation of MinGen) [`#1861`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1861) + + +### Bug Fixes + +* **bc:** correct the name and shape of the binding constraint matrices [`#1849`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1849) +* **bc:** avoid duplicates in Binding Constraints creation through REST API [`#1858`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1858) +* **ui:** update current area after window reload [`#1862`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1862) +* **ui-study:** fix the study card explore button visibility [`#1842`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1842) +* **ui-matrix:** prevent matrices float values to be converted [`#1850`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1850) +* **ui-matrix:** calculate the prepend index according to the existence of a time column [`#1856`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1856) +* **ui-output:** add the missing "ST Storages" option in the Display selector in results view [`#1855`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1855) + + +### Performance + +* **db-init:** separate database initialization from global database session [`#1837`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1837) +* **variant:** improve performances and correct snapshot generation [`#1854`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1854) + + +## Documentation + +* **config:** enhance application configuration documentation [`#1710`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1710) + + +### Chore + +* **deps:** upgrade material-react-table [`#1851`](https://github.com/AntaresSimulatorTeam/AntaREST/pull/1851) + + + v2.16.0 (2023-11-30) -------------------- diff --git a/docs/install/1-CONFIG.md b/docs/install/1-CONFIG.md index 7c82b044dd..8af3e22615 100644 --- a/docs/install/1-CONFIG.md +++ b/docs/install/1-CONFIG.md @@ -1,11 +1,690 @@ -# Application Configuration +# Application Configuration Documentation -Almost all the configuration of the application can be found in the -[application.yaml](https://github.com/AntaresSimulatorTeam/AntaREST/blob/master/resources/application.yaml) file. -If the path to this configuration file is not explicitly provided (through the `-c` option), -the application will try to look for files in the following location (in order): +In the following, we will be exploring how to edit your application configuration file. <br> +As explained in the main documentation readme file, you can use the following command line +to start the API: - 1. `./config.yaml` - 2. `../config.yaml` - 3. `$HOME/.antares/config.yaml` +```shell +python3 antarest/main.py -c resources/application.yaml --auto-upgrade-db --no-front +``` +The `-c` option here describes the path towards the configuration `.yaml` file. If this option is +not fed to the program, it will to look for files in the following locations (in order): + +1. `./config.yaml` +2. `../config.yaml` +3. `$HOME/.antares/config.yaml` + +<br> +In this documentation, you will have a global overview of the configuration +file structure and details for each of the `.yaml` fields with specifications regarding +type of data and the default values, and descriptions of those fields. + + +# File Structure + +- [Security](#security) +- [Database](#db) +- [Storage](#storage) +- [Launcher](#launcher) +- [Logging](#logging) +- [Root Path](#root_path) +- [Optional sections](#debug) + +# security + +This section defines the settings for application security, authentication, and groups. + +## **disabled** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If set to `false`, user identification will be required when launching the app. + +## **jwt** + +### **key** + +- **Type:** String (usually a Base64-encoded one) +- **Default value:** "" +- **Description:** JWT (Json Web Token) secret key for authentication. + +## **login** + +### **admin** + +#### **pwd** + +- **Type:** String +- **Default value:** "" +- **Description:** Admin user's password. + +## **external_auth** + +This subsection is about setting up an external authentication service that lets you connect to an LDAP using a web +service. The group names and their IDs are obtained from the LDAP directory. + +### **url** + +- **Type:** String +- **Default value:** "" +- **Description:** External authentication URL. If you want to enable local authentication, you should write "". + +### **default_group_role** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Default user role for external authentication + - `ADMIN = 40` + - `WRITER = 30` + - `RUNNER = 20` + - `READER = 10` + +### **add_ext_groups** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Whether to add external groups to user roles. + +### **group_mapping** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Groups of the application: Keys = Ids, Values = Names. Example: + - 00000001: espace_commun + - 00001188: drd + - 00001574: cnes + +```yaml +# example for security settings +security: + disabled: false + jwt: + key: best-key + login: + admin: + pwd: root + external_auth: + url: "" + default_group_role: 10 + group_mapping: + id_ext: id_int + add_ext_groups: false +``` + +# db + +This section presents the configuration of application's database connection. + +## **url** + +- **Type:** String +- **Default value:** "" +- **Description:** The Database URL. For example, `sqlite:///database.db` for a local SQLite DB + or `postgresql://postgres_user:postgres_password@postgres_host:postgres_port/postgres_db` for a PostgreSQL DB. + +## **admin_url** + +- **Type:** String +- **Default value:** None +- **Description:** The URL you can use to directly access your database. + +## **pool_use_null** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If set to `true`, connections are not pooled. This parameter should be kept at `false` to avoid + issues. + +## **db_connect_timeout** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** The timeout (in seconds) for database connection creation. + +## **pool_recycle** + +- **Type:** Integer +- **Default value:** None +- **Description:** Prevents the pool from using a particular connection that has passed a certain time in seconds. An + often-used value is 3600, which corresponds to an hour. *Not used for SQLite DB.* + +## **pool_size** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The maximum number of permanent connections to keep. *Not used for SQLite DB.* + +## **pool_use_lifo** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Specifies whether the Database should use the Last-in-First-out method. It is commonly used in cases + where the most recent data entry is the most important and applies to the application context. Therefore, it's better + to set this parameter to `true`. *Not used for SQLite DB.* + +## **pool_pre_ping** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Connections that are closed from the server side are gracefully handled by the connection pool and + replaced with a new connection. *Not used for SQLite DB.* + +## **pool_max_overflow** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Temporarily exceeds the set pool_size if no connections are available. *Not used for SQLite DB.* + +```yaml +# example for db settings +db: + url: "postgresql://postgres:My:s3Cr3t/@127.0.0.1:30432/antares" + admin_url: "postgresql://{{postgres_owner}}:{{postgres_owner_password}}@{{postgres_host}}:{{postgres_port}}/{{postgres_db}}" + pool_recycle: 3600 + pool_max_overflow: 10 + pool_size: 5 + pool_use_lifo: true + pool_use_null: false +``` + +# storage + +The following section configuration parameters define the application paths and services options. + +## **tmp_dir** + +- **Type:** Path +- **Default value:** `tempfile.gettempdir()` ( + documentation [here](https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir)) +- **Description:** The temporary directory for storing temporary files. An often-used value is `./tmp`. + +## **matrixstore** + +- **Type:** Path +- **Default value:** `./matrixstore` +- **Description:** Antares Web extracts matrices data and shares them between managed studies to save space. These + matrices are stored here. + +## **archive_dir** + +- **Type:** Path +- **Default value:** `./archives` +- **Description:** The directory for archived (zipped) studies. + +## **workspaces** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Different workspaces where the application expects to find studies. Keys = Folder names, Values = + WorkspaceConfig object. Such an object has 4 fields: + - `groups`: List of groups corresponding to the workspace (default []) + - `path`: Path of the workspace (default `Path()`) + - `filter_in`: List of regex. If a folder does not contain a file whose name matches one of the regex, it's not + scanned (default [".*"]) + - `filter_out`: List of regex. If a folder contains any file whose name matches one of the regex, it's not scanned ( + default []) + +> NOTE: If a directory is to be ignored by the watcher, place a file named `AW_NO_SCAN` inside. + +Examples: + +```yaml +default: + path: /home/john/Projects/antarest_data/internal_studies/ +studies: + path: /home/john/Projects/antarest_data/studies/ +staging_studies: + path: /home/john/Projects/antarest_data/staging_studies/ +``` + +```yaml +default: + path: /studies/internal +"public": + path: /mounts/public + filter_in: + - .* + filter_out: + - ^R$ + - System Volume Information + - .*RECYCLE.BIN + - .Rproj.* + - ^.git$ + - ^areas$ +"aws_share_2": + path: /mounts/aws_share_2 + groups: + - test +"sedre_archive": + path: /mounts/sedre_archive + groups: + - sedre +``` + +## **allow_deletion** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Indicates if studies found in non-default workspace can be deleted by the application. + +## **matrix_gc_sleeping_time** + +- **Type:** Integer +- **Default value:** 3600 (corresponds to 1 hour) +- **Description:** Time in seconds to sleep between two garbage collections (which means matrix suppression). + +## **matrix_gc_dry_run** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, matrices will never be removed. Else, the ones that are unused will. + +## **auto_archive_sleeping_time** + +- **Type:** Integer +- **Default value:** 3600 (corresponds to 1 hour) +- **Description:** Time in seconds to sleep between two auto_archiver tasks (which means zipping unused studies). + +## **auto_archive_dry_run** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, studies will never be archived. Else, the ones that no one has accessed for a while will. + +## **auto_archive_threshold_days** + +- **Type:** Integer +- **Default value:** 60 +- **Description:** Number of days after the last study access date before it should be archived. + +## **auto_archive_max_parallel** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** Max auto archival tasks in parallel. + +## **watcher_lock** + +- **Type:** Boolean +- **Default value:** true +- **Description:** If false, it will scan without any delay. Else, its delay will be the value of the + field `watcher_lock_delay`. + +## **watcher_lock_delay** + +- **Type:** Integer +- **Default value:** 10 +- **Description:** Seconds delay between two scans. + +## **download_default_expiration_timeout_minutes** + +- **Type:** Integer +- **Default value:** 1440 (corresponds to 1 day) +- **Description:** Minutes before your study download will be cleared. The value could be less than the default one as a + user should download his study pretty soon after the download becomes available. + +```yaml +# example for storage settings +storage: + tmp_dir: /home/jon/Projects/antarest_data/tmp + matrixstore: /home/jon/Projects/antarest_data/matrices + archive_dir: /home/jon/Projects/antarest_data/archives + allow_deletion: false + matrix_gc_sleeping_time: 3600 + matrix_gc_dry_run: False + workspaces: + default: + path: /home/jon/Projects/antarest_data/internal_studies/ + studies: + path: /home/jon/Projects/antarest_data/studies/ + staging_studies: + path: /home/jon/Projects/antarest_data/staging_studies/ +``` + +# launcher + +This section provides the launcher with specified options and defines the settings for solver binaries. + +## **default** + +- **Type:** String, possible values: `local` or `slurm` +- **Default value:** `local` +- **Description:** Default launcher configuration, if set to `local` then the launcher is defined locally. Otherwise +it is instantiated on shared servers using `slurm`. + +## **local** + +### **enable_nb_cores_detection** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Enables detection of available CPUs for the solver. If so, the default value used will be `max(1, + multiprocessing.cpu_count() - 2)`. Else, it will be 22. To maximize the solver's performance, it is recommended to + activate this option. + +### **binaries** + +- **Type:** Dictionary +- **Default value:** {} +- **Description:** Binary paths for various versions of the launcher. Example: + +```yaml +700: /home/john/Antares/antares_web_data/antares-solver/antares-8.0-solver +800: /home/john/Antares/antares_web_data/antares-solver/antares-8.0-solver +810: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +820: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +830: /home/john/Antares/antares_web_data/antares-solver/antares-8.3-solver +840: /home/john/Antares/antares_web_data/antares-solver/antares-8.4-solver +850: /home/john/Antares/antares_web_data/antares-solver/antares-8.5-solver +860: /home/john/Antares/antares_web_data/antares-solver/antares-8.6-solver +``` + +> NOTE: As you can see, you can use newer solver for older study version thanks to the solver retro-compatibility + +## **slurm** + +SLURM (Simple Linux Utility for Resource Management) is used to interact with a remote environment (for Antares it's +computing server) as a workload manager. + +### **local_workspace** + +- **Type:** Path +- **Default value:** Path +- **Description:** Path to the local SLURM workspace + +### **username** + +- **Type:** String +- **Default value:** "" +- **Description:** Username for SLURM to connect itself with SSH protocol to computing server. + +### **hostname** + +- **Type:** String +- **Default value:** "" +- **Description:** IP address for SLURM to connect itself with SSH protocol to computing server. + +### **port** + +- **Type:** Integer +- **Default value:** 0 +- **Description:** SSH port for SLURM + +Examples: + +- Options to connect SLURM to computing server `prod-server-name` (production): + +``` +username: run-antares +hostname: XX.XXX.XXX.XXX +port: 22 +``` + +- Options to connect SLURM to computing server `dev-server-name` (recette and integration): + +``` +username: dev-antares +hostname: XX.XXX.XXX.XXX +port: 22 +``` + +### **private_key_file** + +- **Type:** Path +- **Default value:** Path() +- **Description:** SSH private key file. If you do not have one, you have to fill the `password` field. + +### **password** + +- **Type:** String +- **Default value:** "" +- **Description:** SSH password for the remote server. You need it or a private key file for SLURM to connect itself. + +### **key_password** + +- **Type:** String +- **Default value:** "" +- **Description:** An optional password to use to decrypt the key file, if it's encrypted + +### **default_wait_time** + +> NOTE: Deprecated as the app is launched with wait_mode=false* + +- **Type:** Integer +- **Default value:** 0 +- **Description:** Default delay (in seconds) of the SLURM loop checking the status of the tasks and recovering those + completed in the loop. Often used value: 900 (15 minutes) + +### **default_time_limit** + +- **Type:** Integer +- **Default value:** 0 +- **Description:** Time limit for SLURM jobs (in seconds). If a jobs exceed this time limit, SLURM kills the job and it + is considered failed. Often used value: 172800 (48 hours) + +### **enable_nb_cores_detection** + +- **Type:** Boolean +- **Default value:** false +- **Description:** Enables detection of available CPUs for the solver (Not implemented yet). + +### **nb_cores** + +#### **min** + +- **Type:** Integer +- **Default value:** 1 +- **Description:** Minimum amount of CPUs to use when launching a simulation. + +#### **default** + +- **Type:** Integer +- **Default value:** 22 +- **Description:** Default amount of CPUs to use when launching a simulation. The user can override this value in the + launch dialog box. + +#### **max** + +- **Type:** Integer +- **Default value:** 24 +- **Description:** Maximum amount of CPUs to use when launching a simulation. + +### **default_json_db_name** + +- **Type:** String +- **Default value:** "" +- **Description:** SLURM local DB name. Often used value : `launcher_db.json` + +### **slurm_script_path** + +- **Type:** String +- **Default value:** "" +- **Description:** Bash script path to execute on remote server. + - If SLURM is connected to `prod-server-name` (*production*), use this path: `/applis/antares/launchAntares.sh` + - If SLURM is connected to `dev-server-name` (*recette* and *integration*), use this + path: `/applis/antares/launchAntaresRec.sh` + +### **antares_versions_on_remote_server** + +- **Type:** List of String +- **Default value:** [] +- **Description:** List of Antares solver versions available on the remote server. Examples: + +```yaml +# example for launcher settings +launcher: + default: local + local: + binaries: + 860: /home/jon/opt/antares-solver_ubuntu20.04/antares-8.6-solver + slurm: + local_workspace: /home/jon/Projects/antarest_data/slurm_workspace + username: jon + hostname: localhost + port: 22 + private_key_file: /home/jon/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + default_n_cpu: 20 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '610' + - '700' +``` + +# Logging + +This section sets the configuration for the application logs. + +## **level** + +- **Type:** String, possible values: "DEBUG", "INFO", "WARNING", "ERROR" +- **Default value:** `INFO` +- **Description:** The logging level of the application (INFO, DEBUG, etc.). + +## **logfile** + +- **Type:** Path +- **Default value:** None +- **Description:** The path to the application log file. An often-used value is `.tmp/antarest.log`. + +## **json** + +- **Type:** Boolean +- **Default value:** false +- **Description:** If `true`, the logging format will be `json`; otherwise, it is `console`. + - `console`: The default format used for console output, suitable for Desktop versions or development environments. + - `json`: A specific JSON format suitable for consumption by monitoring tools via a web service. + +```yaml +# example for logging settings +logging: + level: INFO + logfile: ./tmp/antarest.log + json: false +``` + +# root_path + +- **Type:** String +- **Default value:** "" +- **Description:** The root path for FastAPI. To use a remote server, use `/api`, and for a local environment: `api`. + +```yaml +# example for root_path settings +root_path: "/{root_path}" + +``` + +## `Extra optional configuration` + +# debug + +- **Type:** Boolean +- **Default value:** false +- **Description:** This flag determines whether the engine will log all the SQL statements it executes to the console. + If you turn this on by setting it to `true`, you'll see a detailed log of the database queries. + +```yaml +# example for debug settings +debug: false +``` + +# cache + +## **checker_delay** + +- **Type:** Float +- **Default value:** 0.2 +- **Description:** The time in seconds to sleep before checking what needs to be removed from the cache. + +```yaml +# example for cache settings +cache: + checker_delay: 0.2 +``` + +# tasks + +## **max_workers** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The number of threads for Tasks in the ThreadPoolExecutor. + +## **remote_workers** + +- **Type:** List +- **Default value:** [] +- **Description:** Example: + +```yaml +# example for tasks settings +tasks: + max_workers: 4 + remote_workers: + - name: aws_share_2 + queues: + - unarchive_aws_share_2 + - name: simulator_worker + queues: + - generate-timeseries + - generate-kirshoff-constraints +``` + +# server + +## **worker_threadpool_size** + +- **Type:** Integer +- **Default value:** 5 +- **Description:** The number of threads of the Server in the `ThreadPoolExecutor`. + +## **services** + +- **Type:** List of Strings +- **Default value:** [] +- **Description:** Services to enable when launching the application. Possible values: "watcher," "matrix_gc," " + archive_worker," "auto_archiver," "simulator_worker." + +```yaml +#example for server settings +server: + worker_threadpool_size: 5 + services: + - watcher + - matrix_gc +``` + +# redis + +This section is for the settings of Redis backend, which is used for managing the event bus and in-memory caching. + +## **host** + +- **Type:** String +- **Default value:** `localhost` +- **Description:** The Redis server hostname. + +## **port** + +- **Type:** Integer +- **Default value:** 6379 +- **Description:** The Redis server port. + +## **password** + +- **Type:** String +- **Default value:** None +- **Description:** The Redis password. + +```yaml +# example for redis settings +redis: + host: localhost + port: 9862 +``` \ No newline at end of file diff --git a/resources/application.yaml b/resources/application.yaml index a85357634f..6fbdb31f9f 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -1,45 +1,22 @@ +# Documentation about this file can be found in this file: `docs/install/1-CONFIG.md` + security: disabled: true jwt: key: super-secret - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 -# group_mapping: -# id_ext: id_int -# ... - add_ext_groups: false - - db: url: "sqlite:///database.db" - #pool_recycle: storage: tmp_dir: ./tmp matrixstore: ./matrices archive_dir: ./examples/archives - allow_deletion: false # indicate if studies found in non default workspace can be deleted by the application - #matrix_gc_sleeping_time: 3600 # time in seconds to sleep between two garbage collection - #matrix_gc_dry_run: False # Skip matrix effective deletion - #auto_archive_sleeping_time: 3600 # time in seconds to sleep between two auto archival checks - #auto_archive_dry_run: True # Skip auto archive effective archival - #auto_archive_threshold_days: 60 # number of days after last study access when the study should be archived - #auto_archive_max_parallel: 5 # max auto archival tasks in parallel workspaces: - default: # required, no filters applied, this folder is not watched + default: path: ./examples/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: + studies: path: ./examples/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty launcher: default: local @@ -49,39 +26,8 @@ launcher: 700: path/to/700 enable_nb_cores_detection: true -# slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False -# nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : -# - "610" -# - "700" -# - "710" -# - "720" -# - "800" - - -debug: true - root_path: "api" -#tasks: -# max_workers: 5 - server: worker_threadpool_size: 12 services: @@ -90,12 +36,4 @@ server: logging: level: INFO - logfile: ./tmp/antarest.log -# json: false - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -#eventbus: -# redis: -# host: localhost -# port: 6379 + logfile: ./tmp/antarest.log \ No newline at end of file diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index 02fbb4b8bc..e69de29bb2 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -1,87 +0,0 @@ -security: - disabled: false - jwt: - key: secretkeytochange - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 - -db: - url: "postgresql://postgres:somepass@postgresql:5432/postgres" - admin_url: "postgresql://postgres:somepass@postgresql:5432/postgres" - pool_recycle: 3600 - -storage: - tmp_dir: /antarest_tmp_dir - archive_dir: /studies/archives - matrixstore: /matrixstore - matrix_gc_dry_run: true - workspaces: - default: # required, no filters applied, this folder is not watched - path: /workspaces/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: - path: /workspaces/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty - -launcher: - default: local - - local: - binaries: - 800: /antares_simulator/antares-8.2-solver - enable_nb_cores_detection: true - -# slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False -# nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : -# - "610" -# - "700" -# - "710" -# - "720" -# - "800" - - -debug: false - -root_path: "api" - -#tasks: -# max_workers: 5 -server: - worker_threadpool_size: 12 -# services: -# - watcher - -logging: - level: INFO -# logfile: /logs/antarest.log -# json: true - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -redis: - host: redis - port: 6379 diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 810e1f8d24..3eaaf891b6 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -1,13 +1,9 @@ +# Documentation about this file can be found in this file: `docs/install/1-CONFIG.md` + security: disabled: true jwt: key: super-secret - login: - admin: - pwd: admin - external_auth: - url: "" - default_group_role: 10 db: url: "sqlite:///database.db" @@ -17,43 +13,33 @@ storage: matrixstore: ./matrices archive_dir: ./examples/archives workspaces: - default: # required, no filters applied, this folder is not watched + default: path: ./examples/internal_studies/ - # other workspaces can be added - # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside - tmp: + studies: path: ./examples/studies/ - # filter_in: ['.*'] # default to '.*' - # filter_out: [] # default to empty - # groups: [] # default empty launcher: - default: local - local: binaries: - 700: path/to/700 - enable_nb_cores_detection: true + VER: ANTARES_SOLVER_PATH # slurm: -# local_workspace: path/to/workspace -# username: username -# hostname: 0.0.0.0 -# port: 22 -# private_key_file: path/to/key -# key_password: key_password -# password: password_is_optional_but_necessary_if_key_is_absent -# default_wait_time: 900 -# default_time_limit: 172800 -# enable_nb_cores_detection: False +# local_workspace: /path/to/slurm_workspace # Path to the local SLURM workspace +# username: run-antares # SLURM username +# hostname: 10.134.248.111 # SLURM server hostname +# port: 22 # SSH port for SLURM +# private_key_file: /path/to/ssh_private_key # SSH private key file +# default_wait_time: 900 # Default wait time for SLURM jobs +# default_time_limit: 172800 # Default time limit for SLURM jobs +# enable_nb_cores_detection: False # Enable detection of available CPU cores for SLURM # nb_cores: -# min: 1 -# default: 22 -# max: 24 -# default_json_db_name: launcher_db.json -# slurm_script_path: /path/to/launchantares_v1.1.3.sh -# db_primary_key: name -# antares_versions_on_remote_server : +# min: 1 # Minimum number of CPU cores +# default: 22 # Default number of CPU cores +# max: 24 # Maximum number of CPU cores +# default_json_db_name: launcher_db.json # Default JSON database name for SLURM +# slurm_script_path: /applis/antares/launchAntares.sh # Path to the SLURM script (on distant server) +# db_primary_key: name # Primary key for the SLURM database +# antares_versions_on_remote_server: #List of Antares versions available on the remote SLURM server # - "840" # - "850" @@ -62,20 +48,5 @@ debug: false root_path: "api" -#tasks: -# max_workers: 5 -server: - worker_threadpool_size: 12 - services: - - watcher - logging: - level: INFO logfile: ./tmp/antarest.log -# json: false - -# Uncomment these lines to use redis as a backend for the eventbus -# It is required to use redis when using this application on multiple workers in a preforked model like gunicorn for instance -#redis: -# host: localhost -# port: 6379 diff --git a/scripts/package_antares_web.sh b/scripts/package_antares_web.sh index 5dc2da6adb..31ae7ac0f1 100755 --- a/scripts/package_antares_web.sh +++ b/scripts/package_antares_web.sh @@ -73,9 +73,9 @@ echo "INFO: Copying basic configuration files..." rm -rf "${DIST_DIR}/examples" # in case of replay cp -r "${RESOURCES_DIR}"/deploy/* "${DIST_DIR}" if [[ "$OSTYPE" == "msys"* ]]; then - sed -i "s/700: path\/to\/700/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver.exe/g" "${DIST_DIR}/config.yaml" + sed -i "s/VER: ANTARES_SOLVER_PATH/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver.exe/g" "${DIST_DIR}/config.yaml" else - sed -i "s/700: path\/to\/700/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver/g" "${DIST_DIR}/config.yaml" + sed -i "s/VER: ANTARES_SOLVER_PATH/$ANTARES_SOLVER_FULL_VERSION_INT: .\/AntaresWeb\/antares_solver\/antares-$ANTARES_SOLVER_VERSION-solver/g" "${DIST_DIR}/config.yaml" fi echo "INFO: Creating shortcuts..." diff --git a/setup.py b/setup.py index 37074c1e33..1760ecac46 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="AntaREST", - version="2.16.0", + version="2.16.1", description="Antares Server", long_description=Path("README.md").read_text(encoding="utf-8"), long_description_content_type="text/markdown", diff --git a/sonar-project.properties b/sonar-project.properties index e19a4a82dc..69dd022476 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -6,5 +6,5 @@ sonar.exclusions=antarest/gui.py,antarest/main.py sonar.python.coverage.reportPaths=coverage.xml sonar.python.version=3.8 sonar.javascript.lcov.reportPaths=webapp/coverage/lcov.info -sonar.projectVersion=2.16.0 +sonar.projectVersion=2.16.1 sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/* \ No newline at end of file diff --git a/tests/conftest_services.py b/tests/conftest_services.py index ee2fea2057..59f562e241 100644 --- a/tests/conftest_services.py +++ b/tests/conftest_services.py @@ -1,12 +1,13 @@ """ This module provides various pytest fixtures for unit testing the AntaREST application. -Fixtures in this module are used to set up and provide instances of different classes and services required during testing. +Fixtures in this module are used to set up and provide instances of different classes +and services required during testing. """ import datetime +import typing as t import uuid from pathlib import Path -from typing import Dict, List, Optional, Union from unittest.mock import Mock import pytest @@ -18,6 +19,9 @@ from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskListFilter, TaskResult, TaskStatus, TaskType from antarest.core.tasks.service import ITaskService, Task from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.eventbus.business.local_eventbus import LocalEventBus +from antarest.eventbus.service import EventBusService +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.storage.patch_service import PatchService @@ -50,26 +54,26 @@ class SynchTaskService(ITaskService): def __init__(self) -> None: - self._task_result: Optional[TaskResult] = None + self._task_result: t.Optional[TaskResult] = None def add_worker_task( self, task_type: TaskType, task_queue: str, - task_args: Dict[str, Union[int, float, bool, str]], - name: Optional[str], - ref_id: Optional[str], + task_args: t.Dict[str, t.Union[int, float, bool, str]], + name: t.Optional[str], + ref_id: t.Optional[str], request_params: RequestParameters, - ) -> Optional[str]: + ) -> t.Optional[str]: raise NotImplementedError() def add_task( self, action: Task, - name: Optional[str], - task_type: Optional[TaskType], - ref_id: Optional[str], - custom_event_messages: Optional[CustomTaskEventMessages], + name: t.Optional[str], + task_type: t.Optional[TaskType], + ref_id: t.Optional[str], + custom_event_messages: t.Optional[CustomTaskEventMessages], request_params: RequestParameters, ) -> str: self._task_result = action(lambda message: None) @@ -92,15 +96,15 @@ def status_task( logs=None, ) - def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> List[TaskDTO]: + def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParameters) -> t.List[TaskDTO]: return [] - def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None: + def await_task(self, task_id: str, timeout_sec: t.Optional[int] = None) -> None: pass @pytest.fixture(name="bucket_dir", scope="session") -def bucket_dir_fixture(tmp_path_factory) -> Path: +def bucket_dir_fixture(tmp_path_factory: t.Any) -> Path: """ Fixture that creates a session-level temporary directory named "matrix_store" for storing matrices. @@ -114,7 +118,7 @@ def bucket_dir_fixture(tmp_path_factory) -> Path: Returns: A Path object representing the created temporary directory for storing matrices. """ - return tmp_path_factory.mktemp("matrix_store", numbered=False) + return t.cast(Path, tmp_path_factory.mktemp("matrix_store")) @pytest.fixture(name="simple_matrix_service", scope="session") @@ -128,7 +132,10 @@ def simple_matrix_service_fixture(bucket_dir: Path) -> SimpleMatrixService: Returns: An instance of the SimpleMatrixService class representing the matrix service. """ - return SimpleMatrixService(bucket_dir) + matrix_content_repository = MatrixContentRepository( + bucket_dir=bucket_dir, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @pytest.fixture(name="generator_matrix_constants", scope="session") @@ -144,7 +151,9 @@ def generator_matrix_constants_fixture( Returns: An instance of the GeneratorMatrixConstants class representing the matrix constants generator. """ - return GeneratorMatrixConstants(matrix_service=simple_matrix_service) + out_generator_matrix_constants = GeneratorMatrixConstants(simple_matrix_service) + out_generator_matrix_constants.init_constant_matrices() + return out_generator_matrix_constants @pytest.fixture(name="uri_resolver_service", scope="session") @@ -269,7 +278,7 @@ def event_bus_fixture() -> IEventBus: Returns: A Mock instance of the IEventBus class for event bus-related testing. """ - return Mock(spec=IEventBus) + return EventBusService(LocalEventBus()) @pytest.fixture(name="command_factory", scope="session") diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 00c6f9458d..e69de29bb2 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,253 +0,0 @@ -from pathlib import Path -from unittest import mock - -import pytest - -from antarest.core.config import ( - Config, - InvalidConfigurationError, - LauncherConfig, - LocalConfig, - NbCoresConfig, - SlurmConfig, -) -from tests.core.assets import ASSETS_DIR - -LAUNCHER_CONFIG = { - "default": "slurm", - "local": { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - }, - "slurm": { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "enable_nb_cores_detection": False, - "nb_cores": {"min": 1, "default": 34, "max": 36}, - }, - "batch_size": 100, -} - - -class TestNbCoresConfig: - def test_init__default_values(self): - config = NbCoresConfig() - assert config.min == 1 - assert config.default == 22 - assert config.max == 24 - - def test_init__invalid_values(self): - with pytest.raises(ValueError): - # default < min - NbCoresConfig(min=2, default=1, max=24) - with pytest.raises(ValueError): - # default > max - NbCoresConfig(min=1, default=25, max=24) - with pytest.raises(ValueError): - # min < 0 - NbCoresConfig(min=0, default=22, max=23) - with pytest.raises(ValueError): - # min > max - NbCoresConfig(min=22, default=22, max=21) - - def test_to_json(self): - config = NbCoresConfig() - # ReactJs Material UI expects "min", "defaultValue" and "max" keys - assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24} - - -class TestLocalConfig: - def test_init__default_values(self): - config = LocalConfig() - assert config.binaries == {}, "binaries should be empty by default" - assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default" - assert config.nb_cores == NbCoresConfig() - - def test_from_dict(self): - config = LocalConfig.from_dict( - { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - } - ) - assert config.binaries == {"860": Path("/bin/solver-860.exe")} - assert not config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) - - def test_from_dict__auto_detect(self): - with mock.patch("multiprocessing.cpu_count", return_value=8): - config = LocalConfig.from_dict( - { - "binaries": {"860": Path("/bin/solver-860.exe")}, - "enable_nb_cores_detection": True, - } - ) - assert config.binaries == {"860": Path("/bin/solver-860.exe")} - assert config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8) - - -class TestSlurmConfig: - def test_init__default_values(self): - config = SlurmConfig() - assert config.local_workspace == Path() - assert config.username == "" - assert config.hostname == "" - assert config.port == 0 - assert config.private_key_file == Path() - assert config.key_password == "" - assert config.password == "" - assert config.default_wait_time == 0 - assert config.default_time_limit == 0 - assert config.default_json_db_name == "" - assert config.slurm_script_path == "" - assert config.max_cores == 64 - assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default" - assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default" - assert config.nb_cores == NbCoresConfig() - - def test_from_dict(self): - config = SlurmConfig.from_dict( - { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "enable_nb_cores_detection": False, - "nb_cores": {"min": 2, "default": 10, "max": 20}, - } - ) - assert config.local_workspace == Path("/home/john/antares/workspace") - assert config.username == "john" - assert config.hostname == "slurm-001" - assert config.port == 22 - assert config.private_key_file == Path("/home/john/.ssh/id_rsa") - assert config.key_password == "password" - assert config.password == "password" - assert config.default_wait_time == 10 - assert config.default_time_limit == 20 - assert config.default_json_db_name == "antares.db" - assert config.slurm_script_path == "/path/to/slurm/launcher.sh" - assert config.max_cores == 32 - assert config.antares_versions_on_remote_server == ["860"] - assert not config.enable_nb_cores_detection - assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) - - def test_from_dict__default_n_cpu__backport(self): - config = SlurmConfig.from_dict( - { - "local_workspace": Path("/home/john/antares/workspace"), - "username": "john", - "hostname": "slurm-001", - "port": 22, - "private_key_file": Path("/home/john/.ssh/id_rsa"), - "key_password": "password", - "password": "password", - "default_wait_time": 10, - "default_time_limit": 20, - "default_json_db_name": "antares.db", - "slurm_script_path": "/path/to/slurm/launcher.sh", - "max_cores": 32, - "antares_versions_on_remote_server": ["860"], - "default_n_cpu": 15, - } - ) - assert config.nb_cores == NbCoresConfig(min=1, default=15, max=24) - - def test_from_dict__auto_detect(self): - with pytest.raises(NotImplementedError): - SlurmConfig.from_dict({"enable_nb_cores_detection": True}) - - -class TestLauncherConfig: - def test_init__default_values(self): - config = LauncherConfig() - assert config.default == "local", "default launcher should be local" - assert config.local is None - assert config.slurm is None - assert config.batch_size == 9999 - - def test_from_dict(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.default == "slurm" - assert config.local == LocalConfig( - binaries={"860": Path("/bin/solver-860.exe")}, - enable_nb_cores_detection=False, - nb_cores=NbCoresConfig(min=2, default=10, max=20), - ) - assert config.slurm == SlurmConfig( - local_workspace=Path("/home/john/antares/workspace"), - username="john", - hostname="slurm-001", - port=22, - private_key_file=Path("/home/john/.ssh/id_rsa"), - key_password="password", - password="password", - default_wait_time=10, - default_time_limit=20, - default_json_db_name="antares.db", - slurm_script_path="/path/to/slurm/launcher.sh", - max_cores=32, - antares_versions_on_remote_server=["860"], - enable_nb_cores_detection=False, - nb_cores=NbCoresConfig(min=1, default=34, max=36), - ) - assert config.batch_size == 100 - - def test_init__invalid_launcher(self): - with pytest.raises(ValueError): - LauncherConfig(default="invalid_launcher") - - def test_get_nb_cores__default(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - # default == "slurm" - assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36) - - def test_get_nb_cores__local(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20) - - def test_get_nb_cores__slurm(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36) - - def test_get_nb_cores__invalid_configuration(self): - config = LauncherConfig.from_dict(LAUNCHER_CONFIG) - with pytest.raises(InvalidConfigurationError): - config.get_nb_cores("invalid_launcher") - config = LauncherConfig.from_dict({}) - with pytest.raises(InvalidConfigurationError): - config.get_nb_cores("slurm") - - -class TestConfig: - @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"]) - def test_from_yaml_file(self, config_name: str) -> None: - yaml_path = ASSETS_DIR.joinpath("config", config_name) - config = Config.from_yaml_file(yaml_path) - assert config.security.admin_pwd == "admin" - assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies") - assert not config.logging.json - assert config.logging.level == "INFO" diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index dfad126555..cc730bf0ea 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,56 +1,83 @@ +import dataclasses import datetime import time +import typing as t from pathlib import Path -from typing import Callable, List -from unittest.mock import ANY, Mock, call +from unittest.mock import ANY, Mock import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine # type: ignore +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, sessionmaker # type: ignore from antarest.core.config import Config, RemoteWorkerConfig, TaskConfig -from antarest.core.interfaces.eventbus import Event, EventType, IEventBus +from antarest.core.interfaces.eventbus import EventType, IEventBus from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.model import PermissionInfo, PublicMode from antarest.core.persistence import Base from antarest.core.requests import RequestParameters, UserHasNotPermissionError -from antarest.core.tasks.model import TaskDTO, TaskJob, TaskJobLog, TaskListFilter, TaskResult, TaskStatus, TaskType +from antarest.core.tasks.model import ( + TaskJob, + TaskJobLog, + TaskListFilter, + TaskResult, + TaskStatus, + TaskType, + cancel_orphan_tasks, +) from antarest.core.tasks.repository import TaskJobRepository from antarest.core.tasks.service import TaskJobService -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db +from antarest.core.utils.fastapi_sqlalchemy import db from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.service import EventBusService +from antarest.login.model import User +from antarest.study.model import RawStudy +from antarest.utils import SESSION_ARGS from antarest.worker.worker import AbstractWorker, WorkerTaskCommand from tests.helpers import with_db_context -def test_service() -> None: - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) +@pytest.fixture(name="db_engine", autouse=True) +def db_engine_fixture(tmp_path: Path) -> t.Generator[Engine, None, None]: + """ + Fixture that creates an SQLite database in a temporary directory. + + When a function runs in a different thread than the main one and needs to use + the database, it uses the global `db` object. This object helps create a new + local session in the thread to connect to the SQLite database. + However, we can't use an in-memory SQLite database ("sqlite:///:memory:") because + it creates a new empty database each time. That's why we use a SQLite database stored on disk. + + Yields: + An instance of the created SQLite database engine. + """ + db_path = tmp_path / "db.sqlite" + db_url = f"sqlite:///{db_path}" + engine = create_engine(db_url, echo=False) + engine.execute("PRAGMA foreign_keys = ON") Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) + yield engine + engine.dispose() - repo_mock = Mock(spec=TaskJobRepository) - creation_date = datetime.datetime.now(datetime.timezone.utc) - task = TaskJob(id="a", name="b", status=2, creation_date=creation_date) - repo_mock.list.return_value = [task] - repo_mock.get_or_raise.return_value = task - service = TaskJobService(config=Config(), repository=repo_mock, event_bus=Mock()) - repo_mock.save.assert_called_with( - TaskJob( - id="a", - name="b", - status=4, - creation_date=creation_date, - result_status=False, - result_msg="Task was interrupted due to server restart", - completion_date=ANY, - ) - ) + +@with_db_context +def test_service(core_config: Config, event_bus: IEventBus) -> None: + engine = db.session.bind + task_job_repo = TaskJobRepository() + + # Prepare a TaskJob in the database + creation_date = datetime.datetime.utcnow() + running_task = TaskJob(id="a", name="b", status=TaskStatus.RUNNING.value, creation_date=creation_date) + task_job_repo.save(running_task) + + # Create a TaskJobService + service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) + + # Cancel pending and running tasks + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) + + # Test Case: list tasks + # ===================== tasks = service.list_tasks( TaskListFilter(), @@ -60,52 +87,37 @@ def test_service() -> None: assert tasks[0].status == TaskStatus.FAILED assert tasks[0].creation_date_utc == str(creation_date) - start = datetime.datetime.now(datetime.timezone.utc) - end = start + datetime.timedelta(seconds=1) - repo_mock.reset_mock() - repo_mock.get.return_value = TaskJob( - id="a", - completion_date=end, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=start, - ) + # Test Case: get task status + # ========================== + res = service.status_task("a", RequestParameters(user=DEFAULT_ADMIN_USER)) assert res is not None - assert res == TaskDTO( - id="a", - completion_date_utc=str(end), - creation_date_utc=str(start), - owner=1, - name="Unnamed", - result=TaskResult(success=True, message="OK"), - status=TaskStatus.COMPLETED, - ) + expected = { + "completion_date_utc": ANY, + "creation_date_utc": creation_date.isoformat(" "), + "id": "a", + "logs": None, + "name": "b", + "owner": None, + "ref_id": None, + "result": { + "message": "Task was interrupted due to server restart", + "return_value": None, + "success": False, + }, + "status": TaskStatus.FAILED, + "type": None, + } + assert res.dict() == expected + + # Test Case: add a task that fails and wait for it + # ================================================ # noinspection PyUnusedLocal - def action_fail(update_msg: Callable[[str], None]) -> TaskResult: - raise NotImplementedError() - - def action_ok(update_msg: Callable[[str], None]) -> TaskResult: - update_msg("start") - update_msg("end") - return TaskResult(success=True, message="OK") + def action_fail(update_msg: t.Callable[[str], None]) -> TaskResult: + raise Exception("this action failed") - repo_mock.reset_mock() - now = datetime.datetime.utcnow() - task = TaskJob( - name="failed action", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.PENDING.value, - ) - repo_mock.save.side_effect = lambda x: task - repo_mock.get_or_raise.return_value = task - service.add_task( + failed_id = service.add_task( action_fail, "failed action", None, @@ -113,79 +125,27 @@ def action_ok(update_msg: Callable[[str], None]) -> TaskResult: None, RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task("a") - repo_mock.save.assert_has_calls( - [ - call( - TaskJob( - id=None, - logs=[], - owner_id=1, - creation_date=None, - completion_date=None, - name="failed action", - status=None, - result_msg=None, - result_status=None, - ) - ), - call( - TaskJob( - id="a", - logs=[], - owner_id=1, - creation_date=now, - completion_date=ANY, - name="failed action", - status=4, - result_msg=ANY, # "Task a failed: Unhandled exception [...]" - result_status=False, - ) - ), - call( - TaskJob( - id="a", - logs=[], - owner_id=1, - creation_date=now, - completion_date=ANY, - name="failed action", - status=4, - result_msg=ANY, # "Task a failed: Unhandled exception [...]" - result_status=False, - ) - ), - ] + service.await_task(failed_id, timeout_sec=2) + + failed_task = task_job_repo.get(failed_id) + assert failed_task is not None + assert failed_task.status == TaskStatus.FAILED.value + assert failed_task.result_status is False + assert failed_task.result_msg == ( + f"Task {failed_id} failed: Unhandled exception this action failed" + f"\nSee the logs for detailed information and the error traceback." ) + assert failed_task.completion_date is not None - repo_mock.reset_mock() - now = datetime.datetime.utcnow() - task = TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.PENDING.value, - ) - repo_mock.save.side_effect = lambda x: task - repo_mock.get_or_raise.return_value = task - repo_mock.get.side_effect = [ - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - ), - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - ), - ] - service.add_task( + # Test Case: add a task that succeeds and wait for it + # =================================================== + + def action_ok(update_msg: t.Callable[[str], None]) -> TaskResult: + update_msg("start") + update_msg("end") + return TaskResult(success=True, message="OK") + + ok_id = service.add_task( action_ok, None, None, @@ -193,134 +153,46 @@ def action_ok(update_msg: Callable[[str], None]) -> TaskResult: None, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task("a") - repo_mock.save.assert_has_calls( - [ - call(TaskJob(owner_id=1, name="Unnamed")), - # this is not called with that because the object is mutated, and mock seems to suck.. - # TaskJob( - # id="a", - # name="failed action", - # owner_id=1, - # status=TaskStatus.RUNNING.value, - # creation_date=now, - # ), - call( - TaskJob( - id="a", - completion_date=ANY, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=now, - ) - ), - call( - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - logs=[TaskJobLog(message="start", task_id="a")], - ) - ), - call( - TaskJob( - name="Unnamed", - owner_id=1, - id="a", - creation_date=now, - status=TaskStatus.RUNNING.value, - logs=[TaskJobLog(message="end", task_id="a")], - ) - ), - call( - TaskJob( - id="a", - completion_date=ANY, - name="Unnamed", - owner_id=1, - status=TaskStatus.COMPLETED.value, - result_status=True, - result_msg="OK", - creation_date=now, - ) - ), - ] - ) + service.await_task(ok_id, timeout_sec=2) - repo_mock.get.reset_mock() - repo_mock.get.side_effect = [None] - service.await_task("elsewhere") - repo_mock.get.assert_called_with("elsewhere") + ok_task = task_job_repo.get(ok_id) + assert ok_task is not None + assert ok_task.status == TaskStatus.COMPLETED.value + assert ok_task.result_status is True + assert ok_task.result_msg == "OK" + assert ok_task.completion_date is not None + assert len(ok_task.logs) == 2 + assert ok_task.logs[0].message == "start" + assert ok_task.logs[1].message == "end" class DummyWorker(AbstractWorker): - def __init__(self, event_bus: IEventBus, accept: List[str], tmp_path: Path): + def __init__(self, event_bus: IEventBus, accept: t.List[str], tmp_path: Path): super().__init__("test", event_bus, accept) self.tmp_path = tmp_path def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: # simulate a "long" task ;-) time.sleep(0.01) - relative_path = task_info.task_args["file"] + relative_path = t.cast(str, task_info.task_args["file"]) (self.tmp_path / relative_path).touch() return TaskResult(success=True, message="") @with_db_context -def test_worker_tasks(tmp_path: Path): - repo_mock = Mock(spec=TaskJobRepository) - repo_mock.list.return_value = [] - event_bus = EventBusService(LocalEventBus()) - service = TaskJobService( - config=Config(tasks=TaskConfig(remote_workers=[RemoteWorkerConfig(name="test", queues=["test"])])), - repository=repo_mock, - event_bus=event_bus, - ) +def test_worker_tasks(tmp_path: Path, core_config: Config, event_bus: IEventBus) -> None: + # Create a TaskJobService + task_job_repo = TaskJobRepository() + task_config = TaskConfig(remote_workers=[RemoteWorkerConfig(name="test", queues=["test"])]) + config = dataclasses.replace(core_config, tasks=task_config) + service = TaskJobService(config=config, repository=task_job_repo, event_bus=event_bus) worker = DummyWorker(event_bus, ["test"], tmp_path) worker.start(threaded=True) file_to_create = "foo" - assert not (tmp_path / file_to_create).exists() - repo_mock.save.side_effect = [ - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - ), - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - status=TaskStatus.RUNNING, - ), - TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - status=TaskStatus.COMPLETED, - ), - ] - repo_mock.get_or_raise.return_value = TaskJob( - id="taskid", - name="Unnamed", - owner_id=0, - type=TaskType.WORKER_TASK, - ref_id=None, - ) task_id = service.add_worker_task( TaskType.WORKER_TASK, "test", @@ -329,127 +201,192 @@ def test_worker_tasks(tmp_path: Path): None, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), ) - service.await_task(task_id) + assert task_id is not None + service.await_task(task_id, timeout_sec=2) assert (tmp_path / file_to_create).exists() -def test_repository(): - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +def test_repository(db_session: Session) -> None: + # Prepare two users in the database + user1_id = 9 + db_session.add(User(id=user1_id, name="John")) + user2_id = 10 + db_session.add(User(id=user2_id, name="Jane")) + db_session.commit() - with db(): - # sourcery skip: extract-method - task_repository = TaskJobRepository() - - new_task = TaskJob(name="foo", owner_id=0, type=TaskType.COPY) - second_task = TaskJob(owner_id=1, ref_id="a") - - now = datetime.datetime.utcnow() - new_task = task_repository.save(new_task) - assert task_repository.get(new_task.id) == new_task - assert new_task.status == TaskStatus.PENDING.value - assert new_task.owner_id == 0 - assert new_task.creation_date >= now - - second_task = task_repository.save(second_task) - - result = task_repository.list(TaskListFilter(type=[TaskType.COPY])) - assert len(result) == 1 - assert result[0].id == new_task.id - - result = task_repository.list(TaskListFilter(ref_id="a")) - assert len(result) == 1 - assert result[0].id == second_task.id - - result = task_repository.list(TaskListFilter(), user=1) - assert len(result) == 1 - assert result[0].id == second_task.id - - result = task_repository.list(TaskListFilter()) - assert len(result) == 2 - - result = task_repository.list(TaskListFilter(name="fo")) - assert len(result) == 1 - - result = task_repository.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) - assert len(result) == 0 - new_task.status = TaskStatus.RUNNING.value - task_repository.save(new_task) - result = task_repository.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) - assert len(result) == 1 - - new_task.completion_date = datetime.datetime.utcnow() - task_repository.save(new_task) - result = task_repository.list( - TaskListFilter( - name="fo", - from_completion_date_utc=(new_task.completion_date + datetime.timedelta(seconds=1)).timestamp(), - ) + # Create a RawStudy in the database + study_id = "e34fe4d5-5964-4ef2-9baf-fad66dadc512" + db_session.add(RawStudy(id="study_id", name="foo", version="860")) + db_session.commit() + + # Create a TaskJobService + task_job_repo = TaskJobRepository(db_session) + + new_task = TaskJob(name="foo", owner_id=user1_id, type=TaskType.COPY) + + now = datetime.datetime.utcnow() + new_task = task_job_repo.save(new_task) + assert task_job_repo.get(new_task.id) == new_task + assert new_task.status == TaskStatus.PENDING.value + assert new_task.owner_id == user1_id + assert new_task.creation_date >= now + + second_task = TaskJob(owner_id=user2_id, ref_id=study_id) + second_task = task_job_repo.save(second_task) + + result = task_job_repo.list(TaskListFilter(type=[TaskType.COPY])) + assert len(result) == 1 + assert result[0].id == new_task.id + + result = task_job_repo.list(TaskListFilter(ref_id=study_id)) + assert len(result) == 1 + assert result[0].id == second_task.id + + result = task_job_repo.list(TaskListFilter(), user=user2_id) + assert len(result) == 1 + assert result[0].id == second_task.id + + result = task_job_repo.list(TaskListFilter()) + assert len(result) == 2 + + result = task_job_repo.list(TaskListFilter(name="fo")) + assert len(result) == 1 + + result = task_job_repo.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) + assert len(result) == 0 + new_task.status = TaskStatus.RUNNING.value + task_job_repo.save(new_task) + result = task_job_repo.list(TaskListFilter(name="fo", status=[TaskStatus.RUNNING])) + assert len(result) == 1 + + new_task.completion_date = datetime.datetime.utcnow() + task_job_repo.save(new_task) + result = task_job_repo.list( + TaskListFilter( + name="fo", + from_completion_date_utc=(new_task.completion_date + datetime.timedelta(seconds=1)).timestamp(), ) - assert len(result) == 0 - result = task_repository.list( - TaskListFilter( - name="fo", - from_completion_date_utc=(new_task.completion_date - datetime.timedelta(seconds=1)).timestamp(), - ) + ) + assert len(result) == 0 + result = task_job_repo.list( + TaskListFilter( + name="fo", + from_completion_date_utc=(new_task.completion_date - datetime.timedelta(seconds=1)).timestamp(), ) - assert len(result) == 1 + ) + assert len(result) == 1 - new_task.logs.append(TaskJobLog(message="hello")) - new_task.logs.append(TaskJobLog(message="bar")) - task_repository.save(new_task) - new_task = task_repository.get(new_task.id) - assert len(new_task.logs) == 2 - assert new_task.logs[0].message == "hello" + new_task.logs.append(TaskJobLog(message="hello")) + new_task.logs.append(TaskJobLog(message="bar")) + task_job_repo.save(new_task) + assert new_task.id is not None + new_task = task_job_repo.get_or_raise(new_task.id) + assert len(new_task.logs) == 2 + assert new_task.logs[0].message == "hello" - assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 + assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 2 - task_repository.delete(new_task.id) - assert len(db.session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 - assert task_repository.get(new_task.id) is None + task_job_repo.delete(new_task.id) + assert len(db_session.query(TaskJobLog).where(TaskJobLog.task_id == new_task.id).all()) == 0 + assert task_job_repo.get(new_task.id) is None -def test_cancel(): - # sourcery skip: aware-datetime-for-utc - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +@with_db_context +def test_cancel(core_config: Config, event_bus: IEventBus) -> None: + # Create a TaskJobService and add tasks + task_job_repo = TaskJobRepository() + task_job_repo.save(TaskJob(id="a")) + task_job_repo.save(TaskJob(id="b")) - repo_mock = Mock(spec=TaskJobRepository) - repo_mock.list.return_value = [] - service = TaskJobService(config=Config(), repository=repo_mock, event_bus=Mock()) + # Create a TaskJobService + service = TaskJobService(config=core_config, repository=task_job_repo, event_bus=event_bus) with pytest.raises(UserHasNotPermissionError): service.cancel_task("a", RequestParameters()) + # The event_bus fixture is actually a EventBusService with LocalEventBus backend + backend = t.cast(LocalEventBus, t.cast(EventBusService, event_bus).backend) + + # Test Case: cancel a task that is not in the service tasks map + # ============================================================= + + backend.clear_events() + service.cancel_task("b", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) - # noinspection PyUnresolvedReferences - service.event_bus.push.assert_called_with( - Event( - type=EventType.TASK_CANCEL_REQUEST, - payload="b", - permissions=PermissionInfo(public_mode=PublicMode.NONE), - ) - ) - creation_date = datetime.datetime.utcnow() - task = TaskJob(id="a", name="b", status=2, creation_date=creation_date) - repo_mock.list.return_value = [task] - repo_mock.get_or_raise.return_value = task - service.tasks["a"] = Mock() + collected_events = backend.get_events() + + assert len(collected_events) == 1 + assert collected_events[0].type == EventType.TASK_CANCEL_REQUEST + assert collected_events[0].payload == "b" + assert collected_events[0].permissions == PermissionInfo(public_mode=PublicMode.NONE) + + # Test Case: cancel a task that is in the service tasks map + # ========================================================= + + service.tasks["a"] = Mock(cancel=Mock(return_value=None)) + + backend.clear_events() + service.cancel_task("a", RequestParameters(user=DEFAULT_ADMIN_USER), dispatch=True) - task.status = TaskStatus.CANCELLED.value - repo_mock.save.assert_called_with(task) + + collected_events = backend.get_events() + assert len(collected_events) == 0, "No event should have been emitted because the task is in the service map" + task_a = task_job_repo.get("a") + assert task_a is not None + assert task_a.status == TaskStatus.CANCELLED.value + + +@pytest.mark.parametrize( + ("status", "result_status", "result_msg"), + [ + (TaskStatus.RUNNING.value, False, "task ongoing"), + (TaskStatus.PENDING.value, True, "task pending"), + (TaskStatus.FAILED.value, False, "task failed"), + (TaskStatus.COMPLETED.value, True, "task finished"), + (TaskStatus.TIMEOUT.value, False, "task timed out"), + (TaskStatus.CANCELLED.value, True, "task canceled"), + ], +) +def test_cancel_orphan_tasks( + db_engine: Engine, + status: int, + result_status: bool, + result_msg: str, +) -> None: + max_diff_seconds: int = 1 + test_id: str = "2ea94758-9ea5-4015-a45f-b245a6ffc147" + + completion_date: datetime.datetime = datetime.datetime.utcnow() + task_job = TaskJob( + id=test_id, + status=status, + result_status=result_status, + result_msg=result_msg, + completion_date=completion_date, + ) + make_session = sessionmaker(bind=db_engine, **SESSION_ARGS) + with make_session() as session: + session.add(task_job) + session.commit() + cancel_orphan_tasks(engine=db_engine, session_args=SESSION_ARGS) + with make_session() as session: + if status in [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]: + update_tasks_count = ( + session.query(TaskJob) + .filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])) + .count() + ) + assert not update_tasks_count + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == TaskStatus.FAILED.value + assert not updated_task_job.result_status + assert updated_task_job.result_msg == "Task was interrupted due to server restart" + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds + else: + updated_task_job = session.query(TaskJob).get(test_id) + assert updated_task_job.status == status + assert updated_task_job.result_status == result_status + assert updated_task_job.result_msg == result_msg + assert (datetime.datetime.utcnow() - updated_task_job.completion_date).seconds <= max_diff_seconds diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 0938557239..1e9fd99caa 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -2192,6 +2192,45 @@ def test_binding_constraint_manager(client: TestClient, admin_access_token: str, assert res.status_code == 200 assert constraints is None + # Creates a binding constraint with the new API + res = client.post( + f"/v1/studies/{variant_id}/bindingconstraints", + json={ + "name": "binding_constraint_3", + "enabled": True, + "time_step": "hourly", + "operator": "less", + "coeffs": {}, + "comments": "New API", + }, + headers=admin_headers, + ) + assert res.status_code == 200 + + # Asserts that creating 2 binding constraints with the same name raises an Exception + res = client.post( + f"/v1/studies/{variant_id}/bindingconstraints", + json={ + "name": "binding_constraint_3", + "enabled": True, + "time_step": "hourly", + "operator": "less", + "coeffs": {}, + "comments": "New API", + }, + headers=admin_headers, + ) + assert res.status_code == 409 + assert res.json() == { + "description": "A binding constraint with the same name already exists: binding_constraint_3.", + "exception": "DuplicateConstraintName", + } + + # Asserts that only 3 binding constraint have been created + res = client.get(f"/v1/studies/{variant_id}/bindingconstraints", headers=admin_headers) + assert res.status_code == 200 + assert len(res.json()) == 3 + def test_import(client: TestClient, admin_access_token: str, study_id: str) -> None: admin_headers = {"Authorization": f"Bearer {admin_access_token}"} diff --git a/tests/login/conftest.py b/tests/login/conftest.py index 7b7935d6d3..61bebee728 100644 --- a/tests/login/conftest.py +++ b/tests/login/conftest.py @@ -13,19 +13,15 @@ def group_repo_fixture(db_middleware: DBSessionMiddleware) -> GroupRepository: """Fixture that creates a GroupRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `GroupRepository` insert an admin group in the database if it does not exist: - # >>> Group(id="admin", name="admin") return GroupRepository() # noinspection PyUnusedLocal @pytest.fixture(name="user_repo") -def user_repo_fixture(core_config: Config, db_middleware: DBSessionMiddleware) -> UserRepository: +def user_repo_fixture(db_middleware: DBSessionMiddleware) -> UserRepository: """Fixture that creates a UserRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `UserRepository` insert an admin user in the database if it does not exist. - # >>> User(id=1, name="admin", password=Password(config.security.admin_pwd)) - return UserRepository(config=core_config) + return UserRepository() # noinspection PyUnusedLocal @@ -49,8 +45,6 @@ def bot_repo_fixture(db_middleware: DBSessionMiddleware) -> BotRepository: def role_repo_fixture(db_middleware: DBSessionMiddleware) -> RoleRepository: """Fixture that creates a RoleRepository instance.""" # note: `DBSessionMiddleware` is required to instantiate a thread-local db session. - # important: the `RoleRepository` insert an admin role in the database if it does not exist. - # >>> Role(type=RoleType.ADMIN, identity=User(id=1), group=Group(id="admin")) return RoleRepository() diff --git a/tests/login/test_login_service.py b/tests/login/test_login_service.py index a56f256003..e48a54a918 100644 --- a/tests/login/test_login_service.py +++ b/tests/login/test_login_service.py @@ -1,13 +1,13 @@ import typing as t -from unittest.mock import Mock +from unittest.mock import patch import pytest -from fastapi import HTTPException from antarest.core.jwt import JWTGroup, JWTUser -from antarest.core.requests import RequestParameters, UserHasNotPermissionError +from antarest.core.requests import RequestParameters from antarest.core.roles import RoleType from antarest.login.model import ( + ADMIN_ID, Bot, BotCreateDTO, BotRoleCreateDTO, @@ -22,573 +22,654 @@ from antarest.login.service import LoginService from tests.helpers import with_db_context -SITE_ADMIN = RequestParameters( - user=JWTUser( - id=0, - impersonator=0, +# For the unit tests, we will define several fictitious users, groups and roles. + +GroupObj = t.TypedDict("GroupObj", {"id": str, "name": str}) +UserObj = t.TypedDict("UserObj", {"id": int, "name": str}) +RoleObj = t.TypedDict("RoleObj", {"type": RoleType, "group_id": str, "identity_id": int}) + + +_GROUPS: t.List[GroupObj] = [ + {"id": "admin", "name": "X-Men"}, + {"id": "superman", "name": "Superman"}, + {"id": "metropolis", "name": "Metropolis"}, +] + + +_USERS: t.List[UserObj] = [ + # main characters + {"id": ADMIN_ID, "name": "Professor Xavier"}, # site admin + {"id": 2, "name": "Clark Kent"}, # admin of "Superman" group + {"id": 3, "name": "Lois Lane"}, # reader in "Superman" group + {"id": 4, "name": "Joh Fredersen"}, # "Metropolis" leader + {"id": 5, "name": "Freder Fredersen"}, # reader in "Metropolis" group + # secondary characters + {"id": 50, "name": "Storm"}, # evil man in "X-Men" group + {"id": 60, "name": "Livewire"}, # evil woman in "Superman" group + {"id": 70, "name": "Maria"}, # robot in "Metropolis" group + {"id": 80, "name": "Jane DOE"}, # external user +] + +_ROLES: t.List[RoleObj] = [ + {"type": RoleType.ADMIN, "group_id": "admin", "identity_id": ADMIN_ID}, + {"type": RoleType.ADMIN, "group_id": "superman", "identity_id": 2}, + {"type": RoleType.READER, "group_id": "superman", "identity_id": 3}, + {"type": RoleType.ADMIN, "group_id": "metropolis", "identity_id": 4}, + {"type": RoleType.READER, "group_id": "metropolis", "identity_id": 5}, +] + + +def get_jwt_user(user: User, roles: t.Iterable[Role], owner_id: int = 0) -> JWTUser: + jwt_user = JWTUser( + id=user.id, + impersonator=owner_id or user.id, type="users", - groups=[JWTGroup(id="admin", name="admin", role=RoleType.ADMIN)], + groups=[JWTGroup(id=role.group.id, name=role.group.name, role=role.type) for role in roles], ) -) + return jwt_user -GROUP_ADMIN = RequestParameters( - user=JWTUser( - id=1, - impersonator=1, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.ADMIN)], - ) -) - -USER3 = RequestParameters( - user=JWTUser( - id=3, - impersonator=3, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.READER)], - ) -) -BAD_PARAM = RequestParameters(user=None) +def get_request_param( + user: t.Union[User, UserLdap, Bot], + role: t.Optional[Role], + owner_id: int = 0, +) -> RequestParameters: + if user is None: + return RequestParameters(user=None) + roles = (role,) if role else () + jwt_user = get_jwt_user(user, roles, owner_id=owner_id) + return RequestParameters(user=jwt_user) -class TestLoginService: - """ - Test login service. +def get_user_param(login_service: LoginService, user_id: int, group_id: str = "(unknown)") -> RequestParameters: + user = login_service.users.get(user_id) or login_service.ldap.get(user_id) + assert user is not None + role = login_service.roles.get(user_id, group_id) + return get_request_param(user, role) - important: - - the `GroupRepository` insert an admin group in the database if it does not exist: - `Group(id="admin", name="admin")` +def get_bot_param(login_service: LoginService, bot_id: int, group_id: str = "(unknown)") -> RequestParameters: + bot = login_service.bots.get(bot_id) + assert bot is not None + role = login_service.roles.get(bot_id, group_id) + return get_request_param(bot, role, owner_id=bot.owner) - - the `UserRepository` insert an admin user in the database if it does not exist. - `User(id=1, name="admin", password=Password(config.security.admin_pwd))` - - the `RoleRepository` insert an admin role in the database if it does not exist. - `Role(type=RoleType.ADMIN, identity=User(id=1), group=Group(id="admin"))` +class TestLoginService: + """ + Test login service. """ + @pytest.fixture(name="populate_db", autouse=True) @with_db_context - @pytest.mark.parametrize( - "param, can_create", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, False), (BAD_PARAM, False)] - ) - def test_save_group(self, login_service: LoginService, param: RequestParameters, can_create: bool) -> None: - group = Group(id="group", name="group") - - # Only site admin and group admin can update a group - if can_create: - actual = login_service.save_group(group, param) - assert actual == group - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_group(group, param) - actual = login_service.groups.get(group.id) - assert actual is None - - # Users can't create a duplicate group - with pytest.raises(HTTPException): - login_service.save_group(group, param) + def populate_db_fixture(self, login_service: LoginService) -> None: + for group in _GROUPS: + login_service.groups.save(Group(**group)) + main_characters = (u for u in _USERS if u["id"] < 10) + for user in main_characters: + login_service.users.save(User(**user)) + for role in _ROLES: + group = t.cast(Group, login_service.groups.get(role["group_id"])) + user = t.cast(User, login_service.users.get(role["identity_id"])) + role = Role(**role, group=group, identity=user) + login_service.roles.save(role) @with_db_context - @pytest.mark.parametrize( - "param, can_create", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, False), (BAD_PARAM, False)] - ) - def test_create_user(self, login_service: LoginService, param: RequestParameters, can_create: bool) -> None: - create = UserCreateDTO(name="hello", password="world") - - # Only site admin can create a user - if can_create: - actual = login_service.create_user(create, param) - assert actual.name == create.name - else: - with pytest.raises(UserHasNotPermissionError): - login_service.create_user(create, param) - actual = login_service.users.get_by_name(create.name) - assert actual is None - - # Users can't create a duplicate user - with pytest.raises(HTTPException): - login_service.create_user(create, param) + def test_save_group(self, login_service: LoginService) -> None: + # site admin can update any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.save_group(Group(id="superman", name="Poor Men"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Poor Men" + + # Group admin can update his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + login_service.save_group(Group(id="superman", name="Man of Steel"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Man of Steel" + + # Another user of the same group cannot update the group + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.save_group(Group(id="superman", name="Woman of Steel"), _param) + actual = login_service.groups.get("superman") + assert actual is not None + assert actual.name == "Man of Steel" # not updated + + # Group admin cannot update another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.save_group(Group(id="metropolis", name="Man of Steel"), _param) + actual = login_service.groups.get("metropolis") + assert actual is not None + assert actual.name == "Metropolis" # not updated @with_db_context - @pytest.mark.parametrize( - "param, can_save", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, False), (BAD_PARAM, False)] - ) - def test_save_user(self, login_service: LoginService, param: RequestParameters, can_save: bool) -> None: - create = UserCreateDTO(name="Laurent", password="S3cr3t") - user = login_service.create_user(create, SITE_ADMIN) - user.name = "Roland" + def test_create_user(self, login_service: LoginService) -> None: + # Site admin can create a user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.create_user(UserCreateDTO(name="Laurent", password="S3cr3t"), _param) + actual = login_service.users.get_by_name("Laurent") + assert actual is not None + assert actual.name == "Laurent" + + # Group admin cannot create a user + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.create_user(UserCreateDTO(name="Alexandre", password="S3cr3t"), _param) + actual = login_service.users.get_by_name("Alexandre") + assert actual is None + + @with_db_context + def test_save_user(self, login_service: LoginService) -> None: + # Prepare a new user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + user = login_service.create_user(UserCreateDTO(name="Laurentius", password="S3cr3t"), _param) # Only site admin can update a user - if can_save: - login_service.save_user(user, param) - actual = login_service.users.get_by_name(user.name) - assert actual == user - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_user(user, param) - actual = login_service.users.get_by_name(user.name) - assert actual != user + login_service.save_user(User(id=user.id, name="Lawrence"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Lawrence" + + # Group admin cannot update a user + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.save_user(User(id=user.id, name="Loran"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Lawrence" - @with_db_context - def test_save_user__themselves(self, login_service: LoginService) -> None: - user_create = UserCreateDTO(name="Laurent", password="S3cr3t") - user = login_service.create_user(user_create, SITE_ADMIN) - - # users can update themselves - param = RequestParameters( - user=JWTUser( - id=user.id, - impersonator=user.id, - type="users", - groups=[JWTGroup(id="group", name="group", role=RoleType.READER)], - ) - ) - user.name = "Roland" - actual = login_service.save_user(user, param) - assert actual == user + # A user can update himself + _param = get_user_param(login_service, user_id=user.id) + login_service.save_user(User(id=user.id, name="Loran"), _param) + actual = login_service.users.get(user.id) + assert actual is not None + assert actual.name == "Loran" @with_db_context def test_save_bot(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Scoobydoo") - login_service.users.save(user3) - - # Prepare the user group and role - for jwt_group in USER3.user.groups: - group = Group(id=jwt_group.id, name=jwt_group.name) - login_service.groups.save(group) - role = Role(type=jwt_group.role, identity=user3, group=group) - login_service.roles.save(role) - - # Request parameters must reference a user - with pytest.raises(HTTPException): - login_service.save_bot(BotCreateDTO(name="bot", roles=[]), BAD_PARAM) - - # The user USER3 is a reader in the group "group" and can crate a bot with the same role - assert all(jwt_group.role == RoleType.READER for jwt_group in USER3.user.groups) - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - bot = login_service.save_bot(bot_create, USER3) - - assert bot.name == bot_create.name - assert bot.owner == USER3.user.id - assert bot.is_author is True - - # The user can't create a bot with an empty name - bot_create = BotCreateDTO(name="", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) - - # The user can't create a bot with a higher role than his own - for role_type in set(RoleType) - {RoleType.READER}: - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=role_type.value)]) - with pytest.raises(UserHasNotPermissionError): - login_service.save_bot(bot_create, USER3) - - # The user can't create a bot that already exists - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="group", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) + # Joh Fredersen can create Maria because he is the leader of Metropolis + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + login_service.save_bot(BotCreateDTO(name="Maria I", roles=[]), _param) + actual: t.Sequence[Role] = login_service.bots.get_all_by_owner(4) + assert len(actual) == 1 + assert actual[0].name == "Maria I" + + # Freder Fredersen can create Maria with the reader role + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + login_service.save_bot( + BotCreateDTO( + name="Maria II", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.READER.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create Maria with the admin role + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria III", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot with an empty name + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot that already exists + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria II", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Freder Fredersen cannot create a bot with an invalid group + _param = get_user_param(login_service, user_id=5, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria III", + roles=[BotRoleCreateDTO(group="metropolis2", role=RoleType.ADMIN.value)], + ), + _param, + ) + actual = login_service.bots.get_all_by_owner(5) + assert len(actual) == 1 + assert actual[0].name == "Maria II" + + # Bot's name cannot be empty + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) - # The user can't create a bot with a group that does not exist - bot_create = BotCreateDTO(name="bot", roles=[BotRoleCreateDTO(group="unknown", role=RoleType.READER.value)]) - with pytest.raises(HTTPException): - login_service.save_bot(bot_create, USER3) + # Avoid duplicate bots + _param = get_user_param(login_service, user_id=4, group_id="metropolis") + with pytest.raises(Exception): + login_service.save_bot( + BotCreateDTO( + name="Maria I", + roles=[BotRoleCreateDTO(group="metropolis", role=RoleType.ADMIN.value)], + ), + _param, + ) @with_db_context - @pytest.mark.parametrize( - "param, can_save", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, False), (BAD_PARAM, False)] - ) - def test_save_role(self, login_service: LoginService, param: RequestParameters, can_save: bool) -> None: - # Prepare the site admin in the db - assert SITE_ADMIN.user is not None - admin = User(id=SITE_ADMIN.user.id, name="Superman") - login_service.users.save(admin) - - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Kryptonians") - login_service.groups.save(group) - - # Only site admin and group admin can update a role - role = RoleCreationDTO(type=RoleType.ADMIN, identity_id=0, group_id="group") - if can_save: - actual = login_service.save_role(role, param) - assert actual.type == RoleType.ADMIN - assert actual.identity == admin - else: - with pytest.raises(UserHasNotPermissionError): - login_service.save_role(role, param) - actual = login_service.roles.get_all_by_group(group.id) - assert len(actual) == 0 + def test_save_role(self, login_service: LoginService) -> None: + # Prepare a new group and a new user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.groups.save(Group(id="web", name="Spider Web")) + login_service.users.save(User(id=20, name="Spider-man")) + login_service.users.save(User(id=21, name="Spider-woman")) + + # The site admin can create a role + login_service.save_role( + RoleCreationDTO(type=RoleType.ADMIN, group_id="web", identity_id=20), + _param, + ) + actual = login_service.roles.get(20, "web") + assert actual is not None + assert actual.type == RoleType.ADMIN + + # The group admin can create a role + _param = get_user_param(login_service, user_id=20, group_id="web") + login_service.save_role( + RoleCreationDTO(type=RoleType.WRITER, group_id="web", identity_id=21), + _param, + ) + actual = login_service.roles.get(21, "web") + assert actual is not None + assert actual.type == RoleType.WRITER + + # The group admin cannot create a role with an invalid group + _param = get_user_param(login_service, user_id=20, group_id="web") + with pytest.raises(Exception): + login_service.save_role( + RoleCreationDTO(type=RoleType.WRITER, group_id="web2", identity_id=21), + _param, + ) + actual = login_service.roles.get(21, "web") + assert actual is not None + assert actual.type == RoleType.WRITER + + # The user cannot create a role + _param = get_user_param(login_service, user_id=21, group_id="web") + with pytest.raises(Exception): + login_service.save_role( + RoleCreationDTO(type=RoleType.READER, group_id="web", identity_id=20), + _param, + ) + actual = login_service.roles.get(20, "web") + assert actual is not None + assert actual.type == RoleType.ADMIN @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_group(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Vulcans") - login_service.groups.save(group) - - # Anybody except anonymous can get a group - if can_get: - actual = login_service.get_group("group", param) - assert actual == group - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_group(group.id, param) - - # noinspection SpellCheckingInspection + def test_get_group(self, login_service: LoginService) -> None: + # Site admin can get any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin can get his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin cannot get another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_group("metropolis", _param) + + # Lois Lane can get its own group + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_group("superman", _param) + assert actual is not None + assert actual.id == "superman" + @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - { - "id": "group", - "name": "Vulcans", - "users": [ - {"id": 3, "name": "Spock", "role": RoleType.RUNNER}, - {"id": 4, "name": "Saavik", "role": RoleType.RUNNER}, - ], - }, - ), - ( - GROUP_ADMIN, - { - "id": "group", - "name": "Vulcans", - "users": [ - {"id": 3, "name": "Spock", "role": RoleType.RUNNER}, - {"id": 4, "name": "Saavik", "role": RoleType.RUNNER}, - ], - }, - ), - (USER3, {}), - (BAD_PARAM, {}), - ], - ) - def test_get_group_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare the group "group" in the db - # noinspection SpellCheckingInspection - group = Group(id="group", name="Vulcans") - login_service.groups.save(group) - - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Spock") - login_service.users.save(user3) - - # Prepare an LDAP user named "Jane" with id=4 - user4 = UserLdap(id=4, name="Saavik") - login_service.users.save(user4) - - # Spock and Saavik are vulcans and can run simulations - role = Role(type=RoleType.RUNNER, identity=user3, group=group) - login_service.roles.save(role) - role = Role(type=RoleType.RUNNER, identity=user4, group=group) - login_service.roles.save(role) - - # Only site admin and group admin can get a group info - if expected: - actual = login_service.get_group_info("group", param) - assert actual.dict() == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_group_info(group.id, param) + def test_get_group_info(self, login_service: LoginService) -> None: + # Site admin can get any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_group_info("superman", _param) + assert actual is not None + assert actual.name == "Superman" + assert [obj.dict() for obj in actual.users] == [ + {"id": 2, "name": "Clark Kent", "role": RoleType.ADMIN}, + {"id": 3, "name": "Lois Lane", "role": RoleType.READER}, + ] + + # Group admin can get his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_group_info("superman", _param) + assert actual is not None + assert actual.name == "Superman" + + # Group admin cannot get another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_group_info("metropolis", _param) + + # Lois Lane cannot get its own group + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_group_info("superman", _param) @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, True), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_user(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare a group of readers - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - user3 = User(id=USER3.user.id, name="Batman") - login_service.users.save(user3) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the user3 - if can_get: - actual = login_service.get_user(user3.id, param) - assert actual == user3 - else: - # This function doesn't raise an exception if the user does not exist - actual = login_service.get_user(user3.id, param) - assert actual is None + def test_get_user(self, login_service: LoginService) -> None: + # Site admin can get any user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_user(2, _param) + assert actual is not None + assert actual.name == "Clark Kent" + + # Group admin can get a user of his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" + + # Group admin cannot get a user of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_user(5, _param) + assert actual is None + + # Lois Lane can get its own user + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" + + # Create a bot for Lois Lane + _param = get_user_param(login_service, user_id=3, group_id="superman") + bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The bot can get its owner + _param = get_bot_param(login_service, bot_id=bot.id) + actual = login_service.get_user(3, _param) + assert actual is not None + assert actual.name == "Lois Lane" @with_db_context def test_get_identity(self, login_service: LoginService) -> None: - # important: id=1 is the admin user - user = login_service.users.save(User(id=2, name="John")) - user_ldap = login_service.users.save(UserLdap(id=3, name="Jane")) - bot = login_service.users.save(Bot(id=4, name="my-app", owner=3, is_author=False)) + # Create the admin user "Storm" + storm = login_service.users.save(User(id=50, name="Storm")) + # Create the LDAP user "Jane DOE" + jane = login_service.users.save(UserLdap(id=60, name="Jane DOE")) + # Create the bot "Maria" + maria = login_service.users.save(Bot(id=70, name="Maria", owner=50, is_author=False)) - assert login_service.get_identity(2, include_token=False) == user - assert login_service.get_identity(3, include_token=False) == user_ldap - assert login_service.get_identity(4, include_token=False) is None + assert login_service.get_identity(50, include_token=False) == storm + assert login_service.get_identity(60, include_token=False) == jane + assert login_service.get_identity(70, include_token=False) is None - assert login_service.get_identity(2, include_token=True) == user - assert login_service.get_identity(3, include_token=True) == user_ldap - assert login_service.get_identity(4, include_token=True) == bot + assert login_service.get_identity(50, include_token=True) == storm + assert login_service.get_identity(60, include_token=True) == jane + assert login_service.get_identity(70, include_token=True) == maria @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, + def test_get_user_info(self, login_service: LoginService) -> None: + # Site admin can get any user + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + clark_id = 2 + actual = login_service.get_user_info(clark_id, _param) + assert actual is not None + assert actual.dict() == { + "id": clark_id, + "name": "Clark Kent", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - ( - GROUP_ADMIN, + "group_id": "superman", + "group_name": "Superman", + "identity_id": clark_id, + "type": RoleType.ADMIN, + } + ], + } + + # Group admin can get a user of his own group + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + lois_id = 3 + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - ( - USER3, + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } + + # Group admin cannot get a user of another group + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + freder_id = 5 + actual = login_service.get_user_info(freder_id, _param) + assert actual is None + + # Lois Lane can get its own user info + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ { - "id": 3, - "name": "Batman", - "roles": [ - { - "group_id": "group", - "group_name": "readers", - "identity_id": 3, - "type": RoleType.READER, - } - ], - }, - ), - (BAD_PARAM, {}), - ], - ) - def test_get_user_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare a group of readers - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - user3 = User(id=USER3.user.id, name="Batman") - login_service.users.save(user3) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the user3 - if expected: - actual = login_service.get_user_info(user3.id, param) - assert actual.dict() == expected - else: - # This function doesn't raise an exception if the user does not exist - actual = login_service.get_user_info(user3.id, param) - assert actual is None + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } + + # Create a bot for Lois Lane + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The bot can get its owner + _param = get_bot_param(login_service, bot_id=bot.id) + actual = login_service.get_user_info(lois_id, _param) + assert actual is not None + assert actual.dict() == { + "id": lois_id, + "name": "Lois Lane", + "roles": [ + { + "group_id": "superman", + "group_name": "Superman", + "identity_id": lois_id, + "type": RoleType.READER, + } + ], + } @with_db_context - @pytest.mark.parametrize( - "param, can_get", [(SITE_ADMIN, True), (GROUP_ADMIN, False), (USER3, True), (BAD_PARAM, False)] - ) - def test_get_bot(self, login_service: LoginService, param: RequestParameters, can_get: bool) -> None: - # Prepare a user in the db, with id=4 - clark = User(id=3, name="Clark") - login_service.users.save(clark) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=clark.id, is_author=True) - login_service.users.save(bot) - - # Only site admin and the owner can get a bot - if can_get: - actual = login_service.get_bot(bot.id, param) - assert actual == bot - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_bot(bot.id, param) + def test_get_bot(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_bot(joh_bot.id, _param) + assert actual is not None + assert actual.name == "Maria" + + # Joh Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_bot(joh_bot.id, _param) + assert actual is not None + assert actual.name == "Maria" + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_bot(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_bot(joh_bot.id, _param) @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - { - "id": 4, - "isAuthor": True, - "name": "Maria", - "roles": [ - { - "group_id": "Metropolis", - "group_name": "watchers", - "identity_id": 4, - "type": RoleType.READER, - } - ], - }, - ), - (GROUP_ADMIN, {}), - ( - USER3, - { - "id": 4, - "isAuthor": True, - "name": "Maria", - "roles": [ - { - "group_id": "Metropolis", - "group_name": "watchers", - "identity_id": 4, - "type": RoleType.READER, - } - ], - }, - ), - (BAD_PARAM, {}), - ], - ) - def test_get_bot_info( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # Prepare a user in the db, with id=4 - clark = User(id=3, name="Clark") - login_service.users.save(clark) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=clark.id, is_author=True) - login_service.users.save(bot) - - # Prepare a group of readers - group = Group(id="Metropolis", name="watchers") - login_service.groups.save(group) - - # The user3 is a reader in the group "group" - role = Role(type=RoleType.READER, identity=bot, group=group) - login_service.roles.save(role) - - # Only site admin and the owner can get a bot - if expected: - actual = login_service.get_bot_info(bot.id, param) - assert actual is not None - assert actual.dict() == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_bot_info(bot.id, param) + def test_get_bot_info(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_bot_info(joh_bot.id, _param) + assert actual is not None + assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + + # Joh Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_bot_info(joh_bot.id, _param) + assert actual is not None + assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_bot_info(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_bot_info(joh_bot.id, _param) + + # Freder Fredersen cannot get the bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + with pytest.raises(Exception): + login_service.get_bot_info(999, _param) @with_db_context - @pytest.mark.parametrize("param, expected", [(SITE_ADMIN, [5]), (GROUP_ADMIN, []), (USER3, [5]), (BAD_PARAM, [])]) - def test_get_all_bots_by_owner( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Mapping[str, t.Any], - ) -> None: - # add a user, an LDAP user and a bot in the db - user = User(id=3, name="John") - login_service.users.save(user) - user_ldap = UserLdap(id=4, name="Jane") - login_service.users.save(user_ldap) - bot = Bot(id=5, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - if expected: - actual = login_service.get_all_bots_by_owner(3, param) - assert [b.id for b in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_bots_by_owner(3, param) + def test_get_all_bots_by_owner(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get any bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_bots_by_owner(joh_id, _param) + expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] + assert [obj.to_dto().dict() for obj in actual] == expected + + # Freder Fredersen can get its own bot + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + actual = login_service.get_all_bots_by_owner(joh_id, _param) + expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] + assert [obj.to_dto().dict() for obj in actual] == expected + + # The bot cannot get itself + _param = get_bot_param(login_service, bot_id=joh_bot.id) + with pytest.raises(Exception): + login_service.get_all_bots_by_owner(joh_id, _param) + + # Freder Fredersen cannot get the bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots_by_owner(joh_id, _param) @with_db_context def test_exists_bot(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Clark") - login_service.users.save(user3) - - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=4, name="Maria", owner=user3.id, is_author=True) - login_service.users.save(bot) + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) # Everybody can check the existence of a bot - assert login_service.exists_bot(4) - assert not login_service.exists_bot(5) # unknown ID - assert not login_service.exists_bot(3) # user ID, not bot ID + assert login_service.exists_bot(joh_id) is False, "not a bot" + assert login_service.exists_bot(joh_bot.id) is True + assert login_service.exists_bot(999) is False @with_db_context - def test_authenticate__unknown_user(self, login_service: LoginService) -> None: - # An unknown user cannot log in - user = login_service.authenticate(name="unknown", pwd="S3cr3t") - assert user is None - - @with_db_context - def test_authenticate__known_user(self, login_service: LoginService) -> None: - # Create a user named "Tarzan" in the group "Adventure" - group = Group(id="adventure", name="Adventure") - login_service.groups.save(group) - user = User(id=3, name="Tarzan", password=Password("S3cr3t")) - login_service.users.save(user) - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) + def test_authenticate(self, login_service: LoginService) -> None: + # Update the password of "Lois Lane" + lois_id = 3 + login_service.users.save(User(id=lois_id, name="Lois Lane", password=Password("S3cr3t"))) # A known user can log in - jwt_user = login_service.authenticate(name="Tarzan", pwd="S3cr3t") + jwt_user = login_service.authenticate(name="Lois Lane", pwd="S3cr3t") assert jwt_user is not None - assert jwt_user.id == user.id - assert jwt_user.impersonator == user.id + assert jwt_user.id == lois_id + assert jwt_user.impersonator == lois_id assert jwt_user.type == "users" - assert jwt_user.groups == [JWTGroup(id="adventure", name="Adventure", role=RoleType.READER)] + assert jwt_user.groups == [ + JWTGroup(id="superman", name="Superman", role=RoleType.READER), + ] - @with_db_context - def test_authenticate__external_user(self, login_service: LoginService) -> None: - # Create a user named "Tarzan" - user_ldap = UserLdap(id=3, name="Tarzan", external_id="tarzan", firstname="Tarzan", lastname="Jungle") + # An unknown user cannot log in + user = login_service.authenticate(name="unknown", pwd="S3cr3t") + assert user is None + + # Update the user "Jane DOE" which is an LDAP user + jane_id = 60 + user_ldap = UserLdap( + id=jane_id, + name="Jane DOE", + external_id="j.doe", + firstname="Jane", + lastname="DOE", + ) login_service.users.save(user_ldap) # Mock the LDAP service - login_service.ldap.login = Mock(return_value=user_ldap) # type: ignore - login_service.ldap.get = Mock(return_value=user_ldap) # type: ignore + with patch("antarest.login.ldap.LdapService.login") as mock_login: + mock_login.return_value = user_ldap + with patch("antarest.login.ldap.LdapService.login") as mock_get: + mock_get.return_value = user_ldap + jwt_user = login_service.authenticate(name="Jane DOE", pwd="S3cr3t") - # A known user can log in - jwt_user = login_service.authenticate(name="Tarzan", pwd="S3cr3t") assert jwt_user is not None assert jwt_user.id == user_ldap.id assert jwt_user.impersonator == user_ldap.id @@ -597,32 +678,30 @@ def test_authenticate__external_user(self, login_service: LoginService) -> None: @with_db_context def test_get_jwt(self, login_service: LoginService) -> None: - # Prepare the user3 in the db - assert USER3.user is not None - user3 = User(id=USER3.user.id, name="Clark") - login_service.users.save(user3) - - # Attach a group to the user - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Prepare an LDAP user in the db - user_ldap = UserLdap(id=4, name="Jane") + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # Update the user "Jane DOE" which is an LDAP user + jane_id = 60 + user_ldap = UserLdap( + id=jane_id, + name="Jane DOE", + external_id="j.doe", + firstname="Jane", + lastname="DOE", + ) login_service.users.save(user_ldap) - # Prepare a bot in the db (using the ID or user3) - bot = Bot(id=5, name="Maria", owner=user3.id, is_author=True) - login_service.users.save(bot) - # We can get a JWT for a user, an LDAP user, but not a bot - jwt_user = login_service.get_jwt(user3.id) + lois_id = 3 + jwt_user = login_service.get_jwt(lois_id) assert jwt_user is not None - assert jwt_user.id == user3.id - assert jwt_user.impersonator == user3.id + assert jwt_user.id == lois_id + assert jwt_user.impersonator == lois_id assert jwt_user.type == "users" - assert jwt_user.groups == [JWTGroup(id="group", name="readers", role=RoleType.READER)] + assert jwt_user.groups == [JWTGroup(id="superman", name="Superman", role=RoleType.READER)] jwt_user = login_service.get_jwt(user_ldap.id) assert jwt_user is not None @@ -631,367 +710,306 @@ def test_get_jwt(self, login_service: LoginService) -> None: assert jwt_user.type == "users_ldap" assert jwt_user.groups == [] - jwt_user = login_service.get_jwt(bot.id) + jwt_user = login_service.get_jwt(joh_bot.id) assert jwt_user is None @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - [ - {"id": "admin", "name": "admin"}, - {"id": "gr1", "name": "Adventure"}, - {"id": "gr2", "name": "Comedy"}, - ], - ), - ( - GROUP_ADMIN, - [ - {"id": "admin", "name": "admin"}, - {"id": "gr2", "name": "Comedy"}, - ], - ), - ( - USER3, - [ - {"id": "gr1", "name": "Adventure"}, - ], - ), - (BAD_PARAM, []), - ], - ) - def test_get_all_groups( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Mapping[str, str]], - ) -> None: - # Prepare some groups in the db - group1 = Group(id="gr1", name="Adventure") - login_service.groups.save(group1) - group2 = Group(id="gr2", name="Comedy") - login_service.groups.save(group2) - - # The group admin is a reader in the group "gr2" - assert GROUP_ADMIN.user is not None - robin_hood = User(id=GROUP_ADMIN.user.id, name="Robin") - login_service.users.save(robin_hood) - role = Role(type=RoleType.READER, identity=robin_hood, group=group2) - login_service.roles.save(role) - - # The user3 is a reader in the group "gr1" - assert USER3.user is not None - indiana_johns = User(id=USER3.user.id, name="Indiana") - login_service.users.save(indiana_johns) - role = Role(type=RoleType.READER, identity=indiana_johns, group=group1) - login_service.roles.save(role) - - # Anybody except anonymous can get the list of groups - if expected: - # The site admin can see all groups - actual = login_service.get_all_groups(param) - assert [g.dict() for g in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_groups(param) + def test_get_all_groups(self, login_service: LoginService) -> None: + # The site admin can get all groups + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [ + {"id": "admin", "name": "X-Men"}, + {"id": "superman", "name": "Superman"}, + {"id": "metropolis", "name": "Metropolis"}, + ] + + # The group admin can its own groups + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + + # The user can get its own groups + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_all_groups(_param) + assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - ( - SITE_ADMIN, - [ - {"id": 0, "name": "Superman"}, - {"id": 1, "name": "John"}, - {"id": 2, "name": "Jane"}, - {"id": 3, "name": "Tarzan"}, - ], - ), - ( - GROUP_ADMIN, - [ - {"id": 1, "name": "John"}, - ], - ), - ( - USER3, - [ - {"id": 3, "name": "Tarzan"}, - ], - ), - (BAD_PARAM, []), - ], - ) - def test_get_all_users( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Mapping[str, t.Union[int, str]]], - ) -> None: - # Insert some users in the db - user0 = User(id=0, name="Superman") - login_service.users.save(user0) - user1 = User(id=1, name="John") - login_service.users.save(user1) - user2 = User(id=2, name="Jane") - login_service.users.save(user2) - user3 = User(id=3, name="Tarzan") - login_service.users.save(user3) - - # user3 is a reader in the group "group" - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # Anybody except anonymous can get the list of users - if expected: - actual = login_service.get_all_users(param) - assert [u.dict() for u in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_users(param) + def test_get_all_users(self, login_service: LoginService) -> None: + # The site admin can get all users + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 1, "name": "Professor Xavier"}, + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + {"id": 4, "name": "Joh Fredersen"}, + {"id": 5, "name": "Freder Fredersen"}, + ] + + # The group admin can get its own users, but also the users of the other groups + # note: I don't know why the group admin can get all users -- Laurent + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 1, "name": "Professor Xavier"}, + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + {"id": 4, "name": "Joh Fredersen"}, + {"id": 5, "name": "Freder Fredersen"}, + ] + + # The user can get its own users + _param = get_user_param(login_service, user_id=3, group_id="superman") + actual = login_service.get_all_users(_param) + assert [u.dict() for u in actual] == [ + {"id": 2, "name": "Clark Kent"}, + {"id": 3, "name": "Lois Lane"}, + ] @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - (SITE_ADMIN, [5]), - (GROUP_ADMIN, []), - (USER3, []), - (BAD_PARAM, []), - ], - ) - def test_get_all_bots( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[int], - ) -> None: - # add a user, an LDAP user and a bot in the db - user = User(id=3, name="John") - login_service.users.save(user) - user_ldap = UserLdap(id=4, name="Jane") - login_service.users.save(user_ldap) - bot = Bot(id=5, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - if expected: - actual = login_service.get_all_bots(param) - assert [b.id for b in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_bots(param) + def test_get_all_bots(self, login_service: LoginService) -> None: + # Create a bot for Joh Fredersen + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="superman") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can get all bots + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_bots(_param) + assert [b.to_dto().dict() for b in actual] == [ + {"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}, + ] + + # The group admin cannot access the list of bots + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots(_param) + + # The user cannot access the list of bots + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_bots(_param) @with_db_context - @pytest.mark.parametrize( - "param, expected", - [ - (SITE_ADMIN, [(3, "group")]), - (GROUP_ADMIN, [(3, "group")]), - (USER3, []), - (BAD_PARAM, []), - ], - ) - def test_get_all_roles_in_group( - self, - login_service: LoginService, - param: RequestParameters, - expected: t.Sequence[t.Tuple[int, str]], - ) -> None: - # Insert some users in the db - user0 = User(id=0, name="Superman") - login_service.users.save(user0) - user1 = User(id=1, name="John") - login_service.users.save(user1) - user2 = User(id=2, name="Jane") - login_service.users.save(user2) - user3 = User(id=3, name="Tarzan") - login_service.users.save(user3) - - # user3 is a reader in the group "group" - group = Group(id="group", name="readers") - login_service.groups.save(group) - role = Role(type=RoleType.READER, identity=user3, group=group) - login_service.roles.save(role) - - # The site admin and the group admin can get the list of roles in a group - if expected: - actual = login_service.get_all_roles_in_group("group", param) - assert [(r.identity_id, r.group_id) for r in actual] == expected - else: - with pytest.raises(UserHasNotPermissionError): - login_service.get_all_roles_in_group("group", param) + def test_get_all_roles_in_group(self, login_service: LoginService) -> None: + # The site admin can get all roles in a given group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + actual = login_service.get_all_roles_in_group("superman", _param) + assert [b.to_dto().dict() for b in actual] == [ + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 2, "name": "Clark Kent"}, + "type": RoleType.ADMIN, + }, + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 3, "name": "Lois Lane"}, + "type": RoleType.READER, + }, + ] + + # The group admin can get all roles his own group + _param = get_user_param(login_service, user_id=2, group_id="superman") + actual = login_service.get_all_roles_in_group("superman", _param) + assert [b.to_dto().dict() for b in actual] == [ + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 2, "name": "Clark Kent"}, + "type": RoleType.ADMIN, + }, + { + "group": {"id": "superman", "name": "Superman"}, + "identity": {"id": 3, "name": "Lois Lane"}, + "type": RoleType.READER, + }, + ] + + # The user cannot access the list of roles + _param = get_user_param(login_service, user_id=3, group_id="superman") + with pytest.raises(Exception): + login_service.get_all_roles_in_group("superman", _param) @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_group(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # The site admin and the group admin can delete a group - if can_delete: - login_service.delete_group("group", param) - actual = login_service.groups.get("group") - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_group("group", param) - actual = login_service.groups.get("group") - assert actual is not None + def test_delete_group(self, login_service: LoginService) -> None: + # Create new groups for Lois Lane (3) and Freder Fredersen (5) + group1 = login_service.groups.save(Group(id="g1", name="Group I")) + group2 = login_service.groups.save(Group(id="g2", name="Group II")) + group3 = login_service.groups.save(Group(id="g3", name="Group III")) + + lois = t.cast(User, login_service.users.get(3)) # group admin + freder = t.cast(User, login_service.users.get(5)) # user + + login_service.roles.save(Role(type=RoleType.ADMIN, group=group1, identity=lois)) + login_service.roles.save(Role(type=RoleType.READER, group=group1, identity=freder)) + login_service.roles.save(Role(type=RoleType.ADMIN, group=group2, identity=lois)) + login_service.roles.save(Role(type=RoleType.WRITER, group=group2, identity=freder)) + login_service.roles.save(Role(type=RoleType.ADMIN, group=group3, identity=lois)) + login_service.roles.save(Role(type=RoleType.RUNNER, group=group3, identity=freder)) + + # The site admin can delete any group + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_group("g1", _param) + assert login_service.groups.get(group1.id) is None + + # The group admin can delete his own group + _param = get_user_param(login_service, user_id=3, group_id="g2") + login_service.delete_group("g2", _param) + assert login_service.groups.get(group2.id) is None + + # The user cannot delete a group + _param = get_user_param(login_service, user_id=5, group_id="g3") + with pytest.raises(Exception): + login_service.delete_group("g3", _param) + assert login_service.groups.get(group3.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, False), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_user(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a user in the db which is an owner of a bot - user = User(id=3, name="John") - login_service.users.save(user) - bot = Bot(id=4, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - # The site admin can delete the user - if can_delete: - login_service.delete_user(3, param) - actual = login_service.users.get(3) - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_user(3, param) - actual = login_service.users.get(3) - assert actual is not None + def test_delete_user(self, login_service: LoginService) -> None: + # Create Joh's bot + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can delete Fredersen (5) + freder_id = 5 + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_user(freder_id, _param) + assert login_service.users.get(freder_id) is None + + # The group admin Joh can delete himself (4) + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + login_service.delete_user(joh_id, _param) + assert login_service.users.get(joh_id) is None + assert login_service.bots.get(joh_bot.id) is None + + # Lois Lane cannot delete Clark Kent (2) + lois_id = 3 + clark_id = 2 + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_user(clark_id, _param) + assert login_service.users.get(clark_id) is not None + + # Clark Kent cannot delete Lois Lane (3) + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_user(lois_id, _param) + assert login_service.users.get(lois_id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, False), - (USER3, True), - (BAD_PARAM, False), - ], - ) - def test_delete_bot(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert a user in the db which is an owner of a bot - user = User(id=3, name="John") - login_service.users.save(user) - bot = Bot(id=4, name="my-app", owner=3, is_author=False) - login_service.users.save(bot) - - # The site admin and the current owner can delete the bot - if can_delete: - login_service.delete_bot(4, param) - actual = login_service.bots.get(4) - assert actual is None - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_bot(4, param) - actual = login_service.bots.get(4) - assert actual is not None + def test_delete_bot(self, login_service: LoginService) -> None: + # Create Joh's bot + joh_id = 4 + _param = get_user_param(login_service, user_id=joh_id, group_id="metropolis") + joh_bot = login_service.save_bot(BotCreateDTO(name="Maria", roles=[]), _param) + + # The site admin can delete the bot + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_bot(joh_bot.id, _param) + assert login_service.bots.get(joh_bot.id) is None + + # Create Lois's bot + lois_id = 3 + _param = get_user_param(login_service, user_id=lois_id, group_id="superman") + lois_bot = login_service.save_bot(BotCreateDTO(name="Lois bot", roles=[]), _param) + + # The group admin cannot delete the bot + clark_id = 2 + _param = get_user_param(login_service, user_id=clark_id, group_id="superman") + with pytest.raises(Exception): + login_service.delete_bot(lois_bot.id, _param) + assert login_service.bots.get(lois_bot.id) is not None + + # Create Freder's bot + freder_id = 5 + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + freder_bot = login_service.save_bot(BotCreateDTO(name="Freder bot", roles=[]), _param) + + # Freder can delete his own bot + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + login_service.delete_bot(freder_bot.id, _param) + assert login_service.bots.get(freder_bot.id) is None + + # Freder cannot delete Lois's bot + _param = get_user_param(login_service, user_id=freder_id, group_id="metropolis") + with pytest.raises(Exception): + login_service.delete_bot(lois_bot.id, _param) + assert login_service.bots.get(lois_bot.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_role(self, login_service: LoginService, param: RequestParameters, can_delete: bool) -> None: - # Insert the user3 in the db - user = User(id=3, name="Tarzan") - login_service.users.save(user) - - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # Insert a role in the db - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) - - # The site admin and the group admin can delete a role - if can_delete: - login_service.delete_role(3, "group", param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 0 - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_role(3, "group", param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 1 + def test_delete_role(self, login_service: LoginService) -> None: + # Create a new group + group = login_service.groups.save(Group(id="g1", name="Group I")) + + # Create a new user + user = login_service.users.save(User(id=10, name="User 1")) + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The site admin can delete any role + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin can delete a role of his own group + _param = get_user_param(login_service, user_id=user.id, group_id="g1") + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin cannot delete a role of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.delete_role(role.identity.id, "g1", _param) + assert login_service.roles.get(role.identity.id, "g1") is not None + + # The user cannot delete a role + _param = get_user_param(login_service, user_id=1, group_id="g1") + with pytest.raises(Exception): + login_service.delete_role(role.identity.id, role.group.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None @with_db_context - @pytest.mark.parametrize( - "param, can_delete", - [ - (SITE_ADMIN, True), - (GROUP_ADMIN, True), - (USER3, False), - (BAD_PARAM, False), - ], - ) - def test_delete_all_roles_from_user( - self, login_service: LoginService, param: RequestParameters, can_delete: bool - ) -> None: - # Insert the user3 in the db - assert USER3.user is not None - user = User(id=USER3.user.id, name="Tarzan") - login_service.users.save(user) - - # Insert a group in the db - group = Group(id="group", name="readers") - login_service.groups.save(group) - - # Insert a role in the db - role = Role(type=RoleType.READER, identity=user, group=group) - login_service.roles.save(role) - - # Insert the group admin in the db - assert GROUP_ADMIN.user is not None - group_admin = User(id=GROUP_ADMIN.user.id, name="John") - login_service.users.save(group_admin) - - # Insert another group in the db - group2 = Group(id="group2", name="readers") - login_service.groups.save(group2) - - # Insert a role in the db - role2 = Role(type=RoleType.READER, identity=group_admin, group=group2) - login_service.roles.save(role2) - - # The site admin and the group admin can delete a role - if can_delete: - login_service.delete_all_roles_from_user(3, param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 0 - actual = login_service.roles.get_all_by_group("group2") - assert len(actual) == 1 - else: - with pytest.raises(UserHasNotPermissionError): - login_service.delete_all_roles_from_user(3, param) - actual = login_service.roles.get_all_by_group("group") - assert len(actual) == 1 - actual = login_service.roles.get_all_by_group("group2") - assert len(actual) == 1 + def test_delete_all_roles_from_user(self, login_service: LoginService) -> None: + # Create a new group + group = login_service.groups.save(Group(id="g1", name="Group I")) + + # Create a new user + user = login_service.users.save(User(id=10, name="User 1")) + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The site admin can delete any role + _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin can delete a role of his own group + _param = get_user_param(login_service, user_id=user.id, group_id="g1") + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is None + + # Create a new role + role = login_service.roles.save(Role(type=RoleType.ADMIN, group=group, identity=user)) + + # The group admin cannot delete a role of another group + _param = get_user_param(login_service, user_id=2, group_id="superman") + with pytest.raises(Exception): + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None + + # The user cannot delete a role + _param = get_user_param(login_service, user_id=1, group_id="g1") + with pytest.raises(Exception): + login_service.delete_all_roles_from_user(user.id, _param) + assert login_service.roles.get(role.identity.id, role.group.id) is not None diff --git a/tests/login/test_model.py b/tests/login/test_model.py index 787f4f2d6a..2dee1d994e 100644 --- a/tests/login/test_model.py +++ b/tests/login/test_model.py @@ -1,5 +1,66 @@ -from antarest.login.model import Password +import contextlib + +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.exc import IntegrityError # type: ignore +from sqlalchemy.orm import sessionmaker # type: ignore + +from antarest.login.model import ( + ADMIN_ID, + ADMIN_NAME, + GROUP_ID, + GROUP_NAME, + Group, + Password, + Role, + User, + init_admin_user, +) +from antarest.utils import SESSION_ARGS + +TEST_ADMIN_PASS_WORD = "test" def test_password(): assert Password("pwd").check("pwd") + + +class TestInitAdminUser: + def test_init_admin_user_nominal(self, db_engine: Engine): + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = session.query(User).get(ADMIN_ID) + assert user is not None + assert user.id == ADMIN_ID + assert user.name == ADMIN_NAME + assert user.password.check(TEST_ADMIN_PASS_WORD) + group = session.query(Group).get(GROUP_ID) + assert group is not None + assert group.id == GROUP_ID + assert group.name == GROUP_NAME + role = session.query(Role).get((ADMIN_ID, GROUP_ID)) + assert role is not None + assert role.identity is user + assert role.group is group + + def test_init_admin_user_redundancy_check(self, db_engine: Engine): + # run first time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + # run second time + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_group(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + group = Group(id=GROUP_ID, name=GROUP_NAME) + session.add(group) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) + + def test_init_admin_user_existing_user(self, db_engine: Engine): + make_session = sessionmaker(bind=db_engine) + with make_session() as session: + user = User(id=ADMIN_ID, name=ADMIN_NAME, password=Password(TEST_ADMIN_PASS_WORD)) + session.add(user) + session.commit() + init_admin_user(db_engine, SESSION_ARGS, admin_password=TEST_ADMIN_PASS_WORD) diff --git a/tests/login/test_repository.py b/tests/login/test_repository.py index 6669747507..60bdbc0dbf 100644 --- a/tests/login/test_repository.py +++ b/tests/login/test_repository.py @@ -1,29 +1,14 @@ import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker # type: ignore +from sqlalchemy.orm import Session, scoped_session, sessionmaker # type: ignore -from antarest.core.config import Config, SecurityConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db from antarest.login.model import Bot, Group, Password, Role, RoleType, User, UserLdap from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository @pytest.mark.unit_test -def test_users(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserRepository( - config=Config(security=SecurityConfig(admin_pwd="admin")), - ) +def test_users(db_session: Session): + with db_session: + repo = UserRepository(session=db_session) a = User( name="a", password=Password("a"), @@ -43,18 +28,9 @@ def test_users(): @pytest.mark.unit_test -def test_users_ldap(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = UserLdapRepository() +def test_users_ldap(db_session: Session): + repo = UserLdapRepository(session=db_session) + with repo.session: a = UserLdap(name="a", external_id="b") a = repo.save(a) @@ -67,18 +43,9 @@ def test_users_ldap(): @pytest.mark.unit_test -def test_bots(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = BotRepository() +def test_bots(db_session: Session): + repo = BotRepository(session=db_session) + with repo.session: a = Bot(name="a", owner=1) a = repo.save(a) assert a.id @@ -98,19 +65,9 @@ def test_bots(): @pytest.mark.unit_test -def test_groups(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = GroupRepository() - +def test_groups(db_session: Session): + repo = GroupRepository(session=db_session) + with repo.session: a = Group(name="a") a = repo.save(a) @@ -125,19 +82,9 @@ def test_groups(): @pytest.mark.unit_test -def test_roles(): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - with db(): - repo = RoleRepository() - +def test_roles(db_session: Session): + repo = RoleRepository(session=db_session) + with repo.session: a = Role(type=RoleType.ADMIN, identity=User(id=0), group=Group(id="group")) a = repo.save(a) diff --git a/tests/matrixstore/test_repository.py b/tests/matrixstore/test_repository.py index 3973a18d39..ba76b17fac 100644 --- a/tests/matrixstore/test_repository.py +++ b/tests/matrixstore/test_repository.py @@ -1,13 +1,12 @@ +import datetime import typing as t -from datetime import datetime from pathlib import Path import numpy as np import pytest from numpy import typing as npt +from sqlalchemy.orm import Session # type: ignore -from antarest.core.config import Config, SecurityConfig -from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group, Password, User from antarest.login.repository import GroupRepository, UserRepository from antarest.matrixstore.model import Matrix, MatrixContent, MatrixDataSet, MatrixDataSetRelation @@ -17,11 +16,10 @@ class TestMatrixRepository: - def test_db_lifecycle(self) -> None: - with db(): - # sourcery skip: extract-method - repo = MatrixRepository() - m = Matrix(id="hello", created_at=datetime.now()) + def test_db_lifecycle(self, db_session: Session) -> None: + with db_session: + repo = MatrixRepository(db_session) + m = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m) assert m.id assert m == repo.get(m.id) @@ -51,24 +49,21 @@ def test_bucket_lifecycle(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError): repo.get(aid) - def test_dataset(self) -> None: - with db(): - # sourcery skip: extract-duplicate-method, extract-method - repo = MatrixRepository() + def test_dataset(self, db_session: Session) -> None: + with db_session: + repo = MatrixRepository(session=db_session) - user_repo = UserRepository(Config(security=SecurityConfig())) - # noinspection PyArgumentList + user_repo = UserRepository(session=db_session) user = user_repo.save(User(name="foo", password=Password("bar"))) - group_repo = GroupRepository() - # noinspection PyArgumentList + group_repo = GroupRepository(session=db_session) group = group_repo.save(Group(name="group")) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) - m1 = Matrix(id="hello", created_at=datetime.now()) + m1 = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m1) - m2 = Matrix(id="world", created_at=datetime.now()) + m2 = Matrix(id="world", created_at=datetime.datetime.now()) repo.save(m2) dataset = MatrixDataSet( @@ -76,8 +71,8 @@ def test_dataset(self) -> None: public=True, owner_id=user.id, groups=[group], - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") @@ -97,7 +92,7 @@ def test_dataset(self) -> None: id=dataset.id, name="some name change", public=False, - updated_at=datetime.now(), + updated_at=datetime.datetime.now(), ) dataset_repo.save(dataset_update) dataset_query_result = dataset_repo.get(dataset.id) @@ -105,29 +100,27 @@ def test_dataset(self) -> None: assert dataset_query_result.name == "some name change" assert dataset_query_result.owner_id == user.id - def test_datastore_query(self) -> None: + def test_datastore_query(self, db_session: Session) -> None: # sourcery skip: extract-duplicate-method - with db(): - user_repo = UserRepository(Config(security=SecurityConfig())) - # noinspection PyArgumentList + with db_session: + user_repo = UserRepository(session=db_session) user1 = user_repo.save(User(name="foo", password=Password("bar"))) - # noinspection PyArgumentList user2 = user_repo.save(User(name="hello", password=Password("world"))) - repo = MatrixRepository() - m1 = Matrix(id="hello", created_at=datetime.now()) + repo = MatrixRepository(session=db_session) + m1 = Matrix(id="hello", created_at=datetime.datetime.now()) repo.save(m1) - m2 = Matrix(id="world", created_at=datetime.now()) + m2 = Matrix(id="world", created_at=datetime.datetime.now()) repo.save(m2) - dataset_repo = MatrixDataSetRepository() + dataset_repo = MatrixDataSetRepository(session=db_session) dataset = MatrixDataSet( name="some name", public=True, owner_id=user1.id, - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") matrix_relation.matrix_id = "hello" @@ -141,8 +134,8 @@ def test_datastore_query(self) -> None: name="some name 2", public=False, owner_id=user2.id, - created_at=datetime.now(), - updated_at=datetime.now(), + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), ) matrix_relation = MatrixDataSetRelation(name="m1") matrix_relation.matrix_id = "hello" @@ -163,14 +156,12 @@ def test_datastore_query(self) -> None: assert len(dataset_repo.query("name 2")) == 0 assert repo.get(m1.id) is not None assert ( - len( - # fmt: off - db.session - .query(MatrixDataSetRelation) - .filter(MatrixDataSetRelation.dataset_id == dataset.id) - .all() - # fmt: on - ) + # fmt: off + db_session + .query(MatrixDataSetRelation) + .filter(MatrixDataSetRelation.dataset_id == dataset.id) + .count() + # fmt: on == 0 ) diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 9f8e0be884..4caee7b7bd 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -9,6 +9,7 @@ from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.business.area_management import AreaCreationDTO, AreaManager, AreaType, AreaUI from antarest.study.business.link_management import LinkInfoDTO, LinkManager @@ -66,7 +67,10 @@ def matrix_service_fixture(tmp_path: Path) -> SimpleMatrixService: """ matrix_path = tmp_path.joinpath("matrix-store") matrix_path.mkdir() - return SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + return SimpleMatrixService(matrix_content_repository=matrix_content_repository) @with_db_context @@ -94,8 +98,10 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): raw_study_service.get_raw.return_value = empty_study raw_study_service.cache = Mock() + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices() variant_study_service.command_factory = CommandFactory( - GeneratorMatrixConstants(matrix_service), + generator_matrix_constants, matrix_service, patch_service=Mock(spec=PatchService), ) diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index 4ff8fbf888..197be27144 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -12,6 +12,7 @@ from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base from antarest.login.model import User +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.main import build_study_service from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, StudyAdditionalData @@ -87,7 +88,10 @@ def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) matrix_path = tmp_path / "matrices" matrix_path.mkdir() - matrix_service = SimpleMatrixService(matrix_path) + matrix_content_repository = MatrixContentRepository( + bucket_dir=matrix_path, + ) + matrix_service = SimpleMatrixService(matrix_content_repository=matrix_content_repository) storage_service = build_study_service( application=Mock(), cache=LocalCache(config=config.cache), diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index 93a3262259..a216571510 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -1,5 +1,6 @@ import numpy as np +from antarest.matrixstore.repository import MatrixContentRepository from antarest.matrixstore.service import SimpleMatrixService from antarest.study.storage.variantstudy.business import matrix_constants from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( @@ -10,7 +11,15 @@ class TestGeneratorMatrixConstants: def test_get_st_storage(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) + generator.init_constant_matrices() ref1 = generator.get_st_storage_pmax_injection() matrix_id1 = ref1.split(MATRIX_PROTOCOL_PREFIX)[1] @@ -38,20 +47,28 @@ def test_get_st_storage(self, tmp_path): assert np.array(matrix_dto5.data).all() == matrix_constants.st_storage.series.inflows.all() def test_get_binding_constraint(self, tmp_path): - generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + matrix_content_repository = MatrixContentRepository( + bucket_dir=tmp_path, + ) + generator = GeneratorMatrixConstants( + matrix_service=SimpleMatrixService( + matrix_content_repository=matrix_content_repository, + ) + ) + generator.init_constant_matrices() series = matrix_constants.binding_constraint.series hourly = generator.get_binding_constraint_hourly() hourly_matrix_id = hourly.split(MATRIX_PROTOCOL_PREFIX)[1] hourly_matrix_dto = generator.matrix_service.get(hourly_matrix_id) - assert np.array(hourly_matrix_dto.data).all() == series.default_binding_constraint_hourly.all() + assert np.array(hourly_matrix_dto.data).all() == series.default_bc_hourly.all() daily = generator.get_binding_constraint_daily() daily_matrix_id = daily.split(MATRIX_PROTOCOL_PREFIX)[1] daily_matrix_dto = generator.matrix_service.get(daily_matrix_id) - assert np.array(daily_matrix_dto.data).all() == series.default_binding_constraint_daily.all() + assert np.array(daily_matrix_dto.data).all() == series.default_bc_weekly_daily.all() weekly = generator.get_binding_constraint_weekly() weekly_matrix_id = weekly.split(MATRIX_PROTOCOL_PREFIX)[1] weekly_matrix_dto = generator.matrix_service.get(weekly_matrix_id) - assert np.array(weekly_matrix_dto.data).all() == series.default_binding_constraint_weekly.all() + assert np.array(weekly_matrix_dto.data).all() == series.default_bc_weekly_daily.all() diff --git a/tests/study/storage/variantstudy/model/test_dbmodel.py b/tests/study/storage/variantstudy/model/test_dbmodel.py index 0bcd107518..0715dec535 100644 --- a/tests/study/storage/variantstudy/model/test_dbmodel.py +++ b/tests/study/storage/variantstudy/model/test_dbmodel.py @@ -215,7 +215,7 @@ def test_init__without_snapshot(self, db_session: Session, raw_study_id: str, us # check Variant-specific properties assert obj.snapshot_dir == Path(variant_study_path).joinpath("snapshot") - assert obj.is_snapshot_recent() is False + assert obj.is_snapshot_up_to_date() is False @pytest.mark.parametrize( "created_at, updated_at, study_antares_file, expected", @@ -294,4 +294,4 @@ def test_is_snapshot_recent( # Check the snapshot_uptodate() method obj: VariantStudy = db_session.query(VariantStudy).filter(VariantStudy.id == variant_id).one() - assert obj.is_snapshot_recent() == expected + assert obj.is_snapshot_up_to_date() == expected diff --git a/tests/study/storage/variantstudy/test_snapshot_generator.py b/tests/study/storage/variantstudy/test_snapshot_generator.py index 2247b21045..5e90b6ee06 100644 --- a/tests/study/storage/variantstudy/test_snapshot_generator.py +++ b/tests/study/storage/variantstudy/test_snapshot_generator.py @@ -25,6 +25,7 @@ from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator, search_ref_study from antarest.study.storage.variantstudy.variant_study_service import VariantStudyService +from tests.db_statement_recorder import DBStatementRecorder from tests.helpers import with_db_context @@ -85,12 +86,12 @@ class TestSearchRefStudy: and corresponding to a variant with an up-to-date snapshot. - The case where the list of studies contains two variants with up-to-date snapshots and - where the first is older than the second. + where the first is older than the second, and a third variant without snapshot. We expect to have a reference study corresponding to the second variant and a list of commands for the second variant. - The case where the list of studies contains two variants with up-to-date snapshots and - where the first is more recent than the second. + where the first is more recent than the second, and a third variant without snapshot. We expect to have a reference study corresponding to the first variant and a list of commands for both variants in order. @@ -118,9 +119,10 @@ def test_search_ref_study__empty_descendants(self) -> None: """ root_study = Study(id=str(uuid.uuid4()), name="root") references: t.Sequence[VariantStudy] = [] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == root_study - assert cmd_blocks == [] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [] + assert search_result.force_regenerate is True def test_search_ref_study__from_scratch(self, tmp_path: Path) -> None: """ @@ -195,9 +197,10 @@ def test_search_ref_study__from_scratch(self, tmp_path: Path) -> None: # Check the variants references = [variant1, variant2, variant3] - ref_study, cmd_blocks = search_ref_study(root_study, references, from_scratch=True) - assert ref_study == root_study - assert cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + search_result = search_ref_study(root_study, references, from_scratch=True) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: """ @@ -205,12 +208,9 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: - either there is no snapshot, - or the snapshot's creation date is earlier than the variant's last modification date. Note: The situation where the "snapshot/study.antares" file does not exist is not considered. - We expect to have the root study and a list of `CommandBlock` for all variants. - - Given a list of descendants with some variants with obsolete snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the root study is returned as reference study, - and all commands of all variants are returned. + The third variant has no snapshot, and must be generated from scratch. + We expect to have a reference study corresponding to the root study + and the list of commands of all variants in order. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -231,6 +231,14 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=1), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -252,25 +260,41 @@ def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: version=1, args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "IT", "cluster_name": "IT", "cluster_type": "gas"}', + ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == root_study - assert cmd_blocks == [c for v in [variant1, variant2] for c in v.commands] + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == root_study + assert search_result.cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: """ Case where the list of studies contains a variant with up-to-date snapshots and where the first is older than the second. + The third variant has no snapshot, and must be generated from scratch. We expect to have a reference study corresponding to the second variant - and an empty list of commands, because the snapshot is already completely up-to-date. - - Given a list of descendants with some variants with up-to-date snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the second variant is returned as reference study, and no commands are returned. + and the list of commands of the third variant. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -291,6 +315,14 @@ def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=3), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -315,24 +347,31 @@ def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant2 - assert cmd_blocks == [] + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant2 + assert search_result.cmd_blocks == variant3.commands + assert search_result.force_regenerate is True def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: """ Case where the list of studies contains a variant with up-to-date snapshots and where the second is older than the first. + The third variant has no snapshot, and must be generated from scratch. We expect to have a reference study corresponding to the first variant - and the list of commands of the second variant, because the first is completely up-to-date. - - Given a list of descendants with some variants with up-to-date snapshots, - When calling search_ref_study with the flag from_scratch=False, - Then the first variant is returned as reference study, - and the commands of the second variant are returned. + and the list of commands of the second and third variants. """ root_study = Study(id=str(uuid.uuid4()), name="root") @@ -353,6 +392,14 @@ def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: datetime.datetime(year=2023, month=1, day=2), snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), ) + # Variant 3 has no snapshot. + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=3), + snapshot_created_at=None, + ) # Add some variant commands variant1.commands = [ @@ -377,12 +424,23 @@ def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: ), ] variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "BE", "cluster_name": "BE", "cluster_type": "oil"}', + ), + ] # Check the variants - references = [variant1, variant2] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant2.commands + references = [variant1, variant2, variant3] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == [c for v in [variant2, variant3] for c in v.commands] + assert search_result.force_regenerate is True def test_search_ref_study__one_variant_completely_uptodate(self, tmp_path: Path) -> None: """ @@ -438,9 +496,10 @@ def test_search_ref_study__one_variant_completely_uptodate(self, tmp_path: Path) # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == [] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == [] + assert search_result.force_regenerate is False def test_search_ref_study__one_variant_partially_uptodate(self, tmp_path: Path) -> None: """ @@ -496,9 +555,10 @@ def test_search_ref_study__one_variant_partially_uptodate(self, tmp_path: Path) # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant1.commands[1:] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands[1:] + assert search_result.force_regenerate is False def test_search_ref_study__missing_last_command(self, tmp_path: Path) -> None: """ @@ -550,9 +610,65 @@ def test_search_ref_study__missing_last_command(self, tmp_path: Path) -> None: # Check the variants references = [variant1] - ref_study, cmd_blocks = search_ref_study(root_study, references) - assert ref_study == variant1 - assert cmd_blocks == variant1.commands + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands + assert search_result.force_regenerate is True + + def test_search_ref_study__deleted_last_command(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with an up-to-date snapshot, + but the last executed command is missing (removed). + We expect to have the list of all variant commands, so that the snapshot can be re-generated. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=2, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # The last executed command is missing. + variant1.snapshot.last_executed_command = str(uuid.uuid4()) + + # Check the variants + references = [variant1] + search_result = search_ref_study(root_study, references) + assert search_result.ref_study == variant1 + assert search_result.cmd_blocks == variant1.commands + assert search_result.force_regenerate is True class RegisterNotification: @@ -715,15 +831,9 @@ def test_generate__nominal_case( repository=variant_study_service.repository, ) - sql_statements = [] notifier = RegisterNotification() - @event.listens_for(db.session.bind, "before_cursor_execute") # type: ignore - def before_cursor_execute(conn, cursor, statement: str, parameters, context, executemany) -> None: - # note: add a breakpoint here to debug the SQL statements. - sql_statements.append(statement) - - try: + with DBStatementRecorder(db.session.bind) as db_recorder: results = generator.generate_snapshot( variant_study.id, jwt_user, @@ -731,8 +841,6 @@ def before_cursor_execute(conn, cursor, statement: str, parameters, context, exe from_scratch=False, notifier=notifier, ) - finally: - event.remove(db.session.bind, "before_cursor_execute", before_cursor_execute) # Check: the number of database queries is kept as low as possible. # We expect 5 queries: @@ -741,7 +849,7 @@ def before_cursor_execute(conn, cursor, statement: str, parameters, context, exe # - 1 query to fetch the list of variants with snapshot, commands, etc., # - 1 query to update the variant study additional_data, # - 1 query to insert the variant study snapshot. - assert len(sql_statements) == 5, "\n-------\n".join(sql_statements) + assert len(db_recorder.sql_statements) == 5, str(db_recorder) # Check: the variant generation must succeed. assert results == GenerationResultInfoDTO( @@ -826,11 +934,6 @@ def test_generate__with_user_dir( Test the generation of a variant study containing a user directory. We expect that the user directory is correctly preserved. """ - # Add a user directory to the variant study. - user_dir = Path(variant_study.snapshot_dir) / "user" - user_dir.mkdir(parents=True, exist_ok=True) - user_dir.joinpath("user_file.txt").touch() - generator = SnapshotGenerator( cache=variant_study_service.cache, raw_study_service=variant_study_service.raw_study_service, @@ -840,22 +943,25 @@ def test_generate__with_user_dir( repository=variant_study_service.repository, ) - results = generator.generate_snapshot( + # Generate the snapshot once + generator.generate_snapshot( variant_study.id, jwt_user, denormalize=False, from_scratch=False, ) - # Check the results - assert results == GenerationResultInfoDTO( - success=True, - details=[ - ("create_area", True, "Area 'North' created"), - ("create_area", True, "Area 'South' created"), - ("create_link", True, "Link between 'north' and 'south' created"), - ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), - ], + # Add a user directory to the variant study. + user_dir = Path(variant_study.snapshot_dir) / "user" + user_dir.mkdir(parents=True, exist_ok=True) + user_dir.joinpath("user_file.txt").touch() + + # Generate the snapshot again + generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, ) # Check that the user directory is correctly preserved. diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index 9db21ab220..011a6bb68d 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -91,8 +91,10 @@ def command_context_fixture(matrix_service: MatrixService) -> CommandContext: CommandContext: The CommandContext object. """ # sourcery skip: inline-immediately-returned-variable + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices() command_context = CommandContext( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(repository=Mock(spec=StudyMetadataRepository)), ) @@ -110,8 +112,10 @@ def command_factory_fixture(matrix_service: MatrixService) -> CommandFactory: Returns: CommandFactory: The CommandFactory object. """ + generator_matrix_constants = GeneratorMatrixConstants(matrix_service) + generator_matrix_constants.init_constant_matrices() return CommandFactory( - generator_matrix_constants=GeneratorMatrixConstants(matrix_service=matrix_service), + generator_matrix_constants=generator_matrix_constants, matrix_service=matrix_service, patch_service=PatchService(), ) diff --git a/tests/variantstudy/model/command/test_manage_binding_constraints.py b/tests/variantstudy/model/command/test_manage_binding_constraints.py index a1309c2e47..3387db8e6d 100644 --- a/tests/variantstudy/model/command/test_manage_binding_constraints.py +++ b/tests/variantstudy/model/command/test_manage_binding_constraints.py @@ -8,9 +8,8 @@ from antarest.study.storage.variantstudy.business.command_extractor import CommandExtractor from antarest.study.storage.variantstudy.business.command_reverter import CommandReverter from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( - default_binding_constraint_daily, - default_binding_constraint_hourly, - default_binding_constraint_weekly, + default_bc_hourly, + default_bc_weekly_daily, ) from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea @@ -109,7 +108,7 @@ def test_manage_binding_constraint( "type": "daily", } - weekly_values = default_binding_constraint_weekly.tolist() + weekly_values = default_bc_weekly_daily.tolist() bind_update = UpdateBindingConstraint( id="bd 1", enabled=False, @@ -148,7 +147,7 @@ def test_manage_binding_constraint( def test_match(command_context: CommandContext): - values = default_binding_constraint_daily.tolist() + values = default_bc_weekly_daily.tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -231,9 +230,9 @@ def test_match(command_context: CommandContext): def test_revert(command_context: CommandContext): - hourly_values = default_binding_constraint_hourly.tolist() - daily_values = default_binding_constraint_daily.tolist() - weekly_values = default_binding_constraint_weekly.tolist() + hourly_values = default_bc_hourly.tolist() + daily_values = default_bc_weekly_daily.tolist() + weekly_values = default_bc_weekly_daily.tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -339,7 +338,7 @@ def test_revert(command_context: CommandContext): def test_create_diff(command_context: CommandContext): - values_a = np.random.rand(365, 3).tolist() + values_a = np.random.rand(366, 3).tolist() base = CreateBindingConstraint( name="foo", enabled=False, @@ -350,7 +349,7 @@ def test_create_diff(command_context: CommandContext): command_context=command_context, ) - values_b = np.random.rand(8760, 3).tolist() + values_b = np.random.rand(8784, 3).tolist() other_match = CreateBindingConstraint( name="foo", enabled=True, @@ -372,7 +371,7 @@ def test_create_diff(command_context: CommandContext): ) ] - values = default_binding_constraint_daily.tolist() + values = default_bc_weekly_daily.tolist() base = UpdateBindingConstraint( id="foo", enabled=False, diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 7acbce4530..63ac7293b8 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -75,8 +75,7 @@ def test_commands_service( variant_study_service: VariantStudyService, ) -> None: # Initialize the default matrix constants - # noinspection PyProtectedMember - generator_matrix_constants._init() + generator_matrix_constants.init_constant_matrices() params = RequestParameters(user=jwt_user) diff --git a/webapp/package-lock.json b/webapp/package-lock.json index 19ba7c4221..6ce088d97a 100644 --- a/webapp/package-lock.json +++ b/webapp/package-lock.json @@ -1,12 +1,12 @@ { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "dependencies": { "@emotion/react": "11.11.1", "@emotion/styled": "11.11.0", @@ -14,6 +14,7 @@ "@mui/icons-material": "5.14.11", "@mui/lab": "5.0.0-alpha.146", "@mui/material": "5.14.11", + "@mui/x-date-pickers": "6.18.3", "@reduxjs/toolkit": "1.9.6", "@types/d3": "5.16.0", "@types/draft-convert": "2.1.5", @@ -44,7 +45,7 @@ "js-cookie": "3.0.5", "jwt-decode": "3.1.2", "lodash": "4.17.21", - "material-react-table": "1.15.0", + "material-react-table": "2.0.5", "moment": "2.29.4", "notistack": "3.0.1", "os": "0.1.2", @@ -2191,9 +2192,9 @@ "integrity": "sha512-x/rqGMdzj+fWZvCOYForTghzbtqPDZ5gPwaoNGHdgDfF2QA/XZbCBp4Moo5scrkAMPhB7z26XM/AaHuIJdgauA==" }, "node_modules/@babel/runtime": { - "version": "7.23.1", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.1.tgz", - "integrity": "sha512-hC2v6p8ZSI/W0HUzh3V8C5g+NwSKzKPtJwSpTjwl0o297GP9+ZLQSkdvHz46CM3LqyoXxq+5G9komY+eSqSO0g==", + "version": "7.23.5", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.5.tgz", + "integrity": "sha512-NdUTHcPe4C99WxPub+K9l9tK5/lV4UXIoaHSYgzco9BCyjKAAwzdBI+wWtYqHt7LJdbo74ZjRPJgzVweq1sz0w==", "dependencies": { "regenerator-runtime": "^0.14.0" }, @@ -2769,9 +2770,9 @@ } }, "node_modules/@floating-ui/react-dom": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.2.tgz", - "integrity": "sha512-5qhlDvjaLmAst/rKb3VdlCinwTF4EYMiVxuuc/HVUjs46W0zgtbMmAZ1UTsDrRTxRmUEzl92mOtWbeeXL26lSQ==", + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.0.4.tgz", + "integrity": "sha512-CF8k2rgKeh/49UrnIBs4BdxPUV6vize/Db1d/YbCLyp9GiVZ0BEwf5AiDSxJRCr6yOkGqTFHtmrULxkEfYZ7dQ==", "dependencies": { "@floating-ui/dom": "^1.5.1" }, @@ -3570,11 +3571,11 @@ } }, "node_modules/@mui/types": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.2.4.tgz", - "integrity": "sha512-LBcwa8rN84bKF+f5sDyku42w1NTxaPgPyYKODsh01U1fVstTClbUoSA96oyRBnSNyEiAVjKm6Gwx9vjR+xyqHA==", + "version": "7.2.10", + "resolved": "https://registry.npmjs.org/@mui/types/-/types-7.2.10.tgz", + "integrity": "sha512-wX1vbDC+lzF7FlhT6A3ffRZgEoKWPF8VqRoTu4lZwouFX2t90KyCMsgepMw5DxLak1BSp/KP86CmtZttikb/gQ==", "peerDependencies": { - "@types/react": "*" + "@types/react": "^17.0.0 || ^18.0.0" }, "peerDependenciesMeta": { "@types/react": { @@ -3583,12 +3584,12 @@ } }, "node_modules/@mui/utils": { - "version": "5.14.11", - "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-5.14.11.tgz", - "integrity": "sha512-fmkIiCPKyDssYrJ5qk+dime1nlO3dmWfCtaPY/uVBqCRMBZ11JhddB9m8sjI2mgqQQwRJG5bq3biaosNdU/s4Q==", + "version": "5.14.20", + "resolved": "https://registry.npmjs.org/@mui/utils/-/utils-5.14.20.tgz", + "integrity": "sha512-Y6yL5MoFmtQml20DZnaaK1znrCEwG6/vRSzW8PKOTrzhyqKIql0FazZRUR7sA5EPASgiyKZfq0FPwISRXm5NdA==", "dependencies": { - "@babel/runtime": "^7.22.15", - "@types/prop-types": "^15.7.5", + "@babel/runtime": "^7.23.4", + "@types/prop-types": "^15.7.11", "prop-types": "^15.8.1", "react-is": "^18.2.0" }, @@ -3597,7 +3598,7 @@ }, "funding": { "type": "opencollective", - "url": "https://opencollective.com/mui" + "url": "https://opencollective.com/mui-org" }, "peerDependencies": { "@types/react": "^17.0.0 || ^18.0.0", @@ -3609,6 +3610,102 @@ } } }, + "node_modules/@mui/x-date-pickers": { + "version": "6.18.3", + "resolved": "https://registry.npmjs.org/@mui/x-date-pickers/-/x-date-pickers-6.18.3.tgz", + "integrity": "sha512-DmJrAAr6EfhuWA9yubANAdeQayAbUppCezdhxkYKwn38G8+HJPZBol0V5fKji+B4jMxruO78lkQYsGUxVxaR7A==", + "dependencies": { + "@babel/runtime": "^7.23.2", + "@mui/base": "^5.0.0-beta.22", + "@mui/utils": "^5.14.16", + "@types/react-transition-group": "^4.4.8", + "clsx": "^2.0.0", + "prop-types": "^15.8.1", + "react-transition-group": "^4.4.5" + }, + "engines": { + "node": ">=14.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui" + }, + "peerDependencies": { + "@emotion/react": "^11.9.0", + "@emotion/styled": "^11.8.1", + "@mui/material": "^5.8.6", + "@mui/system": "^5.8.0", + "date-fns": "^2.25.0", + "date-fns-jalali": "^2.13.0-0", + "dayjs": "^1.10.7", + "luxon": "^3.0.2", + "moment": "^2.29.4", + "moment-hijri": "^2.1.2", + "moment-jalaali": "^0.7.4 || ^0.8.0 || ^0.9.0 || ^0.10.0", + "react": "^17.0.0 || ^18.0.0", + "react-dom": "^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@emotion/react": { + "optional": true + }, + "@emotion/styled": { + "optional": true + }, + "date-fns": { + "optional": true + }, + "date-fns-jalali": { + "optional": true + }, + "dayjs": { + "optional": true + }, + "luxon": { + "optional": true + }, + "moment": { + "optional": true + }, + "moment-hijri": { + "optional": true + }, + "moment-jalaali": { + "optional": true + } + } + }, + "node_modules/@mui/x-date-pickers/node_modules/@mui/base": { + "version": "5.0.0-beta.26", + "resolved": "https://registry.npmjs.org/@mui/base/-/base-5.0.0-beta.26.tgz", + "integrity": "sha512-gPMRKC84VRw+tjqYoyBzyrBUqHQucMXdlBpYazHa5rCXrb91fYEQk5SqQ2U5kjxx9QxZxTBvWAmZ6DblIgaGhQ==", + "dependencies": { + "@babel/runtime": "^7.23.4", + "@floating-ui/react-dom": "^2.0.4", + "@mui/types": "^7.2.10", + "@mui/utils": "^5.14.20", + "@popperjs/core": "^2.11.8", + "clsx": "^2.0.0", + "prop-types": "^15.8.1" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/mui-org" + }, + "peerDependencies": { + "@types/react": "^17.0.0 || ^18.0.0", + "react": "^17.0.0 || ^18.0.0", + "react-dom": "^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@mui/x-tree-view": { "version": "6.0.0-alpha.1", "resolved": "https://registry.npmjs.org/@mui/x-tree-view/-/x-tree-view-6.0.0-alpha.1.tgz", @@ -4596,11 +4693,11 @@ } }, "node_modules/@tanstack/react-table": { - "version": "8.10.3", - "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.3.tgz", - "integrity": "sha512-Qya1cJ+91arAlW7IRDWksRDnYw28O446jJ/ljkRSc663EaftJoBCAU10M+VV1K6MpCBLrXq1BD5IQc1zj/ZEjA==", + "version": "8.10.7", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.10.7.tgz", + "integrity": "sha512-bXhjA7xsTcsW8JPTTYlUg/FuBpn8MNjiEPhkNhIGCUR6iRQM2+WEco4OBpvDeVcR9SE+bmWLzdfiY7bCbCSVuA==", "dependencies": { - "@tanstack/table-core": "8.10.3" + "@tanstack/table-core": "8.10.7" }, "engines": { "node": ">=12" @@ -4615,24 +4712,25 @@ } }, "node_modules/@tanstack/react-virtual": { - "version": "3.0.0-beta.60", - "resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.0-beta.60.tgz", - "integrity": "sha512-F0wL9+byp7lf/tH6U5LW0ZjBqs+hrMXJrj5xcIGcklI0pggvjzMNW9DdIBcyltPNr6hmHQ0wt8FDGe1n1ZAThA==", + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.0.1.tgz", + "integrity": "sha512-IFOFuRUTaiM/yibty9qQ9BfycQnYXIDHGP2+cU+0LrFFGNhVxCXSQnaY6wkX8uJVteFEBjUondX0Hmpp7TNcag==", "dependencies": { - "@tanstack/virtual-core": "3.0.0-beta.60" + "@tanstack/virtual-core": "3.0.0" }, "funding": { "type": "github", "url": "https://github.com/sponsors/tannerlinsley" }, "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + "react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" } }, "node_modules/@tanstack/table-core": { - "version": "8.10.3", - "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.3.tgz", - "integrity": "sha512-hJ55YfJlWbfzRROfcyA/kC1aZr/shsLA8XNAwN8jXylhYWGLnPmiJJISrUfj4dMMWRiFi0xBlnlC7MLH+zSrcw==", + "version": "8.10.7", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.10.7.tgz", + "integrity": "sha512-KQk5OMg5OH6rmbHZxuNROvdI+hKDIUxANaHlV+dPlNN7ED3qYQ/WkpY2qlXww1SIdeMlkIhpN/2L00rof0fXFw==", "engines": { "node": ">=12" }, @@ -4642,9 +4740,9 @@ } }, "node_modules/@tanstack/virtual-core": { - "version": "3.0.0-beta.60", - "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0-beta.60.tgz", - "integrity": "sha512-QlCdhsV1+JIf0c0U6ge6SQmpwsyAT0oQaOSZk50AtEeAyQl9tQrd6qCHAslxQpgphrfe945abvKG8uYvw3hIGA==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.0.0.tgz", + "integrity": "sha512-SYXOBTjJb05rXa2vl55TTwO40A6wKu0R5i1qQwhJYNDIqaIGF7D0HsLw+pJAyi2OvntlEIVusx3xtbbgSUi6zg==", "funding": { "type": "github", "url": "https://github.com/sponsors/tannerlinsley" @@ -5238,9 +5336,9 @@ "integrity": "sha512-+68kP9yzs4LMp7VNh8gdzMSPZFL44MLGqiHWvttYJe+6qnuVr4Ek9wSBQoveqY/r+LwjCcU29kNVkidwim+kYA==" }, "node_modules/@types/prop-types": { - "version": "15.7.8", - "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.8.tgz", - "integrity": "sha512-kMpQpfZKSCBqltAJwskgePRaYRFukDkm1oItcAbC3gNELR20XIBcN9VRgg4+m8DKsTfkWeA4m4Imp4DDuWy7FQ==" + "version": "15.7.11", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.11.tgz", + "integrity": "sha512-ga8y9v9uyeiLdpKddhxYQkxNDrfvuPrlFb0N1qnZZByvcElJaXthF1UhvCh9TLWJBEHeNtdnbysW7Y6Uq8CVng==" }, "node_modules/@types/q": { "version": "1.5.6", @@ -5332,9 +5430,9 @@ } }, "node_modules/@types/react-transition-group": { - "version": "4.4.7", - "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.7.tgz", - "integrity": "sha512-ICCyBl5mvyqYp8Qeq9B5G/fyBSRC0zx3XM3sCC6KkcMsNeAHqXBKkmat4GqdJET5jtYUpZXrxI5flve5qhi2Eg==", + "version": "4.4.9", + "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.9.tgz", + "integrity": "sha512-ZVNmWumUIh5NhH8aMD9CR2hdW0fNuYInlocZHaZ+dgk/1K49j1w/HoAuK1ki+pgscQrOFRTlXeoURtuzEkV3dg==", "dependencies": { "@types/react": "*" } @@ -15102,29 +15200,30 @@ "integrity": "sha512-6qE4B9deFBIa9YSpOc9O0Sgc43zTeVYbgDT5veRKSlB2+ZuHNoVVxA1L/ckMUayV9Ay9y7Z/SZCLcGteW9i7bg==" }, "node_modules/material-react-table": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/material-react-table/-/material-react-table-1.15.0.tgz", - "integrity": "sha512-f59XPZ+jFErRAs3ym3cHsK6kBLCrYJGX6GoF473V1/gCpsNbkWEEdmCVMpB8ycOUNDEXtnRDMZzk3LjTMd6wpg==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/material-react-table/-/material-react-table-2.0.5.tgz", + "integrity": "sha512-axRrqa/2QQ+AO3SiJbOtSyemlHX0S03X+IXW72z344d3LT+u/jsKiAmdWMLTN8ARScYMAN5NgrArujiLEmftSQ==", "dependencies": { "@tanstack/match-sorter-utils": "8.8.4", - "@tanstack/react-table": "8.10.3", - "@tanstack/react-virtual": "3.0.0-beta.60", + "@tanstack/react-table": "8.10.7", + "@tanstack/react-virtual": "3.0.1", "highlight-words": "1.2.2" }, "engines": { - "node": ">=14" + "node": ">=16" }, "funding": { "type": "github", "url": "https://github.com/sponsors/kevinvandy" }, "peerDependencies": { - "@emotion/react": ">=11", - "@emotion/styled": ">=11", - "@mui/icons-material": ">=5", - "@mui/material": ">=5", - "react": ">=17.0", - "react-dom": ">=17.0" + "@emotion/react": ">=11.11", + "@emotion/styled": ">=11.11", + "@mui/icons-material": ">=5.11", + "@mui/material": ">=5.13", + "@mui/x-date-pickers": ">=6.15.0", + "react": ">=18.0", + "react-dom": ">=18.0" } }, "node_modules/math-log2": { diff --git a/webapp/package.json b/webapp/package.json index 1bc82aae3e..6b687cb3ae 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -1,6 +1,6 @@ { "name": "antares-web", - "version": "2.16.0", + "version": "2.16.1", "private": true, "engines": { "node": "18.16.1" @@ -12,6 +12,7 @@ "@mui/icons-material": "5.14.11", "@mui/lab": "5.0.0-alpha.146", "@mui/material": "5.14.11", + "@mui/x-date-pickers": "6.18.3", "@reduxjs/toolkit": "1.9.6", "@types/d3": "5.16.0", "@types/draft-convert": "2.1.5", @@ -42,7 +43,7 @@ "js-cookie": "3.0.5", "jwt-decode": "3.1.2", "lodash": "4.17.21", - "material-react-table": "1.15.0", + "material-react-table": "2.0.5", "moment": "2.29.4", "notistack": "3.0.1", "os": "0.1.2", diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx index 42a84ca56c..8b65d57f44 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/index.tsx @@ -65,7 +65,7 @@ function Hydro() { return ( <Root> <DocLink to={`${ACTIVE_WINDOWS_DOC_PATH}#hydro`} isAbsolute /> - <TabWrapper study={study} tabStyle="normal" tabList={tabList} /> + <TabWrapper study={study} tabList={tabList} isScrollable /> </Root> ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts index 3103e60f79..0bd3dab7d4 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Hydro/style.ts @@ -4,6 +4,7 @@ export const Root = styled(Box)(({ theme }) => ({ width: "100%", height: "100%", padding: theme.spacing(2), + paddingTop: 0, display: "flex", overflow: "auto", })); diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx index 0f98902d24..30dc1bc7b7 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Renewables/Form.tsx @@ -65,24 +65,21 @@ function RenewablesForm() { key={study.id + areaId} config={{ defaultValues }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > <Fields /> - <Box - sx={{ - width: 1, - display: "flex", - flexDirection: "column", - height: "500px", - }} - > - <Matrix - study={study} - areaId={areaId} - clusterId={nameToId(clusterId)} - /> - </Box> </Form> + <Box + sx={{ + width: 1, + display: "flex", + flexDirection: "column", + py: 3, + height: "75vh", + }} + > + <Matrix study={study} areaId={areaId} clusterId={nameToId(clusterId)} /> + </Box> </Box> ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx index 0e04b378a0..8485fd29e6 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Fields.tsx @@ -101,8 +101,8 @@ function Fields() { message: t("form.field.minValue", { 0: 0 }), }, max: { - value: 1, - message: t("form.field.maxValue", { 0: 1 }), + value: 100, + message: t("form.field.maxValue", { 0: 100 }), }, }} /> @@ -116,8 +116,8 @@ function Fields() { message: t("form.field.minValue", { 0: 0 }), }, max: { - value: 1, - message: t("form.field.maxValue", { 0: 1 }), + value: 100, + message: t("form.field.maxValue", { 0: 100 }), }, }} /> diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx index c3f2f717dd..40aa166664 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Storages/Form.tsx @@ -3,6 +3,7 @@ import { Box, Button } from "@mui/material"; import { useParams, useOutletContext, useNavigate } from "react-router-dom"; import ArrowBackIcon from "@mui/icons-material/ArrowBack"; import { useTranslation } from "react-i18next"; +import * as RA from "ramda-adjunct"; import { StudyMetadata } from "../../../../../../../common/types"; import Form from "../../../../../../common/Form"; import Fields from "./Fields"; @@ -27,17 +28,34 @@ function StorageForm() { }); // prevent re-fetch while useNavigateOnCondition event occurs - const defaultValues = useCallback(() => { - return getStorage(study.id, areaId, storageId); + const defaultValues = useCallback( + async () => { + const storage = await getStorage(study.id, areaId, storageId); + return { + ...storage, + // Convert to percentage ([0-1] -> [0-100]) + efficiency: storage.efficiency * 100, + initialLevel: storage.initialLevel * 100, + }; + }, // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + [], + ); //////////////////////////////////////////////////////////////// // Event handlers //////////////////////////////////////////////////////////////// const handleSubmit = ({ dirtyValues }: SubmitHandlerPlus<Storage>) => { - return updateStorage(study.id, areaId, storageId, dirtyValues); + const newValues = { ...dirtyValues }; + // Convert to ratio ([0-100] -> [0-1]) + if (RA.isNumber(newValues.efficiency)) { + newValues.efficiency /= 100; + } + if (RA.isNumber(newValues.initialLevel)) { + newValues.initialLevel /= 100; + } + return updateStorage(study.id, areaId, storageId, newValues); }; //////////////////////////////////////////////////////////////// @@ -61,24 +79,21 @@ function StorageForm() { defaultValues, }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > <Fields /> - <Box - sx={{ - width: 1, - display: "flex", - flexDirection: "column", - height: "500px", - }} - > - <Matrix - study={study} - areaId={areaId} - storageId={nameToId(storageId)} - /> - </Box> </Form> + <Box + sx={{ + width: 1, + display: "flex", + flexDirection: "column", + py: 3, + height: "75vh", + }} + > + <Matrix study={study} areaId={areaId} storageId={nameToId(storageId)} /> + </Box> </Box> ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx index 3253167d5a..cf5cb2fc66 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Fields.tsx @@ -233,6 +233,7 @@ function Fields() { message: t("form.field.maxValue", { 0: 1 }), }, }} + inputProps={{ step: 0.1 }} /> <NumberFE label={t("study.modelization.clusters.volatilityPlanned")} @@ -248,6 +249,7 @@ function Fields() { message: t("form.field.maxValue", { 0: 1 }), }, }} + inputProps={{ step: 0.1 }} /> <SelectFE label={t("study.modelization.clusters.lawForced")} diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx index ee16dbbf6d..810338b581 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/Areas/Thermal/Form.tsx @@ -63,24 +63,21 @@ function ThermalForm() { key={study.id + areaId} config={{ defaultValues }} onSubmit={handleSubmit} - autoSubmit + enableUndoRedo > <Fields /> - <Box - sx={{ - width: 1, - display: "flex", - flexDirection: "column", - height: "500px", - }} - > - <Matrix - study={study} - areaId={areaId} - clusterId={nameToId(clusterId)} - /> - </Box> </Form> + <Box + sx={{ + width: 1, + display: "flex", + flexDirection: "column", + py: 3, + height: "75vh", + }} + > + <Matrix study={study} areaId={areaId} clusterId={nameToId(clusterId)} /> + </Box> </Box> ); } diff --git a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx index 5d81c34155..06b1a4da83 100644 --- a/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Modelization/index.tsx @@ -1,57 +1,80 @@ -import { useMemo } from "react"; -import { useOutletContext } from "react-router-dom"; +import { useEffect, useMemo } from "react"; +import { useNavigate, useOutletContext, useParams } from "react-router-dom"; import { Box } from "@mui/material"; import { useTranslation } from "react-i18next"; import { StudyMetadata } from "../../../../../common/types"; import TabWrapper from "../TabWrapper"; import useAppSelector from "../../../../../redux/hooks/useAppSelector"; -import { getCurrentAreaId } from "../../../../../redux/selectors"; +import { getAreas, getCurrentAreaId } from "../../../../../redux/selectors"; +import useAppDispatch from "../../../../../redux/hooks/useAppDispatch"; +import { setCurrentArea } from "../../../../../redux/ducks/studySyntheses"; function Modelization() { const { study } = useOutletContext<{ study: StudyMetadata }>(); - const areaId = useAppSelector(getCurrentAreaId); const [t] = useTranslation(); + const dispatch = useAppDispatch(); + const navigate = useNavigate(); + const { areaId: paramAreaId } = useParams(); + const areas = useAppSelector((state) => getAreas(state, study.id)); + const areaId = useAppSelector(getCurrentAreaId); + + useEffect(() => { + if (!areaId && paramAreaId) { + dispatch(setCurrentArea(paramAreaId)); + } + }, [paramAreaId, dispatch, areaId]); + + const tabList = useMemo(() => { + const basePath = `/studies/${study.id}/explore/modelization`; - const tabList = useMemo( - () => [ + const handleAreasClick = () => { + if (areaId.length === 0 && areas.length > 0) { + const firstAreaId = areas[0].id ?? null; + + if (firstAreaId) { + dispatch(setCurrentArea(firstAreaId)); + navigate(`${basePath}/area/${firstAreaId}`, { replace: true }); + } + } + }; + + return [ { label: t("study.modelization.map"), - path: `/studies/${study?.id}/explore/modelization/map`, + path: `${basePath}/map`, }, { label: t("study.areas"), - path: `/studies/${study?.id}/explore/modelization/area/${areaId}`, + path: `${basePath}/area/${areaId}`, + onClick: handleAreasClick, }, { label: t("study.links"), - path: `/studies/${study?.id}/explore/modelization/links`, + path: `${basePath}/links`, }, { label: t("study.bindingconstraints"), - path: `/studies/${study?.id}/explore/modelization/bindingcontraint`, + path: `${basePath}/bindingcontraint`, }, { label: t("study.debug"), - path: `/studies/${study?.id}/explore/modelization/debug`, + path: `${basePath}/debug`, }, { label: t("study.modelization.tableMode"), - path: `/studies/${study?.id}/explore/modelization/tablemode`, + path: `${basePath}/tablemode`, }, - ], - [areaId, study?.id, t], - ); + ]; + }, [areaId, areas, dispatch, navigate, study?.id, t]); return ( <Box - width="100%" - flexGrow={1} - display="flex" - flexDirection="column" - justifyContent="center" - alignItems="center" - boxSizing="border-box" - overflow="hidden" + sx={{ + display: "flex", + flex: 1, + width: 1, + overflow: "hidden", + }} > <TabWrapper study={study} tabStyle="withoutBorder" tabList={tabList} /> </Box> diff --git a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx index 7875ff7246..1f6b4a9bdc 100644 --- a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx +++ b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/index.tsx @@ -296,6 +296,7 @@ function ResultDetails() { { value: DataType.Thermal, label: "Thermal plants" }, { value: DataType.Renewable, label: "Ren. clusters" }, { value: DataType.Record, label: "RecordYears" }, + { value: DataType.STStorage, label: "ST Storages" }, ]} size="small" variant="outlined" diff --git a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts index f140205498..cba62478c8 100644 --- a/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts +++ b/webapp/src/components/App/Singlestudy/explore/Results/ResultDetails/utils.ts @@ -11,6 +11,7 @@ export enum DataType { Thermal = "details", Renewable = "details-res", Record = "id", + STStorage = "details-STstorage", } export enum Timestep { diff --git a/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx b/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx index 7a61f90fb6..482801c482 100644 --- a/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx +++ b/webapp/src/components/App/Singlestudy/explore/TabWrapper.tsx @@ -1,5 +1,5 @@ /* eslint-disable react/jsx-props-no-spreading */ -import { useEffect } from "react"; +import { useEffect, useState } from "react"; import * as React from "react"; import { styled, SxProps, Theme } from "@mui/material"; import Tabs from "@mui/material/Tabs"; @@ -28,19 +28,32 @@ export const StyledTab = styled(Tabs, { }), ); +interface TabItem { + label: string; + path: string; + onClick?: () => void; +} + interface Props { study: StudyMetadata | undefined; - tabList: Array<{ label: string; path: string }>; + tabList: TabItem[]; border?: boolean; tabStyle?: "normal" | "withoutBorder"; sx?: SxProps<Theme>; + isScrollable?: boolean; } -function TabWrapper(props: Props) { - const { study, tabList, border, tabStyle, sx } = props; +function TabWrapper({ + study, + tabList, + border, + tabStyle, + sx, + isScrollable = false, +}: Props) { const location = useLocation(); const navigate = useNavigate(); - const [selectedTab, setSelectedTab] = React.useState(0); + const [selectedTab, setSelectedTab] = useState(0); useEffect(() => { const getTabIndex = (): number => { @@ -66,6 +79,11 @@ function TabWrapper(props: Props) { const handleChange = (event: React.SyntheticEvent, newValue: number) => { setSelectedTab(newValue); navigate(tabList[newValue].path); + + const onTabClick = tabList[newValue].onClick; + if (onTabClick) { + onTabClick(); + } }; //////////////////////////////////////////////////////////////// @@ -87,16 +105,15 @@ function TabWrapper(props: Props) { )} > <StyledTab - border={border === true} + border={border} tabStyle={tabStyle} value={selectedTab} onChange={handleChange} - variant="scrollable" + variant={isScrollable ? "scrollable" : "standard"} sx={{ width: "98%", - ...(border === true - ? { borderBottom: 1, borderColor: "divider" } - : {}), + borderBottom: border ? 1 : 0, + borderColor: border ? "divider" : "inherit", }} > {tabList.map((tab) => ( @@ -108,9 +125,4 @@ function TabWrapper(props: Props) { ); } -TabWrapper.defaultProps = { - border: undefined, - tabStyle: "normal", -}; - export default TabWrapper; diff --git a/webapp/src/components/App/Studies/StudyCard.tsx b/webapp/src/components/App/Studies/StudyCard.tsx index d6f6e91dd6..976ee7acd8 100644 --- a/webapp/src/components/App/Studies/StudyCard.tsx +++ b/webapp/src/components/App/Studies/StudyCard.tsx @@ -1,5 +1,5 @@ import { memo, useState } from "react"; -import { NavLink } from "react-router-dom"; +import { NavLink, useNavigate } from "react-router-dom"; import { AxiosError } from "axios"; import { useSnackbar } from "notistack"; import { useTranslation } from "react-i18next"; @@ -16,6 +16,7 @@ import { ListItemText, Tooltip, Chip, + Divider, } from "@mui/material"; import { styled } from "@mui/material/styles"; import { indigo } from "@mui/material/colors"; @@ -93,6 +94,7 @@ const StudyCard = memo((props: Props) => { const study = useAppSelector((state) => getStudy(state, id)); const isFavorite = useAppSelector((state) => isStudyFavorite(state, id)); const dispatch = useAppDispatch(); + const navigate = useNavigate(); //////////////////////////////////////////////////////////////// // Event Handlers @@ -218,7 +220,13 @@ const StudyCard = memo((props: Props) => { </Box> )} <CardContent - sx={{ flexGrow: 1, display: "flex", flexDirection: "column" }} + sx={{ + flexGrow: 1, + display: "flex", + flexDirection: "column", + overflow: "auto", + maxHeight: "calc(100% - 48px)", + }} > <Box sx={{ @@ -236,6 +244,7 @@ const StudyCard = memo((props: Props) => { noWrap variant="h6" component="div" + onClick={() => navigate(`/studies/${study.id}`)} sx={{ color: "white", boxSizing: "border-box", @@ -244,6 +253,11 @@ const StudyCard = memo((props: Props) => { whiteSpace: "nowrap", textOverflow: "ellipsis", overflow: "hidden", + cursor: "pointer", + "&:hover": { + color: "primary.main", + textDecoration: "underline", + }, }} > {study.name} @@ -289,6 +303,10 @@ const StudyCard = memo((props: Props) => { flexFlow: "nowrap", px: 0.5, paddingBottom: 0.5, + width: "90%", + whiteSpace: "nowrap", + textOverflow: "ellipsis", + overflow: "hidden", }} > {study.folder} @@ -308,8 +326,6 @@ const StudyCard = memo((props: Props) => { sx={{ display: "flex", maxWidth: "65%", - flexDirection: "row", - justifyContent: "flex-start", alignItems: "center", }} > @@ -321,38 +337,27 @@ const StudyCard = memo((props: Props) => { <Box sx={{ display: "flex", - flexDirection: "row", - justifyContent: "flex-start", - alignItems: "center", + gap: 1, }} > - <UpdateOutlinedIcon sx={{ color: "text.secondary", mr: 1 }} /> + <UpdateOutlinedIcon sx={{ color: "text.secondary" }} /> <TinyText> {buildModificationDate(study.modificationDate, t, i18n.language)} </TinyText> + <Divider flexItem orientation="vertical" /> + <TinyText>{`v${displayVersionName(study.version)}`}</TinyText> </Box> </Box> <Box sx={{ - width: "100%", display: "flex", - flexDirection: "row", - justifyContent: "space-between", - alignItems: "center", + textOverflow: "ellipsis", + overflow: "hidden", + mt: 1, }} > - <Box - sx={{ - display: "flex", - flexDirection: "row", - justifyContent: "flex-start", - alignItems: "center", - }} - > - <PersonOutlineIcon sx={{ color: "text.secondary", mr: 1 }} /> - <TinyText>{study.owner.name}</TinyText> - </Box> - <TinyText>{`v${displayVersionName(study.version)}`}</TinyText> + <PersonOutlineIcon sx={{ color: "text.secondary" }} /> + <TinyText>{study.owner.name}</TinyText> </Box> <Box sx={{ @@ -364,8 +369,7 @@ const StudyCard = memo((props: Props) => { flexWrap: "wrap", justifyContent: "flex-start", alignItems: "center", - overflowX: "hidden", - overflowY: "auto", + gap: 0.5, ".MuiChip-root": { color: "black", @@ -377,24 +381,24 @@ const StudyCard = memo((props: Props) => { icon={<AltRouteOutlinedIcon />} label={t("studies.variant").toLowerCase()} color="primary" + size="small" /> )} <Chip label={study.workspace} - variant="filled" + size="small" sx={{ bgcolor: study.managed ? "secondary.main" : "gray", }} /> - {study.tags && - study.tags.map((elm) => ( - <Chip - key={elm} - label={elm} - variant="filled" - sx={{ bgcolor: indigo[300] }} - /> - ))} + {study.tags?.map((tag) => ( + <Chip + key={tag} + label={tag} + size="small" + sx={{ bgcolor: indigo[300] }} + /> + ))} </Box> </CardContent> <CardActions> diff --git a/webapp/src/components/common/EditableMatrix/index.tsx b/webapp/src/components/common/EditableMatrix/index.tsx index 91bfeb7d81..8176cc0f5c 100644 --- a/webapp/src/components/common/EditableMatrix/index.tsx +++ b/webapp/src/components/common/EditableMatrix/index.tsx @@ -13,7 +13,11 @@ import "handsontable/dist/handsontable.min.css"; import MatrixGraphView from "./MatrixGraphView"; import { Root } from "./style"; import "./style.css"; -import { computeStats, createDateFromIndex, slice } from "./utils"; +import { + computeStats, + createDateFromIndex, + cellChangesToMatrixEdits, +} from "./utils"; import Handsontable from "../Handsontable"; const logError = debug("antares:editablematrix:error"); @@ -68,18 +72,19 @@ function EditableMatrix(props: PropTypes) { // Event Handlers //////////////////////////////////////////////////////////////// - const handleSlice = (change: CellChange[], source: string) => { - const isChanged = change.map((item) => { - if (parseFloat(item[2]) === parseFloat(item[3])) { - return; - } - return item; - }); - if (onUpdate) { - const edit = slice( - isChanged.filter((e) => e !== undefined) as CellChange[], - ); - onUpdate(edit, source); + const handleSlice = (changes: CellChange[], source: string) => { + if (!onUpdate) { + return; + } + + const filteredChanges = changes.filter( + ([, , oldValue, newValue]) => + parseFloat(oldValue) !== parseFloat(newValue), + ); + + if (filteredChanges.length > 0) { + const edits = cellChangesToMatrixEdits(filteredChanges, matrixTime); + onUpdate(edits, source); } }; diff --git a/webapp/src/components/common/EditableMatrix/utils.ts b/webapp/src/components/common/EditableMatrix/utils.ts index 020455cd5e..341fb8d2a3 100644 --- a/webapp/src/components/common/EditableMatrix/utils.ts +++ b/webapp/src/components/common/EditableMatrix/utils.ts @@ -82,14 +82,19 @@ export const createDateFromIndex = ( return date; }; -export const slice = (tab: CellChange[]): MatrixEditDTO[] => { - return tab.map((cell) => { +export const cellChangesToMatrixEdits = ( + cellChanges: CellChange[], + matrixTime: boolean, +): MatrixEditDTO[] => + cellChanges.map(([row, column, , value]) => { + const rowIndex = parseFloat(row.toString()); + const colIndex = parseFloat(column.toString()) - (matrixTime ? 1 : 0); + return { - coordinates: [[cell[0] as number, (cell[1] as number) - 1]], - operation: { operation: Operator.EQ, value: parseInt(cell[3], 10) }, + coordinates: [[rowIndex, colIndex]], + operation: { operation: Operator.EQ, value: parseFloat(value) }, }; }); -}; export const computeStats = ( statsType: string, diff --git a/webapp/src/components/common/GroupedDataTable/index.tsx b/webapp/src/components/common/GroupedDataTable/index.tsx index 45c23ab55d..17f1df92a1 100644 --- a/webapp/src/components/common/GroupedDataTable/index.tsx +++ b/webapp/src/components/common/GroupedDataTable/index.tsx @@ -5,7 +5,8 @@ import AddIcon from "@mui/icons-material/Add"; import { Button } from "@mui/material"; import DeleteIcon from "@mui/icons-material/Delete"; import ContentCopyIcon from "@mui/icons-material/ContentCopy"; -import MaterialReactTable, { +import { + MaterialReactTable, MRT_RowSelectionState, MRT_ToggleFiltersButton, MRT_ToggleGlobalFilterButton, diff --git a/webapp/src/redux/ducks/studySyntheses.ts b/webapp/src/redux/ducks/studySyntheses.ts index 9aad642d31..a7ae6cdc52 100644 --- a/webapp/src/redux/ducks/studySyntheses.ts +++ b/webapp/src/redux/ducks/studySyntheses.ts @@ -87,14 +87,7 @@ const initDefaultAreaLinkSelection = ( studyData?: FileStudyTreeConfigDTO, ): void => { if (studyData) { - // Set current area - const areas = Object.keys(studyData.areas); - if (areas.length > 0) { - dispatch(setCurrentArea(areas[0])); - } else { - dispatch(setCurrentArea("")); - } - + dispatch(setCurrentArea("")); dispatch(setCurrentLink("")); } else { dispatch(setCurrentArea(""));