diff --git a/antarest/__init__.py b/antarest/__init__.py index 4023e2ca7b..447ab7edd4 100644 --- a/antarest/__init__.py +++ b/antarest/__init__.py @@ -7,9 +7,9 @@ # Standard project metadata -__version__ = "2.15.1" +__version__ = "2.15.2" __author__ = "RTE, Antares Web Team" -__date__ = "2023-10-05" +__date__ = "2023-10-11" # noinspection SpellCheckingInspection __credits__ = "(c) Réseau de Transport de l’Électricité (RTE)" diff --git a/antarest/core/cache/business/redis_cache.py b/antarest/core/cache/business/redis_cache.py index 3c72845eb8..176e583b76 100644 --- a/antarest/core/cache/business/redis_cache.py +++ b/antarest/core/cache/business/redis_cache.py @@ -21,6 +21,7 @@ def __init__(self, redis_client: Redis): # type: ignore self.redis = redis_client def start(self) -> None: + # Assuming the Redis service is already running; no need to start it here. pass def put(self, id: str, data: JSON, duration: int = 3600) -> None: diff --git a/antarest/core/config.py b/antarest/core/config.py index 4c232fa475..b48be8ded0 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -1,16 +1,14 @@ -import logging +import multiprocessing import tempfile -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import yaml from antarest.core.model import JSON from antarest.core.roles import RoleType -logger = logging.getLogger(__name__) - @dataclass(frozen=True) class ExternalAuthConfig: @@ -23,13 +21,16 @@ class ExternalAuthConfig: add_ext_groups: bool = False group_mapping: Dict[str, str] = field(default_factory=dict) - @staticmethod - def from_dict(data: JSON) -> "ExternalAuthConfig": - return ExternalAuthConfig( - url=data.get("url", None), - default_group_role=RoleType(data.get("default_group_role", RoleType.READER.value)), - add_ext_groups=data.get("add_ext_groups", False), - group_mapping=data.get("group_mapping", {}), + @classmethod + def from_dict(cls, data: JSON) -> "ExternalAuthConfig": + defaults = cls() + return cls( + url=data.get("url", defaults.url), + default_group_role=( + RoleType(data["default_group_role"]) if "default_group_role" in data else defaults.default_group_role + ), + add_ext_groups=data.get("add_ext_groups", defaults.add_ext_groups), + group_mapping=data.get("group_mapping", defaults.group_mapping), ) @@ -44,13 +45,18 @@ class SecurityConfig: disabled: bool = False external_auth: ExternalAuthConfig = ExternalAuthConfig() - @staticmethod - def from_dict(data: JSON) -> "SecurityConfig": - return SecurityConfig( - jwt_key=data.get("jwt", {}).get("key", ""), - admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", ""), - disabled=data.get("disabled", False), - external_auth=ExternalAuthConfig.from_dict(data.get("external_auth", {})), + @classmethod + def from_dict(cls, data: JSON) -> "SecurityConfig": + defaults = cls() + return cls( + jwt_key=data.get("jwt", {}).get("key", defaults.jwt_key), + admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", defaults.admin_pwd), + disabled=data.get("disabled", defaults.disabled), + external_auth=( + ExternalAuthConfig.from_dict(data["external_auth"]) + if "external_auth" in data + else defaults.external_auth + ), ) @@ -65,13 +71,14 @@ class WorkspaceConfig: groups: List[str] = field(default_factory=lambda: []) path: Path = Path() - @staticmethod - def from_dict(data: JSON) -> "WorkspaceConfig": - return WorkspaceConfig( - path=Path(data["path"]), - groups=data.get("groups", []), - filter_in=data.get("filter_in", [".*"]), - filter_out=data.get("filter_out", []), + @classmethod + def from_dict(cls, data: JSON) -> "WorkspaceConfig": + defaults = cls() + return cls( + filter_in=data.get("filter_in", defaults.filter_in), + filter_out=data.get("filter_out", defaults.filter_out), + groups=data.get("groups", defaults.groups), + path=Path(data["path"]) if "path" in data else defaults.path, ) @@ -91,18 +98,19 @@ class DbConfig: pool_size: int = 5 pool_use_lifo: bool = False - @staticmethod - def from_dict(data: JSON) -> "DbConfig": - return DbConfig( - db_admin_url=data.get("admin_url", None), - db_url=data.get("url", ""), - db_connect_timeout=data.get("db_connect_timeout", 10), - pool_recycle=data.get("pool_recycle", None), - pool_pre_ping=data.get("pool_pre_ping", False), - pool_use_null=data.get("pool_use_null", False), - pool_max_overflow=data.get("pool_max_overflow", 10), - pool_size=data.get("pool_size", 5), - pool_use_lifo=data.get("pool_use_lifo", False), + @classmethod + def from_dict(cls, data: JSON) -> "DbConfig": + defaults = cls() + return cls( + db_admin_url=data.get("admin_url", defaults.db_admin_url), + db_url=data.get("url", defaults.db_url), + db_connect_timeout=data.get("db_connect_timeout", defaults.db_connect_timeout), + pool_recycle=data.get("pool_recycle", defaults.pool_recycle), + pool_pre_ping=data.get("pool_pre_ping", defaults.pool_pre_ping), + pool_use_null=data.get("pool_use_null", defaults.pool_use_null), + pool_max_overflow=data.get("pool_max_overflow", defaults.pool_max_overflow), + pool_size=data.get("pool_size", defaults.pool_size), + pool_use_lifo=data.get("pool_use_lifo", defaults.pool_use_lifo), ) @@ -115,7 +123,7 @@ class StorageConfig: matrixstore: Path = Path("./matrixstore") archive_dir: Path = Path("./archives") tmp_dir: Path = Path(tempfile.gettempdir()) - workspaces: Dict[str, WorkspaceConfig] = field(default_factory=lambda: {}) + workspaces: Dict[str, WorkspaceConfig] = field(default_factory=dict) allow_deletion: bool = False watcher_lock: bool = True watcher_lock_delay: int = 10 @@ -127,39 +135,112 @@ class StorageConfig: auto_archive_sleeping_time: int = 3600 auto_archive_max_parallel: int = 5 - @staticmethod - def from_dict(data: JSON) -> "StorageConfig": - return StorageConfig( - tmp_dir=Path(data.get("tmp_dir", tempfile.gettempdir())), - matrixstore=Path(data["matrixstore"]), - workspaces={n: WorkspaceConfig.from_dict(w) for n, w in data["workspaces"].items()}, - allow_deletion=data.get("allow_deletion", False), - archive_dir=Path(data["archive_dir"]), - watcher_lock=data.get("watcher_lock", True), - watcher_lock_delay=data.get("watcher_lock_delay", 10), - download_default_expiration_timeout_minutes=data.get("download_default_expiration_timeout_minutes", 1440), - matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", 3600), - matrix_gc_dry_run=data.get("matrix_gc_dry_run", False), - auto_archive_threshold_days=data.get("auto_archive_threshold_days", 60), - auto_archive_dry_run=data.get("auto_archive_dry_run", False), - auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", 3600), - auto_archive_max_parallel=data.get("auto_archive_max_parallel", 5), + @classmethod + def from_dict(cls, data: JSON) -> "StorageConfig": + defaults = cls() + workspaces = ( + {key: WorkspaceConfig.from_dict(value) for key, value in data["workspaces"].items()} + if "workspaces" in data + else defaults.workspaces + ) + return cls( + matrixstore=Path(data["matrixstore"]) if "matrixstore" in data else defaults.matrixstore, + archive_dir=Path(data["archive_dir"]) if "archive_dir" in data else defaults.archive_dir, + tmp_dir=Path(data["tmp_dir"]) if "tmp_dir" in data else defaults.tmp_dir, + workspaces=workspaces, + allow_deletion=data.get("allow_deletion", defaults.allow_deletion), + watcher_lock=data.get("watcher_lock", defaults.watcher_lock), + watcher_lock_delay=data.get("watcher_lock_delay", defaults.watcher_lock_delay), + download_default_expiration_timeout_minutes=( + data.get( + "download_default_expiration_timeout_minutes", + defaults.download_default_expiration_timeout_minutes, + ) + ), + matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", defaults.matrix_gc_sleeping_time), + matrix_gc_dry_run=data.get("matrix_gc_dry_run", defaults.matrix_gc_dry_run), + auto_archive_threshold_days=data.get("auto_archive_threshold_days", defaults.auto_archive_threshold_days), + auto_archive_dry_run=data.get("auto_archive_dry_run", defaults.auto_archive_dry_run), + auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", defaults.auto_archive_sleeping_time), + auto_archive_max_parallel=data.get("auto_archive_max_parallel", defaults.auto_archive_max_parallel), ) +@dataclass(frozen=True) +class NbCoresConfig: + """ + The NBCoresConfig class is designed to manage the configuration of the number of CPU cores + """ + + min: int = 1 + default: int = 22 + max: int = 24 + + def to_json(self) -> Dict[str, int]: + """ + Retrieves the number of cores parameters, returning a dictionary containing the values "min" + (minimum allowed value), "defaultValue" (default value), and "max" (maximum allowed value) + + Returns: + A dictionary: `{"min": min, "defaultValue": default, "max": max}`. + Because ReactJs Material UI expects "min", "defaultValue" and "max" keys. + """ + return {"min": self.min, "defaultValue": self.default, "max": self.max} + + def __post_init__(self) -> None: + """validation of CPU configuration""" + if 1 <= self.min <= self.default <= self.max: + return + msg = f"Invalid configuration: 1 <= {self.min=} <= {self.default=} <= {self.max=}" + raise ValueError(msg) + + @dataclass(frozen=True) class LocalConfig: + """Sub config object dedicated to launcher module (local)""" + binaries: Dict[str, Path] = field(default_factory=dict) + enable_nb_cores_detection: bool = True + nb_cores: NbCoresConfig = NbCoresConfig() - @staticmethod - def from_dict(data: JSON) -> Optional["LocalConfig"]: - return LocalConfig( - binaries={str(v): Path(p) for v, p in data["binaries"].items()}, + @classmethod + def from_dict(cls, data: JSON) -> "LocalConfig": + """ + Creates an instance of LocalConfig from a data dictionary + Args: + data: Parse config from dict. + Returns: object NbCoresConfig + """ + defaults = cls() + binaries = data.get("binaries", defaults.binaries) + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) + return cls( + binaries={str(v): Path(p) for v, p in binaries.items()}, + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) + @classmethod + def _autodetect_nb_cores(cls) -> Dict[str, int]: + """ + Automatically detects the number of cores available on the user's machine + Returns: Instance of NbCoresConfig + """ + min_cpu = cls.nb_cores.min + max_cpu = multiprocessing.cpu_count() + default = max(min_cpu, max_cpu - 2) + return {"min": min_cpu, "max": max_cpu, "default": default} + @dataclass(frozen=True) class SlurmConfig: + """ + Sub config object dedicated to launcher module (slurm) + """ + local_workspace: Path = Path() username: str = "" hostname: str = "" @@ -169,31 +250,68 @@ class SlurmConfig: password: str = "" default_wait_time: int = 0 default_time_limit: int = 0 - default_n_cpu: int = 1 default_json_db_name: str = "" slurm_script_path: str = "" max_cores: int = 64 antares_versions_on_remote_server: List[str] = field(default_factory=list) + enable_nb_cores_detection: bool = False + nb_cores: NbCoresConfig = NbCoresConfig() - @staticmethod - def from_dict(data: JSON) -> "SlurmConfig": - return SlurmConfig( - local_workspace=Path(data["local_workspace"]), - username=data["username"], - hostname=data["hostname"], - port=data["port"], - private_key_file=data.get("private_key_file", None), - key_password=data.get("key_password", None), - password=data.get("password", None), - default_wait_time=data["default_wait_time"], - default_time_limit=data["default_time_limit"], - default_n_cpu=data["default_n_cpu"], - default_json_db_name=data["default_json_db_name"], - slurm_script_path=data["slurm_script_path"], - antares_versions_on_remote_server=data["antares_versions_on_remote_server"], - max_cores=data.get("max_cores", 64), + @classmethod + def from_dict(cls, data: JSON) -> "SlurmConfig": + """ + Creates an instance of SlurmConfig from a data dictionary + + Args: + data: Parsed config from dict. + Returns: object SlurmConfig + """ + defaults = cls() + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if "default_n_cpu" in data: + # Use the old way to configure the NB cores for backward compatibility + nb_cores["default"] = int(data["default_n_cpu"]) + nb_cores["min"] = min(nb_cores["min"], nb_cores["default"]) + nb_cores["max"] = max(nb_cores["max"], nb_cores["default"]) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) + return cls( + local_workspace=Path(data.get("local_workspace", defaults.local_workspace)), + username=data.get("username", defaults.username), + hostname=data.get("hostname", defaults.hostname), + port=data.get("port", defaults.port), + private_key_file=data.get("private_key_file", defaults.private_key_file), + key_password=data.get("key_password", defaults.key_password), + password=data.get("password", defaults.password), + default_wait_time=data.get("default_wait_time", defaults.default_wait_time), + default_time_limit=data.get("default_time_limit", defaults.default_time_limit), + default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name), + slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path), + antares_versions_on_remote_server=data.get( + "antares_versions_on_remote_server", + defaults.antares_versions_on_remote_server, + ), + max_cores=data.get("max_cores", defaults.max_cores), + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) + @classmethod + def _autodetect_nb_cores(cls) -> Dict[str, int]: + raise NotImplementedError("NB Cores auto-detection is not implemented for SLURM server") + + +class InvalidConfigurationError(Exception): + """ + Exception raised when an attempt is made to retrieve the number of cores + of a launcher that doesn't exist in the configuration. + """ + + def __init__(self, launcher: str): + msg = f"Configuration is not available for the '{launcher}' launcher" + super().__init__(msg) + @dataclass(frozen=True) class LauncherConfig: @@ -202,27 +320,53 @@ class LauncherConfig: """ default: str = "local" - local: Optional[LocalConfig] = LocalConfig() - slurm: Optional[SlurmConfig] = SlurmConfig() + local: Optional[LocalConfig] = None + slurm: Optional[SlurmConfig] = None batch_size: int = 9999 - @staticmethod - def from_dict(data: JSON) -> "LauncherConfig": - local: Optional[LocalConfig] = None - if "local" in data: - local = LocalConfig.from_dict(data["local"]) - - slurm: Optional[SlurmConfig] = None - if "slurm" in data: - slurm = SlurmConfig.from_dict(data["slurm"]) - - return LauncherConfig( - default=data.get("default", "local"), + @classmethod + def from_dict(cls, data: JSON) -> "LauncherConfig": + defaults = cls() + default = data.get("default", cls.default) + local = LocalConfig.from_dict(data["local"]) if "local" in data else defaults.local + slurm = SlurmConfig.from_dict(data["slurm"]) if "slurm" in data else defaults.slurm + batch_size = data.get("batch_size", defaults.batch_size) + return cls( + default=default, local=local, slurm=slurm, - batch_size=data.get("batch_size", 9999), + batch_size=batch_size, ) + def __post_init__(self) -> None: + possible = {"local", "slurm"} + if self.default in possible: + return + msg = f"Invalid configuration: {self.default=} must be one of {possible!r}" + raise ValueError(msg) + + def get_nb_cores(self, launcher: str) -> "NbCoresConfig": + """ + Retrieve the number of cores configuration for a given launcher: "local" or "slurm". + If "default" is specified, retrieve the configuration of the default launcher. + + Args: + launcher: type of launcher "local", "slurm" or "default". + + Returns: + Number of cores of the given launcher. + + Raises: + InvalidConfigurationError: Exception raised when an attempt is made to retrieve + the number of cores of a launcher that doesn't exist in the configuration. + """ + config_map = {"local": self.local, "slurm": self.slurm} + config_map["default"] = config_map[self.default] + launcher_config = config_map.get(launcher) + if launcher_config is None: + raise InvalidConfigurationError(launcher) + return launcher_config.nb_cores + @dataclass(frozen=True) class LoggingConfig: @@ -234,14 +378,13 @@ class LoggingConfig: json: bool = False level: str = "INFO" - @staticmethod - def from_dict(data: JSON) -> "LoggingConfig": - logging_config: Dict[str, Any] = data or {} - logfile: Optional[str] = logging_config.get("logfile") - return LoggingConfig( - logfile=Path(logfile) if logfile is not None else None, - json=logging_config.get("json", False), - level=logging_config.get("level", "INFO"), + @classmethod + def from_dict(cls, data: JSON) -> "LoggingConfig": + defaults = cls() + return cls( + logfile=Path(data["logfile"]) if "logfile" in data else defaults.logfile, + json=data.get("json", defaults.json), + level=data.get("level", defaults.level), ) @@ -255,12 +398,13 @@ class RedisConfig: port: int = 6379 password: Optional[str] = None - @staticmethod - def from_dict(data: JSON) -> "RedisConfig": - return RedisConfig( - host=data["host"], - port=data["port"], - password=data.get("password", None), + @classmethod + def from_dict(cls, data: JSON) -> "RedisConfig": + defaults = cls() + return cls( + host=data.get("host", defaults.host), + port=data.get("port", defaults.port), + password=data.get("password", defaults.password), ) @@ -271,9 +415,9 @@ class EventBusConfig: """ # noinspection PyUnusedLocal - @staticmethod - def from_dict(data: JSON) -> "EventBusConfig": - return EventBusConfig() + @classmethod + def from_dict(cls, data: JSON) -> "EventBusConfig": + return cls() @dataclass(frozen=True) @@ -284,10 +428,11 @@ class CacheConfig: checker_delay: float = 0.2 # in seconds - @staticmethod - def from_dict(data: JSON) -> "CacheConfig": - return CacheConfig( - checker_delay=float(data["checker_delay"]) if "checker_delay" in data else 0.2, + @classmethod + def from_dict(cls, data: JSON) -> "CacheConfig": + defaults = cls() + return cls( + checker_delay=data.get("checker_delay", defaults.checker_delay), ) @@ -296,9 +441,13 @@ class RemoteWorkerConfig: name: str queues: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "RemoteWorkerConfig": - return RemoteWorkerConfig(name=data["name"], queues=data.get("queues", [])) + @classmethod + def from_dict(cls, data: JSON) -> "RemoteWorkerConfig": + defaults = cls(name="") # `name` is mandatory + return cls( + name=data["name"], + queues=data.get("queues", defaults.queues), + ) @dataclass(frozen=True) @@ -310,16 +459,17 @@ class TaskConfig: max_workers: int = 5 remote_workers: List[RemoteWorkerConfig] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "TaskConfig": - return TaskConfig( - max_workers=int(data["max_workers"]) if "max_workers" in data else 5, - remote_workers=list( - map( - lambda x: RemoteWorkerConfig.from_dict(x), - data.get("remote_workers", []), - ) - ), + @classmethod + def from_dict(cls, data: JSON) -> "TaskConfig": + defaults = cls() + remote_workers = ( + [RemoteWorkerConfig.from_dict(d) for d in data["remote_workers"]] + if "remote_workers" in data + else defaults.remote_workers + ) + return cls( + max_workers=data.get("max_workers", defaults.max_workers), + remote_workers=remote_workers, ) @@ -332,11 +482,12 @@ class ServerConfig: worker_threadpool_size: int = 5 services: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "ServerConfig": - return ServerConfig( - worker_threadpool_size=int(data["worker_threadpool_size"]) if "worker_threadpool_size" in data else 5, - services=data.get("services", []), + @classmethod + def from_dict(cls, data: JSON) -> "ServerConfig": + defaults = cls() + return cls( + worker_threadpool_size=data.get("worker_threadpool_size", defaults.worker_threadpool_size), + services=data.get("services", defaults.services), ) @@ -360,36 +511,27 @@ class Config: tasks: TaskConfig = TaskConfig() root_path: str = "" - @staticmethod - def from_dict(data: JSON, res: Optional[Path] = None) -> "Config": - """ - Parse config from dict. - - Args: - data: dict struct to parse - res: resources path is not present in yaml file. - - Returns: - - """ - return Config( - security=SecurityConfig.from_dict(data.get("security", {})), - storage=StorageConfig.from_dict(data["storage"]), - launcher=LauncherConfig.from_dict(data.get("launcher", {})), - db=DbConfig.from_dict(data["db"]) if "db" in data else DbConfig(), - logging=LoggingConfig.from_dict(data.get("logging", {})), - debug=data.get("debug", False), - resources_path=res or Path(), - root_path=data.get("root_path", ""), - redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else None, - eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else EventBusConfig(), - cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else CacheConfig(), - tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else TaskConfig(), - server=ServerConfig.from_dict(data["server"]) if "server" in data else ServerConfig(), + @classmethod + def from_dict(cls, data: JSON) -> "Config": + defaults = cls() + return cls( + server=ServerConfig.from_dict(data["server"]) if "server" in data else defaults.server, + security=SecurityConfig.from_dict(data["security"]) if "security" in data else defaults.security, + storage=StorageConfig.from_dict(data["storage"]) if "storage" in data else defaults.storage, + launcher=LauncherConfig.from_dict(data["launcher"]) if "launcher" in data else defaults.launcher, + db=DbConfig.from_dict(data["db"]) if "db" in data else defaults.db, + logging=LoggingConfig.from_dict(data["logging"]) if "logging" in data else defaults.logging, + debug=data.get("debug", defaults.debug), + resources_path=data["resources_path"] if "resources_path" in data else defaults.resources_path, + redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else defaults.redis, + eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else defaults.eventbus, + cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache, + tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks, + root_path=data.get("root_path", defaults.root_path), ) - @staticmethod - def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": + @classmethod + def from_yaml_file(cls, file: Path, res: Optional[Path] = None) -> "Config": """ Parse config from yaml file. @@ -400,5 +542,8 @@ def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": Returns: """ - data = yaml.safe_load(open(file)) - return Config.from_dict(data, res) + with open(file) as f: + data = yaml.safe_load(f) + if res is not None: + data["resources_path"] = res + return cls.from_dict(data) diff --git a/antarest/core/filetransfer/model.py b/antarest/core/filetransfer/model.py index 9442f051dc..bbb61c00b6 100644 --- a/antarest/core/filetransfer/model.py +++ b/antarest/core/filetransfer/model.py @@ -13,7 +13,7 @@ class FileDownloadNotFound(HTTPException): def __init__(self) -> None: super().__init__( HTTPStatus.NOT_FOUND, - f"Requested download file was not found. It must have expired", + "Requested download file was not found. It must have expired", ) @@ -21,7 +21,7 @@ class FileDownloadNotReady(HTTPException): def __init__(self) -> None: super().__init__( HTTPStatus.NOT_ACCEPTABLE, - f"Requested file is not ready for download.", + "Requested file is not ready for download.", ) @@ -70,4 +70,11 @@ def to_dto(self) -> FileDownloadDTO: ) def __repr__(self) -> str: - return f"(id={self.id},name={self.name},filename={self.filename},path={self.path},ready={self.ready},expiration_date={self.expiration_date})" + return ( + f"(id={self.id}," + f" name={self.name}," + f" filename={self.filename}," + f" path={self.path}," + f" ready={self.ready}," + f" expiration_date={self.expiration_date})" + ) diff --git a/antarest/core/logging/utils.py b/antarest/core/logging/utils.py index 9115dba45b..b0e5227ea3 100644 --- a/antarest/core/logging/utils.py +++ b/antarest/core/logging/utils.py @@ -124,12 +124,14 @@ def configure_logger(config: Config, handler_cls: str = "logging.FileHandler") - "filters": ["context"], } elif handler_cls == "logging.handlers.TimedRotatingFileHandler": + # 90 days = 3 months + # keep only 1 backup (0 means keep all) logging_config["handlers"]["default"] = { "class": handler_cls, "filename": config.logging.logfile, - "when": "D", # D = day - "interval": 90, # 90 days = 3 months - "backupCount": 1, # keep only 1 backup (0 means keep all) + "when": "D", + "interval": 90, + "backupCount": 1, "encoding": "utf-8", "delay": False, "utc": False, diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 1c9ac3bf18..227b3fbf71 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -80,7 +80,7 @@ def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None: # noinspection PyUnusedLocal def noop_notifier(message: str) -> None: - pass + """This function is used in tasks when no notification is required.""" DEFAULT_AWAIT_MAX_TIMEOUT = 172800 @@ -121,7 +121,7 @@ async def _await_task_end(event: Event) -> None: return _await_task_end - # todo: Is `logger_` parameter required? (consider refactoring) + # noinspection PyUnusedLocal def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult: listener_id = self.event_bus.add_listener( _create_awaiter(task_result_wrapper), @@ -338,14 +338,18 @@ def _run_task( result.message, result.return_value, ) + event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success] + event_msg = {True: "completed", False: "failed"}[result.success] self.event_bus.push( Event( - type=EventType.TASK_COMPLETED if result.success else EventType.TASK_FAILED, + type=event_type, payload=TaskEventPayload( id=task_id, - message=custom_event_messages.end - if custom_event_messages is not None - else f'Task {task_id} {"completed" if result.success else "failed"}', + message=( + custom_event_messages.end + if custom_event_messages is not None + else f"Task {task_id} {event_msg}" + ), ).dict(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, diff --git a/antarest/launcher/adapters/local_launcher/local_launcher.py b/antarest/launcher/adapters/local_launcher/local_launcher.py index 7865f8740c..8ee598985b 100644 --- a/antarest/launcher/adapters/local_launcher/local_launcher.py +++ b/antarest/launcher/adapters/local_launcher/local_launcher.py @@ -1,3 +1,4 @@ +import io import logging import shutil import signal @@ -6,7 +7,7 @@ import threading import time from pathlib import Path -from typing import IO, Callable, Dict, Optional, Tuple, cast +from typing import Callable, Dict, Optional, Tuple, cast from uuid import UUID from antarest.core.config import Config @@ -14,7 +15,7 @@ from antarest.core.interfaces.eventbus import IEventBus from antarest.core.requests import RequestParameters from antarest.launcher.adapters.abstractlauncher import AbstractLauncher, LauncherCallbacks, LauncherInitException -from antarest.launcher.adapters.log_manager import LogTailManager +from antarest.launcher.adapters.log_manager import follow from antarest.launcher.model import JobStatus, LauncherParametersDTO, LogType logger = logging.getLogger(__name__) @@ -133,8 +134,8 @@ def stop_reading_output() -> bool: ) thread = threading.Thread( - target=lambda: LogTailManager.follow( - cast(IO[str], process.stdout), + target=lambda: follow( + cast(io.StringIO, process.stdout), self.create_update_log(str(uuid)), stop_reading_output, None, diff --git a/antarest/launcher/adapters/log_manager.py b/antarest/launcher/adapters/log_manager.py index a0bfbdfe70..eeca586f32 100644 --- a/antarest/launcher/adapters/log_manager.py +++ b/antarest/launcher/adapters/log_manager.py @@ -1,8 +1,10 @@ +import contextlib +import io import logging import time from pathlib import Path from threading import Thread -from typing import IO, Callable, Dict, Optional +from typing import Callable, Dict, Optional, cast logger = logging.getLogger(__name__) @@ -11,7 +13,7 @@ class LogTailManager: BATCH_SIZE = 10 def __init__(self, log_base_dir: Path) -> None: - logger.info(f"Initiating Log manager") + logger.info("Initiating Log manager") self.log_base_dir = log_base_dir self.tracked_logs: Dict[str, Thread] = {} @@ -47,43 +49,6 @@ def stop_tracking(self, log_path: Optional[Path]) -> None: if log_path_key in self.tracked_logs: del self.tracked_logs[log_path_key] - @staticmethod - def follow( - io: IO[str], - handler: Callable[[str], None], - stop: Callable[[], bool], - log_file: Optional[str], - ) -> None: - line = "" - line_count = 0 - - while True: - if stop(): - break - tmp = io.readline() - if not tmp: - if line: - logger.debug(f"Calling handler for {log_file}") - try: - handler(line) - except Exception as e: - logger.error("Could not handle this log line", exc_info=e) - line = "" - line_count = 0 - time.sleep(0.1) - else: - line += tmp - if line.endswith("\n"): - line_count += 1 - if line_count >= LogTailManager.BATCH_SIZE: - logger.debug(f"Calling handler for {log_file}") - try: - handler(line) - except Exception as e: - logger.error("Could not handle this log line", exc_info=e) - line = "" - line_count = 0 - def _follow( self, log_file: Optional[Path], @@ -97,4 +62,37 @@ def _follow( with open(log_file, "r") as fh: logger.info(f"Scanning {log_file}") - LogTailManager.follow(fh, handler, stop, str(log_file)) + follow(cast(io.StringIO, fh), handler, stop, str(log_file)) + + +def follow( + file: io.StringIO, + handler: Callable[[str], None], + stop: Callable[[], bool], + log_file: Optional[str], +) -> None: + line = "" + line_count = 0 + + while True: + if stop(): + break + tmp = file.readline() + if tmp: + line += tmp + if line.endswith("\n"): + line_count += 1 + if line_count >= LogTailManager.BATCH_SIZE: + logger.debug(f"Calling handler for {log_file}") + with contextlib.suppress(Exception): + handler(line) + line = "" + line_count = 0 + else: + if line: + logger.debug(f"Calling handler for {log_file}") + with contextlib.suppress(Exception): + handler(line) + line = "" + line_count = 0 + time.sleep(0.1) diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index 926f2b50a3..00283b9ce8 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) logging.getLogger("paramiko").setLevel("WARN") -MAX_NB_CPU = 24 MAX_TIME_LIMIT = 864000 MIN_TIME_LIMIT = 3600 WORKSPACE_LOCK_FILE_NAME = ".lock" @@ -153,7 +152,7 @@ def _init_launcher_arguments(self, local_workspace: Optional[Path] = None) -> ar main_options_parameters = ParserParameters( default_wait_time=self.slurm_config.default_wait_time, default_time_limit=self.slurm_config.default_time_limit, - default_n_cpu=self.slurm_config.default_n_cpu, + default_n_cpu=self.slurm_config.nb_cores.default, studies_in_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_INPUT_DIR_NAME)), log_dir=str((Path(self.slurm_config.local_workspace) / LOG_DIR_NAME)), finished_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_OUTPUT_DIR_NAME)), @@ -440,7 +439,7 @@ def _run_study( _override_solver_version(study_path, version) append_log(launch_uuid, "Submitting study to slurm launcher") - launcher_args = self._check_and_apply_launcher_params(launcher_params) + launcher_args = self._apply_params(launcher_params) self._call_launcher(launcher_args, self.launcher_params) launch_success = self._check_if_study_is_in_launcher_db(launch_uuid) @@ -481,23 +480,40 @@ def _check_if_study_is_in_launcher_db(self, job_id: str) -> bool: studies = self.data_repo_tinydb.get_list_of_studies() return any(s.name == job_id for s in studies) - def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + def _apply_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + """ + Populate a `argparse.Namespace` object with the user parameters. + + Args: + launcher_params: + Contains the launcher parameters selected by the user. + If a parameter is not provided (`None`), the default value should be retrieved + from the configuration. + + Returns: + The `argparse.Namespace` object which is then passed to `antarestlauncher.main.run_with`, + to launch a simulation using Antares Launcher. + """ if launcher_params: launcher_args = deepcopy(self.launcher_args) - other_options = [] + if launcher_params.other_options: - options = re.split("\\s+", launcher_params.other_options) - for opt in options: - other_options.append(re.sub("[^a-zA-Z0-9_,-]", "", opt)) - if launcher_params.xpansion is not None: - launcher_args.xpansion_mode = "r" if launcher_params.xpansion_r_version else "cpp" + options = launcher_params.other_options.split() + other_options = [re.sub("[^a-zA-Z0-9_,-]", "", opt) for opt in options] + else: + other_options = [] + + # launcher_params.xpansion can be an `XpansionParametersDTO`, a bool or `None` + if launcher_params.xpansion: # not None and not False + launcher_args.xpansion_mode = {True: "r", False: "cpp"}[launcher_params.xpansion_r_version] if ( isinstance(launcher_params.xpansion, XpansionParametersDTO) and launcher_params.xpansion.sensitivity_mode ): other_options.append("xpansion_sensitivity") + time_limit = launcher_params.time_limit - if time_limit and isinstance(time_limit, int): + if time_limit is not None: if MIN_TIME_LIMIT > time_limit: logger.warning( f"Invalid slurm launcher time limit ({time_limit})," @@ -512,15 +528,23 @@ def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDT launcher_args.time_limit = MAX_TIME_LIMIT - 3600 else: launcher_args.time_limit = time_limit + post_processing = launcher_params.post_processing - if isinstance(post_processing, bool): + if post_processing is not None: launcher_args.post_processing = post_processing + nb_cpu = launcher_params.nb_cpu - if nb_cpu and isinstance(nb_cpu, int): - if 0 < nb_cpu <= MAX_NB_CPU: + if nb_cpu is not None: + nb_cores = self.slurm_config.nb_cores + if nb_cores.min <= nb_cpu <= nb_cores.max: launcher_args.n_cpu = nb_cpu else: - logger.warning(f"Invalid slurm launcher nb_cpu ({nb_cpu}), should be between 1 and 24") + logger.warning( + f"Invalid slurm launcher nb_cpu ({nb_cpu})," + f" should be between {nb_cores.min} and {nb_cores.max}" + ) + launcher_args.n_cpu = nb_cores.default + if launcher_params.adequacy_patch is not None: # the adequacy patch can be an empty object launcher_args.post_processing = True diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 7a3c615811..a9bf0f6fde 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -17,14 +17,14 @@ class XpansionParametersDTO(BaseModel): class LauncherParametersDTO(BaseModel): - # Warning ! This class must be retrocompatible (that's the reason for the weird bool/XpansionParametersDTO union) + # Warning ! This class must be retro-compatible (that's the reason for the weird bool/XpansionParametersDTO union) # The reason is that it's stored in json format in database and deserialized using the latest class version # If compatibility is to be broken, an (alembic) data migration script should be added adequacy_patch: Optional[Dict[str, Any]] = None nb_cpu: Optional[int] = None post_processing: bool = False - time_limit: Optional[int] = None - xpansion: Union[bool, Optional[XpansionParametersDTO]] = None + time_limit: Optional[int] = None # 3600 <= time_limit < 864000 (10 days) + xpansion: Union[XpansionParametersDTO, bool, None] = None xpansion_r_version: bool = False archive_output: bool = True auto_unzip: bool = True diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 340c00c27c..7312181dea 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -1,4 +1,5 @@ import functools +import json import logging import os import shutil @@ -10,7 +11,7 @@ from fastapi import HTTPException -from antarest.core.config import Config +from antarest.core.config import Config, NbCoresConfig from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager @@ -99,6 +100,21 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]: def get_launchers(self) -> List[str]: return list(self.launchers.keys()) + def get_nb_cores(self, launcher: str) -> NbCoresConfig: + """ + Retrieve the configuration of the launcher's nb of cores. + + Args: + launcher: name of the launcher: "default", "slurm" or "local". + + Returns: + Number of cores of the launcher + + Raises: + InvalidConfigurationError: if the launcher configuration is not available + """ + return self.config.launcher.get_nb_cores(launcher) + def _after_export_flat_hooks( self, job_id: str, @@ -160,7 +176,7 @@ def update( channel=EventChannelDirectory.JOB_STATUS + job_result.id, ) ) - logger.info(f"Study status set") + logger.info("Study status set") def append_log(self, job_id: str, message: str, log_type: JobLogType) -> None: try: @@ -586,27 +602,31 @@ def get_load(self, from_cluster: bool = False) -> Dict[str, float]: local_running_jobs.append(job) else: logger.warning(f"Unknown job launcher {job.launcher}") + load = {} - if self.config.launcher.slurm: + + slurm_config = self.config.launcher.slurm + if slurm_config is not None: if from_cluster: - raise NotImplementedError - slurm_used_cpus = functools.reduce( - lambda count, j: count - + ( - LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu - or self.config.launcher.slurm.default_n_cpu # type: ignore - ), - slurm_running_jobs, - 0, - ) - load["slurm"] = float(slurm_used_cpus) / self.config.launcher.slurm.max_cores - if self.config.launcher.local: - local_used_cpus = functools.reduce( - lambda count, j: count + (LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu or 1), - local_running_jobs, - 0, - ) - load["local"] = float(local_used_cpus) / (os.cpu_count() or 1) + raise NotImplementedError("Cluster load not implemented yet") + default_cpu = slurm_config.nb_cores.default + slurm_used_cpus = 0 + for job in slurm_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + slurm_used_cpus += launch_params.nb_cpu or default_cpu + load["slurm"] = slurm_used_cpus / slurm_config.max_cores + + local_config = self.config.launcher.local + if local_config is not None: + default_cpu = local_config.nb_cores.default + local_used_cpus = 0 + for job in local_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + local_used_cpus += launch_params.nb_cpu or default_cpu + load["local"] = local_used_cpus / local_config.nb_cores.max + return load def get_solver_versions(self, solver: str) -> List[str]: diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index ffb4cf6ccf..51b3582997 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, Query from fastapi.exceptions import HTTPException -from antarest.core.config import Config +from antarest.core.config import Config, InvalidConfigurationError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters @@ -230,4 +230,48 @@ def get_solver_versions( raise UnknownSolverConfig(solver) return service.get_solver_versions(solver) + # noinspection SpellCheckingInspection + @bp.get( + "/launcher/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints + tags=[APITag.launcher], + summary="Retrieving Min, Default, and Max Core Count", + response_model=Dict[str, int], + ) + def get_nb_cores( + launcher: str = Query( + "default", + examples={ + "Default launcher": { + "description": "Min, Default, and Max Core Count", + "value": "default", + }, + "SLURM launcher": { + "description": "Min, Default, and Max Core Count", + "value": "slurm", + }, + "Local launcher": { + "description": "Min, Default, and Max Core Count", + "value": "local", + }, + }, + ) + ) -> Dict[str, int]: + """ + Retrieve the numer of cores of the launcher. + + Args: + - `launcher`: name of the configuration to read: "slurm" or "local". + If "default" is specified, retrieve the configuration of the default launcher. + + Returns: + - "min": min number of cores + - "defaultValue": default number of cores + - "max": max number of cores + """ + logger.info(f"Fetching the number of cores for the '{launcher}' configuration") + try: + return service.config.launcher.get_nb_cores(launcher).to_json() + except InvalidConfigurationError: + raise UnknownSolverConfig(launcher) + return bp diff --git a/antarest/login/web.py b/antarest/login/web.py index 86561457ac..ab63ec16b2 100644 --- a/antarest/login/web.py +++ b/antarest/login/web.py @@ -110,7 +110,7 @@ def users_get_all( details: Optional[bool] = False, current_user: JWTUser = Depends(auth.get_current_user), ) -> Any: - logger.info(f"Fetching users list", extra={"user": current_user.id}) + logger.info("Fetching users list", extra={"user": current_user.id}) params = RequestParameters(user=current_user) return service.get_all_users(params, details) @@ -188,7 +188,7 @@ def groups_get_all( details: Optional[bool] = False, current_user: JWTUser = Depends(auth.get_current_user), ) -> Any: - logger.info(f"Fetching groups list", extra={"user": current_user.id}) + logger.info("Fetching groups list", extra={"user": current_user.id}) params = RequestParameters(user=current_user) return service.get_all_groups(params, details) diff --git a/antarest/matrixstore/web.py b/antarest/matrixstore/web.py index 4b47135b52..523176b241 100644 --- a/antarest/matrixstore/web.py +++ b/antarest/matrixstore/web.py @@ -37,7 +37,7 @@ def create( matrix: List[List[MatrixData]] = Body(description="matrix dto", default=[]), current_user: JWTUser = Depends(auth.get_current_user), ) -> Any: - logger.info(f"Creating new matrix", extra={"user": current_user.id}) + logger.info("Creating new matrix", extra={"user": current_user.id}) if current_user.id is not None: return service.create(matrix) raise UserHasNotPermissionError() @@ -60,7 +60,7 @@ def create_by_importation( @bp.get("/matrix/{id}", tags=[APITag.matrix], response_model=MatrixDTO) def get(id: str, user: JWTUser = Depends(auth.get_current_user)) -> Any: - logger.info(f"Fetching matrix", extra={"user": user.id}) + logger.info("Fetching matrix", extra={"user": user.id}) if user.id is not None: return service.get(id) raise UserHasNotPermissionError() diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index 0fedb98393..ca1f714750 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -12,8 +12,9 @@ from antarest.matrixstore.model import MatrixData from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.storage_service import StudyStorageService -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint @@ -43,7 +44,7 @@ class BindingConstraintDTO(BaseModel): id: str name: str enabled: bool = True - time_step: TimeStep + time_step: BindingConstraintFrequency operator: BindingConstraintOperator values: Optional[Union[List[List[MatrixData]], str]] = None comments: Optional[str] = None diff --git a/antarest/study/business/st_storage_manager.py b/antarest/study/business/st_storage_manager.py index d1c040741b..f16ff680d4 100644 --- a/antarest/study/business/st_storage_manager.py +++ b/antarest/study/business/st_storage_manager.py @@ -1,10 +1,10 @@ import functools import json import operator -from typing import Any, Dict, List, Mapping, MutableMapping, Sequence +from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence import numpy as np -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import BaseModel, Extra, root_validator, validator from typing_extensions import Literal from antarest.core.exceptions import ( @@ -12,9 +12,13 @@ STStorageFieldsNotFoundError, STStorageMatrixNotFoundError, ) -from antarest.study.business.utils import AllOptionalMetaclass, FormFieldsBaseModel, execute_or_add_commands +from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands from antarest.study.model import Study -from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageConfig, STStorageGroup +from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import ( + STStorageConfig, + STStorageGroup, + STStorageProperties, +) from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.model.command.create_st_storage import CreateSTStorage from antarest.study.storage.variantstudy.model.command.remove_st_storage import RemoveSTStorage @@ -23,77 +27,12 @@ _HOURS_IN_YEAR = 8760 -class FormBaseModel(FormFieldsBaseModel): - """ - A foundational model for all form-based models, providing common configurations. - """ - - class Config: - validate_assignment = True - allow_population_by_field_name = True - - -class StorageCreation(FormBaseModel): +@camel_case_model +class StorageInput(STStorageProperties, metaclass=AllOptionalMetaclass): """ - Model representing the form used to create a new short-term storage entry. + Model representing the form used to EDIT an existing short-term storage. """ - name: str = Field( - description="Name of the storage.", - regex=r"[a-zA-Z0-9_(),& -]+", - ) - group: STStorageGroup = Field( - description="Energy storage system group.", - ) - - class Config: - @staticmethod - def schema_extra(schema: MutableMapping[str, Any]) -> None: - schema["example"] = StorageCreation( - name="Siemens Battery", - group=STStorageGroup.BATTERY, - ) - - @property - def to_config(self) -> STStorageConfig: - values = self.dict(by_alias=False) - return STStorageConfig(**values) - - -class StorageUpdate(StorageCreation, metaclass=AllOptionalMetaclass): - """set name, group as optional fields""" - - -class StorageInput(StorageUpdate): - """ - Model representing the form used to edit existing short-term storage details. - """ - - injection_nominal_capacity: float = Field( - description="Injection nominal capacity (MW)", - ge=0, - ) - withdrawal_nominal_capacity: float = Field( - description="Withdrawal nominal capacity (MW)", - ge=0, - ) - reservoir_capacity: float = Field( - description="Reservoir capacity (MWh)", - ge=0, - ) - efficiency: float = Field( - description="Efficiency of the storage system", - ge=0, - le=1, - ) - initial_level: float = Field( - description="Initial level of the storage system", - ge=0, - ) - initial_level_optim: bool = Field( - description="Flag indicating if the initial level is optimized", - ) - class Config: @staticmethod def schema_extra(schema: MutableMapping[str, Any]) -> None: @@ -104,19 +43,37 @@ def schema_extra(schema: MutableMapping[str, Any]) -> None: withdrawal_nominal_capacity=150, reservoir_capacity=600, efficiency=0.94, + initial_level=0.5, initial_level_optim=True, ) -class StorageOutput(StorageInput): +class StorageCreation(StorageInput): """ - Model representing the form used to display the details of a short-term storage entry. + Model representing the form used to CREATE a new short-term storage. """ - id: str = Field( - description="Short-term storage ID", - regex=r"[a-zA-Z0-9_(),& -]+", - ) + # noinspection Pydantic + @validator("name", pre=True) + def validate_name(cls, name: Optional[str]) -> str: + """ + Validator to check if the name is not empty. + """ + if not name: + raise ValueError("'name' must not be empty") + return name + + @property + def to_config(self) -> STStorageConfig: + values = self.dict(by_alias=False, exclude_none=True) + return STStorageConfig(**values) + + +@camel_case_model +class StorageOutput(STStorageConfig): + """ + Model representing the form used to display the details of a short-term storage entry. + """ class Config: @staticmethod diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py index b602c60138..33b62d766c 100644 --- a/antarest/study/business/utils.py +++ b/antarest/study/business/utils.py @@ -117,3 +117,19 @@ def __new__( annotations[field] = Optional[annotations[field]] namespaces["__annotations__"] = annotations return super().__new__(cls, name, bases, namespaces) + + +def camel_case_model(model: Type[BaseModel]) -> Type[BaseModel]: + """ + This decorator can be used to modify a model to use camel case aliases. + + Args: + model: The pydantic model to modify. + + Returns: + The modified model. + """ + model.__config__.alias_generator = to_camel_case + for field_name, field in model.__fields__.items(): + field.alias = to_camel_case(field_name) + return model diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index 4dec6e3ef6..e8c6c79f98 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -436,7 +436,7 @@ def _assert_candidate_is_correct( xpansion_candidate_dto: XpansionCandidateDTO, new_name: bool = False, ) -> None: - logger.info(f"Checking given candidate is correct") + logger.info("Checking given candidate is correct") self._assert_no_illegal_character_is_in_candidate_name(xpansion_candidate_dto.name) if new_name: self._assert_candidate_name_is_not_already_taken(candidates, xpansion_candidate_dto.name) diff --git a/antarest/study/service.py b/antarest/study/service.py index b9ec491c18..64d5c7d83c 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -641,15 +641,20 @@ def create_study( def get_user_name(self, params: RequestParameters) -> str: """ - Args: params : Request parameters + Retrieves the name of a user based on the provided request parameters. - Returns: The user's name + Args: + params: The request parameters which includes user information. + + Returns: + Returns the user's name or, if the logged user is a "bot" + (i.e., an application's token), it returns the token's author name. """ - author = "Unknown" if params.user: - if curr_user := self.user_service.get_user(params.user.id, params): - author = curr_user.to_dto().name - return author + user_id = params.user.impersonator if params.user.type == "bots" else params.user.id + if curr_user := self.user_service.get_user(user_id, params): + return curr_user.to_dto().name + return "Unknown" def get_study_synthesis(self, study_id: str, params: RequestParameters) -> FileStudyTreeConfigDTO: """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py b/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py new file mode 100644 index 0000000000..3648051773 --- /dev/null +++ b/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py @@ -0,0 +1,17 @@ +from enum import Enum +from typing import Set + +from pydantic import BaseModel + + +class BindingConstraintFrequency(str, Enum): + HOURLY = "hourly" + DAILY = "daily" + WEEKLY = "weekly" + + +class BindingConstraintDTO(BaseModel): + id: str + areas: Set[str] + clusters: Set[str] + time_step: BindingConstraintFrequency diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/files.py b/antarest/study/storage/rawstudy/model/filesystem/config/files.py index 9f8111911a..8e174f3ea0 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/files.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/files.py @@ -10,13 +10,16 @@ from antarest.core.model import JSON from antarest.core.utils.utils import extract_file_to_tmp_dir from antarest.study.storage.rawstudy.io.reader import IniReader, MultipleSameKeysIniReader +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( + BindingConstraintDTO, + BindingConstraintFrequency, +) from antarest.study.storage.rawstudy.model.filesystem.config.exceptions import ( SimulationParsingError, XpansionParsingError, ) from antarest.study.storage.rawstudy.model.filesystem.config.model import ( Area, - BindingConstraintDTO, Cluster, DistrictSet, FileStudyTreeConfig, @@ -143,8 +146,12 @@ def _parse_bindings(root: Path) -> List[BindingConstraintDTO]: area_set = set() # contains a set of strings in the following format: "area.cluster" cluster_set = set() + # Default value for time_step + time_step = BindingConstraintFrequency.HOURLY for key in bind: - if "%" in key: + if key == "type": + time_step = BindingConstraintFrequency(bind[key]) + elif "%" in key: areas = key.split("%", 1) area_set.add(areas[0]) area_set.add(areas[1]) @@ -152,7 +159,9 @@ def _parse_bindings(root: Path) -> List[BindingConstraintDTO]: cluster_set.add(key) area_set.add(key.split(".", 1)[0]) - output_list.append(BindingConstraintDTO(id=bind["id"], areas=area_set, clusters=cluster_set)) + output_list.append( + BindingConstraintDTO(id=bind["id"], areas=area_set, clusters=cluster_set, time_step=time_step) + ) return output_list diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py index f688765922..774d225344 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py @@ -1,7 +1,7 @@ import re from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional from pydantic import Extra from pydantic.main import BaseModel @@ -9,6 +9,7 @@ from antarest.core.model import JSON from antarest.core.utils.utils import DTO +from .binding_constraint import BindingConstraintDTO from .st_storage import STStorageConfig @@ -106,12 +107,6 @@ def get_file(self) -> str: return f"{self.date}{modes[self.mode]}{dash}{self.name}" -class BindingConstraintDTO(BaseModel): - id: str - areas: Set[str] - clusters: Set[str] - - class FileStudyTreeConfig(DTO): """ Root object to handle all study parameters which impact tree structure diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py index d2a68c4799..b82910a191 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py @@ -30,22 +30,18 @@ class STStorageGroup(EnumIgnoreCase): # noinspection SpellCheckingInspection -class STStorageConfig(BaseModel): - """ - Manage the configuration files in the context of Short-Term Storage. - It provides a convenient way to read and write configuration data from/to an INI file format. +class STStorageProperties( + BaseModel, + extra=Extra.forbid, + validate_assignment=True, + allow_population_by_field_name=True, +): """ + Properties of a short-term storage system read from the configuration files. - class Config: - extra = Extra.forbid - allow_population_by_field_name = True + All aliases match the name of the corresponding field in the INI files. + """ - # The `id` field is a calculated from the `name` if not provided. - # This value must be stored in the config cache. - id: str = Field( - description="Short-term storage ID", - regex=r"[a-zA-Z0-9_(),& -]+", - ) name: str = Field( description="Short-term storage name", regex=r"[a-zA-Z0-9_(),& -]+", @@ -90,6 +86,21 @@ class Config: alias="initialleveloptim", ) + +# noinspection SpellCheckingInspection +class STStorageConfig(STStorageProperties): + """ + Manage the configuration files in the context of Short-Term Storage. + It provides a convenient way to read and write configuration data from/to an INI file format. + """ + + # The `id` field is a calculated from the `name` if not provided. + # This value must be stored in the config cache. + id: str = Field( + description="Short-term storage ID", + regex=r"[a-zA-Z0-9_(),& -]+", + ) + @root_validator(pre=True) def calculate_storage_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py index 92d5443957..e86dedfe18 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py @@ -1,16 +1,33 @@ +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.folder_node import FolderNode from antarest.study.storage.rawstudy.model.filesystem.inode import TREE from antarest.study.storage.rawstudy.model.filesystem.matrix.input_series_matrix import InputSeriesMatrix +from antarest.study.storage.rawstudy.model.filesystem.matrix.matrix import MatrixFrequency from antarest.study.storage.rawstudy.model.filesystem.root.input.bindingconstraints.bindingconstraints_ini import ( BindingConstraintsIni, ) +from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( + default_binding_constraint_daily, + default_binding_constraint_hourly, + default_binding_constraint_weekly, +) class BindingConstraints(FolderNode): def build(self) -> TREE: + default_matrices = { + BindingConstraintFrequency.HOURLY: default_binding_constraint_hourly, + BindingConstraintFrequency.DAILY: default_binding_constraint_daily, + BindingConstraintFrequency.WEEKLY: default_binding_constraint_weekly, + } children: TREE = { - binding.id: InputSeriesMatrix(self.context, self.config.next_file(f"{binding.id}.txt")) - # todo get the freq of binding to set the default empty matrix + binding.id: InputSeriesMatrix( + self.context, + self.config.next_file(f"{binding.id}.txt"), + freq=MatrixFrequency(binding.time_step), + nb_columns=3, + default_empty=default_matrices[binding.time_step], + ) for binding in self.config.bindings } diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index 0602c8b62a..ab1f8b6d42 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -9,12 +9,13 @@ from antarest.matrixstore.model import MatrixData from antarest.matrixstore.service import ISimpleMatrixService from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.files import get_playlist from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import FileStudyTree from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster @@ -348,7 +349,7 @@ def extract_binding_constraint( binding_constraint_command = CreateBindingConstraint( name=binding["name"], enabled=binding["enabled"], - time_step=TimeStep(binding["type"]), + time_step=BindingConstraintFrequency(binding["type"]), operator=BindingConstraintOperator(binding["operator"]), coeffs={ coeff: [float(el) for el in str(value).split("%")] diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py b/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py index 0f9b1e77ca..9212a8e6c6 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py @@ -1 +1 @@ -from . import hydro, link, prepro, st_storage, thermals +from . import binding_constraint, hydro, link, prepro, st_storage, thermals diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py new file mode 100644 index 0000000000..0a1b9046e5 --- /dev/null +++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py @@ -0,0 +1 @@ +from . import series diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py new file mode 100644 index 0000000000..e7b20a1137 --- /dev/null +++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py @@ -0,0 +1,10 @@ +import numpy as np + +default_binding_constraint_hourly = np.zeros((8760, 3), dtype=np.float64) +default_binding_constraint_hourly.flags.writeable = False + +default_binding_constraint_daily = np.zeros((365, 3), dtype=np.float64) +default_binding_constraint_daily.flags.writeable = False + +default_binding_constraint_weekly = np.zeros((52, 3), dtype=np.float64) +default_binding_constraint_weekly.flags.writeable = False diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py index 338506dd7c..8cb973785e 100644 --- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py +++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py @@ -34,6 +34,10 @@ EMPTY_SCENARIO_MATRIX = "empty_scenario_matrix" ONES_SCENARIO_MATRIX = "ones_scenario_matrix" +# Binding constraint aliases +BINDING_CONSTRAINT_HOURLY = "empty_2nd_member_hourly" +BINDING_CONSTRAINT_DAILY = "empty_2nd_member_daily" +BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_daily" # Short-term storage aliases ST_STORAGE_PMAX_INJECTION = ONES_SCENARIO_MATRIX @@ -84,6 +88,12 @@ def _init(self) -> None: self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS) self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS) + # Binding constraint matrices + series = matrix_constants.binding_constraint.series + self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_binding_constraint_hourly) + self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily) + self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create(series.default_binding_constraint_weekly) + # Some short-term storage matrices use np.ones((8760, 1)) self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create( matrix_constants.st_storage.series.pmax_injection @@ -141,6 +151,18 @@ def get_default_reserves(self) -> str: def get_default_miscgen(self) -> str: return MATRIX_PROTOCOL_PREFIX + self.hashes[MISCGEN_TS] + def get_binding_constraint_hourly(self) -> str: + """2D-matrix of shape (8760, 3), filled-in with zeros.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_HOURLY] + + def get_binding_constraint_daily(self) -> str: + """2D-matrix of shape (365, 3), filled-in with zeros.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_DAILY] + + def get_binding_constraint_weekly(self) -> str: + """2D-matrix of shape (52, 3), filled-in with zeros.""" + return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY] + def get_st_storage_pmax_injection(self) -> str: """2D-matrix of shape (8760, 1), filled-in with ones.""" return MATRIX_PROTOCOL_PREFIX + self.hashes[ST_STORAGE_PMAX_INJECTION] diff --git a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py index 13e3617ec4..b2eded8a6a 100644 --- a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py +++ b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py @@ -1,10 +1,14 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union from antarest.core.model import JSON from antarest.matrixstore.model import MatrixData -from antarest.study.storage.rawstudy.model.filesystem.config.model import BindingConstraintDTO, FileStudyTreeConfig +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( + BindingConstraintDTO, + BindingConstraintFrequency, +) +from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, CommandOutput, TimeStep +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, CommandOutput def cluster_does_not_exist(study_data: FileStudy, area: str, thermal_id: str) -> bool: @@ -25,27 +29,27 @@ def apply_binding_constraint( name: str, comments: Optional[str], enabled: bool, - time_step: TimeStep, + freq: BindingConstraintFrequency, operator: BindingConstraintOperator, coeffs: Dict[str, List[float]], values: Optional[Union[List[List[MatrixData]], str]], filter_year_by_year: Optional[str] = None, filter_synthesis: Optional[str] = None, ) -> CommandOutput: - binding_constraints[str(new_key)] = { + binding_constraints[new_key] = { "name": name, "id": bd_id, "enabled": enabled, - "type": time_step.value, + "type": freq.value, "operator": operator.value, } if study_data.config.version >= 830: if filter_year_by_year: - binding_constraints[str(new_key)]["filter-year-by-year"] = filter_year_by_year + binding_constraints[new_key]["filter-year-by-year"] = filter_year_by_year if filter_synthesis: - binding_constraints[str(new_key)]["filter-synthesis"] = filter_synthesis + binding_constraints[new_key]["filter-synthesis"] = filter_synthesis if comments is not None: - binding_constraints[str(new_key)]["comments"] = comments + binding_constraints[new_key]["comments"] = comments for link_or_thermal in coeffs: if "%" in link_or_thermal: @@ -67,7 +71,7 @@ def apply_binding_constraint( if len(coeffs[link_or_thermal]) == 2: coeffs[link_or_thermal][1] = int(coeffs[link_or_thermal][1]) - binding_constraints[str(new_key)][link_or_thermal] = "%".join( + binding_constraints[new_key][link_or_thermal] = "%".join( [str(coeff_val) for coeff_val in coeffs[link_or_thermal]] ) parse_bindings_coeffs_and_save_into_config(bd_id, study_data.config, coeffs) @@ -76,7 +80,8 @@ def apply_binding_constraint( ["input", "bindingconstraints", "bindingconstraints"], ) if values: - assert isinstance(values, str) + if not isinstance(values, str): # pragma: no cover + raise TypeError(repr(values)) study_data.tree.save(values, ["input", "bindingconstraints", bd_id]) return CommandOutput(status=True) @@ -84,19 +89,24 @@ def apply_binding_constraint( def parse_bindings_coeffs_and_save_into_config( bd_id: str, study_data_config: FileStudyTreeConfig, - coeffs: Dict[str, List[float]], + coeffs: Mapping[str, Union[Literal["hourly", "daily", "weekly"], Sequence[float]]], ) -> None: if bd_id not in [bind.id for bind in study_data_config.bindings]: areas_set = set() clusters_set = set() + # Default time_step value + time_step = BindingConstraintFrequency.HOURLY for k, v in coeffs.items(): + if k == "type": + time_step = BindingConstraintFrequency(v) if "%" in k: - areas_set.add(k.split("%")[0]) - areas_set.add(k.split("%")[1]) + areas_set |= set(k.split("%")) elif "." in k: clusters_set.add(k) areas_set.add(k.split(".")[0]) - study_data_config.bindings.append(BindingConstraintDTO(id=bd_id, areas=areas_set, clusters=clusters_set)) + study_data_config.bindings.append( + BindingConstraintDTO(id=bd_id, areas=areas_set, clusters=clusters_set, time_step=time_step) + ) def remove_area_cluster_from_binding_constraints( diff --git a/antarest/study/storage/variantstudy/model/command/common.py b/antarest/study/storage/variantstudy/model/command/common.py index 34c41402d6..a6ac905fd9 100644 --- a/antarest/study/storage/variantstudy/model/command/common.py +++ b/antarest/study/storage/variantstudy/model/command/common.py @@ -8,12 +8,6 @@ class CommandOutput: message: str = "" -class TimeStep(Enum): - HOURLY = "hourly" - DAILY = "daily" - WEEKLY = "weekly" - - class BindingConstraintOperator(Enum): BOTH = "both" EQUAL = "equal" diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 02e888f9ac..ed53e5e31c 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -1,11 +1,14 @@ +from abc import ABCMeta from typing import Any, Dict, List, Optional, Tuple, Union, cast -from pydantic import validator +import numpy as np +from pydantic import Field, validator -from antarest.core.utils.utils import assert_this from antarest.matrixstore.model import MatrixData +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig, transform_name_to_id from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy +from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol, validate_matrix from antarest.study.storage.variantstudy.business.utils_binding_constraint import ( apply_binding_constraint, @@ -15,39 +18,119 @@ BindingConstraintOperator, CommandName, CommandOutput, - TimeStep, ) from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand from antarest.study.storage.variantstudy.model.model import CommandDTO +__all__ = ("AbstractBindingConstraintCommand", "CreateBindingConstraint", "check_matrix_values") -class CreateBindingConstraint(ICommand): - name: str +MatrixType = List[List[MatrixData]] + + +def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType) -> None: + """ + Check the binding constraint's matrix values for the specified time step. + + Args: + time_step: The frequency of the binding constraint: "hourly", "daily" or "weekly". + values: The binding constraint's 2nd member matrix. + + Raises: + ValueError: + If the matrix shape does not match the expected shape for the given time step. + If the matrix values contain NaN (Not-a-Number). + """ + shapes = { + BindingConstraintFrequency.HOURLY: (8760, 3), + BindingConstraintFrequency.DAILY: (365, 3), + BindingConstraintFrequency.WEEKLY: (52, 3), + } + # Check the matrix values and create the corresponding matrix link + array = np.array(values, dtype=np.float64) + if array.shape != shapes[time_step]: + raise ValueError(f"Invalid matrix shape {array.shape}, expected {shapes[time_step]}") + if np.isnan(array).any(): + raise ValueError("Matrix values cannot contain NaN") + + +class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta): + """ + Abstract class for binding constraint commands. + """ + + # todo: add the `name` attribute because it should also be updated enabled: bool = True - time_step: TimeStep + time_step: BindingConstraintFrequency operator: BindingConstraintOperator coeffs: Dict[str, List[float]] - values: Optional[Union[List[List[MatrixData]], str]] = None + values: Optional[Union[MatrixType, str]] = Field(None, description="2nd member matrix") filter_year_by_year: Optional[str] = None filter_synthesis: Optional[str] = None comments: Optional[str] = None - def __init__(self, **data: Any) -> None: - super().__init__( - command_name=CommandName.CREATE_BINDING_CONSTRAINT, - version=1, - **data, + def to_dto(self) -> CommandDTO: + args = { + "enabled": self.enabled, + "time_step": self.time_step.value, + "operator": self.operator.value, + "coeffs": self.coeffs, + "comments": self.comments, + "filter_year_by_year": self.filter_year_by_year, + "filter_synthesis": self.filter_synthesis, + } + if self.values is not None: + args["values"] = strip_matrix_protocol(self.values) + return CommandDTO( + action=self.command_name.value, + args=args, ) + def get_inner_matrices(self) -> List[str]: + if self.values is not None: + if not isinstance(self.values, str): # pragma: no cover + raise TypeError(repr(self.values)) + return [strip_matrix_protocol(self.values)] + return [] + + +class CreateBindingConstraint(AbstractBindingConstraintCommand): + """ + Command used to create a binding constraint. + """ + + command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT + version: int = 1 + + # Properties of the `CREATE_BINDING_CONSTRAINT` command: + name: str + @validator("values", always=True) def validate_series( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any - ) -> Optional[Union[List[List[MatrixData]], str]]: + cls, + v: Optional[Union[MatrixType, str]], + values: Dict[str, Any], + ) -> Optional[Union[MatrixType, str]]: + constants: GeneratorMatrixConstants + constants = values["command_context"].generator_matrix_constants + time_step = values["time_step"] if v is None: - v = values["command_context"].generator_matrix_constants.get_null_matrix() - return v - else: + # Use an already-registered default matrix + methods = { + BindingConstraintFrequency.HOURLY: constants.get_binding_constraint_hourly, + BindingConstraintFrequency.DAILY: constants.get_binding_constraint_daily, + BindingConstraintFrequency.WEEKLY: constants.get_binding_constraint_weekly, + } + method = methods[time_step] + return method() + if isinstance(v, str): + # Check the matrix link + return validate_matrix(v, values) + if isinstance(v, list): + check_matrix_values(time_step, v) return validate_matrix(v, values) + # Invalid datatype + # pragma: no cover + raise TypeError(repr(v)) def _apply_config(self, study_data_config: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: bd_id = transform_name_to_id(self.name) @@ -55,7 +138,6 @@ def _apply_config(self, study_data_config: FileStudyTreeConfig) -> Tuple[Command return CommandOutput(status=True), {} def _apply(self, study_data: FileStudy) -> CommandOutput: - assert_this(isinstance(self.values, str)) binding_constraints = study_data.tree.get(["input", "bindingconstraints", "bindingconstraints"]) new_key = len(binding_constraints.keys()) bd_id = transform_name_to_id(self.name) @@ -76,20 +158,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: ) def to_dto(self) -> CommandDTO: - return CommandDTO( - action=CommandName.CREATE_BINDING_CONSTRAINT.value, - args={ - "name": self.name, - "enabled": self.enabled, - "time_step": self.time_step.value, - "operator": self.operator.value, - "coeffs": self.coeffs, - "values": strip_matrix_protocol(self.values), - "comments": self.comments, - "filter_year_by_year": self.filter_year_by_year, - "filter_synthesis": self.filter_synthesis, - }, - ) + dto = super().to_dto() + dto.args["name"] = self.name # type: ignore + return dto def match_signature(self) -> str: return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.name) @@ -129,9 +200,3 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: command_context=other.command_context, ) ] - - def get_inner_matrices(self) -> List[str]: - if self.values is not None: - assert_this(isinstance(self.values, str)) - return [strip_matrix_protocol(self.values)] - return [] diff --git a/antarest/study/storage/variantstudy/model/command/remove_area.py b/antarest/study/storage/variantstudy/model/command/remove_area.py index 49b785869a..03b287cf9f 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_area.py +++ b/antarest/study/storage/variantstudy/model/command/remove_area.py @@ -17,10 +17,15 @@ class RemoveArea(ICommand): - id: str + """ + Command used to remove an area. + """ + + command_name: CommandName = CommandName.REMOVE_AREA + version: int = 1 - def __init__(self, **data: Any) -> None: - super().__init__(command_name=CommandName.REMOVE_AREA, version=1, **data) + # Properties of the `REMOVE_AREA` command: + id: str def _remove_area_from_links_in_config(self, study_data_config: FileStudyTreeConfig) -> None: link_to_remove = [ diff --git a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py index 4e33fd6000..97a27f5297 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py @@ -9,14 +9,15 @@ class RemoveBindingConstraint(ICommand): - id: str + """ + Command used to remove a binding constraint. + """ - def __init__(self, **data: Any) -> None: - super().__init__( - command_name=CommandName.REMOVE_BINDING_CONSTRAINT, - version=1, - **data, - ) + command_name: CommandName = CommandName.REMOVE_BINDING_CONSTRAINT + version: int = 1 + + # Properties of the `REMOVE_BINDING_CONSTRAINT` command: + id: str def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: if self.id not in [bind.id for bind in study_data.bindings]: diff --git a/antarest/study/storage/variantstudy/model/command/remove_link.py b/antarest/study/storage/variantstudy/model/command/remove_link.py index aeb262fb14..5ec4c042cb 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_link.py +++ b/antarest/study/storage/variantstudy/model/command/remove_link.py @@ -8,12 +8,17 @@ class RemoveLink(ICommand): + """ + Command used to remove a link. + """ + + command_name: CommandName = CommandName.REMOVE_LINK + version: int = 1 + + # Properties of the `REMOVE_LINK` command: area1: str area2: str - def __init__(self, **data: Any) -> None: - super().__init__(command_name=CommandName.REMOVE_LINK, version=1, **data) - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: result = self._check_link_exists(study_data) if result[0].status: diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py index dc17838f58..3ec874a375 100644 --- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py @@ -3,47 +3,54 @@ from pydantic import validator from antarest.core.model import JSON -from antarest.core.utils.utils import assert_this from antarest.matrixstore.model import MatrixData from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy -from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol, validate_matrix +from antarest.study.storage.variantstudy.business.utils import validate_matrix from antarest.study.storage.variantstudy.business.utils_binding_constraint import apply_binding_constraint -from antarest.study.storage.variantstudy.model.command.common import ( - BindingConstraintOperator, - CommandName, - CommandOutput, - TimeStep, +from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput +from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( + AbstractBindingConstraintCommand, + check_matrix_values, ) from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand from antarest.study.storage.variantstudy.model.model import CommandDTO +__all__ = ("UpdateBindingConstraint",) -class UpdateBindingConstraint(ICommand): +MatrixType = List[List[MatrixData]] + + +class UpdateBindingConstraint(AbstractBindingConstraintCommand): + """ + Command used to update a binding constraint. + """ + + command_name: CommandName = CommandName.UPDATE_BINDING_CONSTRAINT + version: int = 1 + + # Properties of the `UPDATE_BINDING_CONSTRAINT` command: id: str - enabled: bool = True - time_step: TimeStep - operator: BindingConstraintOperator - coeffs: Dict[str, List[float]] - values: Optional[Union[List[List[MatrixData]], str]] = None - filter_year_by_year: Optional[str] = None - filter_synthesis: Optional[str] = None - comments: Optional[str] = None - - def __init__(self, **data: Any) -> None: - super().__init__( - command_name=CommandName.UPDATE_BINDING_CONSTRAINT, - version=1, - **data, - ) @validator("values", always=True) def validate_series( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any - ) -> Optional[Union[List[List[MatrixData]], str]]: - if v is not None: + cls, + v: Optional[Union[MatrixType, str]], + values: Dict[str, Any], + ) -> Optional[Union[MatrixType, str]]: + time_step = values["time_step"] + if v is None: + # The matrix is not updated + return None + if isinstance(v, str): + # Check the matrix link + return validate_matrix(v, values) + if isinstance(v, list): + check_matrix_values(time_step, v) return validate_matrix(v, values) - return None + # Invalid datatype + # pragma: no cover + raise TypeError(repr(v)) def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: return CommandOutput(status=True), {} @@ -81,22 +88,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: ) def to_dto(self) -> CommandDTO: - args = { - "id": self.id, - "enabled": self.enabled, - "time_step": self.time_step.value, - "operator": self.operator.value, - "coeffs": self.coeffs, - "comments": self.comments, - "filter_year_by_year": self.filter_year_by_year, - "filter_synthesis": self.filter_synthesis, - } - if self.values is not None: - args["values"] = strip_matrix_protocol(self.values) - return CommandDTO( - action=CommandName.UPDATE_BINDING_CONSTRAINT.value, - args=args, - ) + dto = super().to_dto() + dto.args["id"] = self.id # type: ignore + return dto def match_signature(self) -> str: return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.id) @@ -119,9 +113,3 @@ def match(self, other: ICommand, equal: bool = False) -> bool: def _create_diff(self, other: "ICommand") -> List["ICommand"]: return [other] - - def get_inner_matrices(self) -> List[str]: - if self.values is not None: - assert_this(isinstance(self.values, str)) - return [strip_matrix_protocol(self.values)] - return [] diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 931816391a..610e3b7263 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -1577,8 +1577,15 @@ def create_st_storage( Args: - `uuid`: The UUID of the study. - `area_id`: The area ID. - - `form`: The name and the group(PSP_open, PSP_closed, Pondage, Battery, Other1, Other2, Other3, Other4, Other5) - of the storage that we want to create. + - `form`: The characteristic of the storage that we can update: + - `name`: The name of the updated storage. + - `group`: The group of the updated storage. + - `injectionNominalCapacity`: The injection Nominal Capacity of the updated storage. + - `withdrawalNominalCapacity`: The withdrawal Nominal Capacity of the updated storage. + - `reservoirCapacity`: The reservoir capacity of the updated storage. + - `efficiency`: The efficiency of the updated storage + - `initialLevel`: The initial Level of the updated storage + - `initialLevelOptim`: The initial Level Optim of the updated storage Returns: New storage with the following attributes: - `id`: The storage ID of the study. diff --git a/antarest/study/web/xpansion_studies_blueprint.py b/antarest/study/web/xpansion_studies_blueprint.py index 98c11a91c2..af9496ba4c 100644 --- a/antarest/study/web/xpansion_studies_blueprint.py +++ b/antarest/study/web/xpansion_studies_blueprint.py @@ -162,7 +162,7 @@ def get_candidates( uuid: str, current_user: JWTUser = Depends(auth.get_current_user), ) -> Any: - logger.info(f"Fetching study list", extra={"user": current_user.id}) + logger.info("Fetching study list", extra={"user": current_user.id}) params = RequestParameters(user=current_user) return study_service.get_candidates(uuid, params) diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index f939d8ea61..d37a8825f5 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -1,9 +1,10 @@ +import io import logging import subprocess import threading import time from pathlib import Path -from typing import IO, cast +from typing import cast from pydantic import BaseModel @@ -12,7 +13,7 @@ from antarest.core.interfaces.eventbus import IEventBus from antarest.core.tasks.model import TaskResult from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.launcher.adapters.log_manager import LogTailManager +from antarest.launcher.adapters.log_manager import follow from antarest.matrixstore.service import MatrixService from antarest.matrixstore.uri_resolver_service import UriResolverService from antarest.study.storage.rawstudy.model.filesystem.factory import StudyFactory @@ -101,8 +102,8 @@ def stop_reading() -> bool: encoding="utf-8", ) thread = threading.Thread( - target=lambda: LogTailManager.follow( - cast(IO[str], process.stdout), + target=lambda: follow( + cast(io.StringIO, process.stdout), append_output, stop_reading, None, diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index e6d96180aa..4fc768bc04 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,6 +1,38 @@ Antares Web Changelog ===================== +v2.15.2 (2023-10-11) +-------------------- + +### Hotfix + +* **service:** user connected via tokens cannot create a study (#1757) ([f620197](https://github.com/AntaresSimulatorTeam/AntaREST/commit/f6201976a653db19739cbc42e91ea27ac790da10)) + + +### Features + +* **binding-constraint:** handling binding constraints frequency in study configuration parsing (#1702) ([703351a](https://github.com/AntaresSimulatorTeam/AntaREST/commit/703351a6d8d4f70491e66c3c54a92c6d28cb92ea)) + - add the binding constraint series in the matrix constants generator ([e00d58b](https://github.com/AntaresSimulatorTeam/AntaREST/commit/e00d58b203023363860cb0e849576e02ed97fd81)) + - command `create_binding_constraint` can check the matrix shape ([68bf99f](https://github.com/AntaresSimulatorTeam/AntaREST/commit/68bf99f1170181f6111bc15c03ede27030f809d2)) + - command `update_binding_constraint` can check the matrix shape ([c962f73](https://github.com/AntaresSimulatorTeam/AntaREST/commit/c962f7344c7ea07c7a8c7699b2af35f90f3b853c)) + - add missing command docstring ([d277805](https://github.com/AntaresSimulatorTeam/AntaREST/commit/d277805c10d3f9c7134166e6d2f7170c7b752428)) + - reduce code duplication ([b41d957](https://github.com/AntaresSimulatorTeam/AntaREST/commit/b41d957cffa6a8dde21a022f8b6c24c8de2559a2)) + - correct `test_command_factory` unit test to ignore abstract commands ([789c2ad](https://github.com/AntaresSimulatorTeam/AntaREST/commit/789c2adfc3ef3999f3779a345e0730f2f9ad906a)) +* **api:** add endpoint get_nb_cores (#1727) ([9cfa9f1](https://github.com/AntaresSimulatorTeam/AntaREST/commit/9cfa9f13d363ea4f73aa31ed760d525b091f04a4)) +* **st-storage:** allow all parameters in endpoint for short term storage creation (#1736) ([853cf6b](https://github.com/AntaresSimulatorTeam/AntaREST/commit/853cf6ba48a23d39f247a0842afac440c4ea4570)) + + +### Chore + +* **sonarcloud:** correct SonarCloud issues ([901d00d](https://github.com/AntaresSimulatorTeam/AntaREST/commit/901d00df558f7e79b728e2ce7406d1bdea69f839)) + + +### Contributors + +laurent-laporte-pro, +MartinBelthle + + v2.15.1 (2023-10-05) -------------------- diff --git a/resources/application.yaml b/resources/application.yaml index 75a7b8605e..a85357634f 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -20,9 +20,9 @@ db: #pool_recycle: storage: - tmp_dir: /tmp + tmp_dir: ./tmp matrixstore: ./matrices - archive_dir: examples/archives + archive_dir: ./examples/archives allow_deletion: false # indicate if studies found in non default workspace can be deleted by the application #matrix_gc_sleeping_time: 3600 # time in seconds to sleep between two garbage collection #matrix_gc_dry_run: False # Skip matrix effective deletion @@ -32,20 +32,23 @@ storage: #auto_archive_max_parallel: 5 # max auto archival tasks in parallel workspaces: default: # required, no filters applied, this folder is not watched - path: examples/internal_studies/ + path: ./examples/internal_studies/ # other workspaces can be added # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside tmp: - path: examples/studies/ + path: ./examples/studies/ # filter_in: ['.*'] # default to '.*' # filter_out: [] # default to empty # groups: [] # default empty launcher: default: local + local: binaries: 700: path/to/700 + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -56,7 +59,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name @@ -70,7 +77,7 @@ launcher: debug: true -root_path: "" +root_path: "api" #tasks: # max_workers: 5 diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index 1bb5e30878..02fbb4b8bc 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -32,9 +32,12 @@ storage: launcher: default: local + local: binaries: 800: /antares_simulator/antares-8.2-solver + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -45,7 +48,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name @@ -59,7 +66,7 @@ launcher: debug: false -root_path: "/api" +root_path: "api" #tasks: # max_workers: 5 diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 48cea48a22..810e1f8d24 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -29,9 +29,12 @@ storage: launcher: default: local + local: binaries: 700: path/to/700 + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -42,7 +45,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name diff --git a/setup.py b/setup.py index 92e4e34f98..d0b8c5deb3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="AntaREST", - version="2.15.1", + version="2.15.2", description="Antares Server", long_description=Path("README.md").read_text(encoding="utf-8"), long_description_content_type="text/markdown", diff --git a/sonar-project.properties b/sonar-project.properties index 9bf66c1c14..1d3a132a04 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -6,5 +6,5 @@ sonar.exclusions=antarest/gui.py,antarest/main.py sonar.python.coverage.reportPaths=coverage.xml sonar.python.version=3.8 sonar.javascript.lcov.reportPaths=webapp/coverage/lcov.info -sonar.projectVersion=2.15.1 +sonar.projectVersion=2.15.2 sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/* \ No newline at end of file diff --git a/tests/conftest_db.py b/tests/conftest_db.py index 877ca119d1..bcb4177766 100644 --- a/tests/conftest_db.py +++ b/tests/conftest_db.py @@ -3,7 +3,8 @@ import pytest from sqlalchemy import create_engine # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, sessionmaker # type: ignore from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base @@ -12,7 +13,7 @@ @pytest.fixture(name="db_engine") -def db_engine_fixture() -> Generator[Any, None, None]: +def db_engine_fixture() -> Generator[Engine, None, None]: """ Fixture that creates an in-memory SQLite database engine for testing. @@ -26,7 +27,7 @@ def db_engine_fixture() -> Generator[Any, None, None]: @pytest.fixture(name="db_session") -def db_session_fixture(db_engine) -> Generator: +def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]: """ Fixture that creates a database session for testing purposes. @@ -46,7 +47,7 @@ def db_session_fixture(db_engine) -> Generator: @pytest.fixture(name="db_middleware", autouse=True) def db_middleware_fixture( - db_engine: Any, + db_engine: Engine, ) -> Generator[DBSessionMiddleware, None, None]: """ Fixture that sets up a database session middleware with custom engine settings. diff --git a/tests/core/assets/__init__.py b/tests/core/assets/__init__.py new file mode 100644 index 0000000000..773f16ec60 --- /dev/null +++ b/tests/core/assets/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +ASSETS_DIR = Path(__file__).parent.resolve() diff --git a/tests/core/assets/config/application-2.14.yaml b/tests/core/assets/config/application-2.14.yaml new file mode 100644 index 0000000000..650093286d --- /dev/null +++ b/tests/core/assets/config/application-2.14.yaml @@ -0,0 +1,61 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + default_n_cpu: 20 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/assets/config/application-2.15.yaml b/tests/core/assets/config/application-2.15.yaml new file mode 100644 index 0000000000..c51d32aaae --- /dev/null +++ b/tests/core/assets/config/application-2.15.yaml @@ -0,0 +1,66 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + enable_nb_cores_detection: True + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + enable_nb_cores_detection: False + nb_cores: + min: 1 + default: 22 + max: 24 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 1c1c96a180..00c6f9458d 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,15 +1,253 @@ from pathlib import Path +from unittest import mock import pytest -from antarest.core.config import Config +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, +) +from tests.core.assets import ASSETS_DIR +LAUNCHER_CONFIG = { + "default": "slurm", + "local": { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + }, + "slurm": { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + }, + "batch_size": 100, +} -@pytest.mark.unit_test -def test_get_yaml(project_path: Path): - config = Config.from_yaml_file(file=project_path / "resources/application.yaml") - assert config.security.admin_pwd == "admin" - assert config.storage.workspaces["default"].path == Path("examples/internal_studies/") - assert not config.logging.json - assert config.logging.level == "INFO" +class TestNbCoresConfig: + def test_init__default_values(self): + config = NbCoresConfig() + assert config.min == 1 + assert config.default == 22 + assert config.max == 24 + + def test_init__invalid_values(self): + with pytest.raises(ValueError): + # default < min + NbCoresConfig(min=2, default=1, max=24) + with pytest.raises(ValueError): + # default > max + NbCoresConfig(min=1, default=25, max=24) + with pytest.raises(ValueError): + # min < 0 + NbCoresConfig(min=0, default=22, max=23) + with pytest.raises(ValueError): + # min > max + NbCoresConfig(min=22, default=22, max=21) + + def test_to_json(self): + config = NbCoresConfig() + # ReactJs Material UI expects "min", "defaultValue" and "max" keys + assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24} + + +class TestLocalConfig: + def test_init__default_values(self): + config = LocalConfig() + assert config.binaries == {}, "binaries should be empty by default" + assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__auto_detect(self): + with mock.patch("multiprocessing.cpu_count", return_value=8): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": True, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8) + + +class TestSlurmConfig: + def test_init__default_values(self): + config = SlurmConfig() + assert config.local_workspace == Path() + assert config.username == "" + assert config.hostname == "" + assert config.port == 0 + assert config.private_key_file == Path() + assert config.key_password == "" + assert config.password == "" + assert config.default_wait_time == 0 + assert config.default_time_limit == 0 + assert config.default_json_db_name == "" + assert config.slurm_script_path == "" + assert config.max_cores == 64 + assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default" + assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.local_workspace == Path("/home/john/antares/workspace") + assert config.username == "john" + assert config.hostname == "slurm-001" + assert config.port == 22 + assert config.private_key_file == Path("/home/john/.ssh/id_rsa") + assert config.key_password == "password" + assert config.password == "password" + assert config.default_wait_time == 10 + assert config.default_time_limit == 20 + assert config.default_json_db_name == "antares.db" + assert config.slurm_script_path == "/path/to/slurm/launcher.sh" + assert config.max_cores == 32 + assert config.antares_versions_on_remote_server == ["860"] + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__default_n_cpu__backport(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "default_n_cpu": 15, + } + ) + assert config.nb_cores == NbCoresConfig(min=1, default=15, max=24) + + def test_from_dict__auto_detect(self): + with pytest.raises(NotImplementedError): + SlurmConfig.from_dict({"enable_nb_cores_detection": True}) + + +class TestLauncherConfig: + def test_init__default_values(self): + config = LauncherConfig() + assert config.default == "local", "default launcher should be local" + assert config.local is None + assert config.slurm is None + assert config.batch_size == 9999 + + def test_from_dict(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.default == "slurm" + assert config.local == LocalConfig( + binaries={"860": Path("/bin/solver-860.exe")}, + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=2, default=10, max=20), + ) + assert config.slurm == SlurmConfig( + local_workspace=Path("/home/john/antares/workspace"), + username="john", + hostname="slurm-001", + port=22, + private_key_file=Path("/home/john/.ssh/id_rsa"), + key_password="password", + password="password", + default_wait_time=10, + default_time_limit=20, + default_json_db_name="antares.db", + slurm_script_path="/path/to/slurm/launcher.sh", + max_cores=32, + antares_versions_on_remote_server=["860"], + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=1, default=34, max=36), + ) + assert config.batch_size == 100 + + def test_init__invalid_launcher(self): + with pytest.raises(ValueError): + LauncherConfig(default="invalid_launcher") + + def test_get_nb_cores__default(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + # default == "slurm" + assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__local(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20) + + def test_get_nb_cores__slurm(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__invalid_configuration(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("invalid_launcher") + config = LauncherConfig.from_dict({}) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("slurm") + + +class TestConfig: + @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"]) + def test_from_yaml_file(self, config_name: str) -> None: + yaml_path = ASSETS_DIR.joinpath("config", config_name) + config = Config.from_yaml_file(yaml_path) + assert config.security.admin_pwd == "admin" + assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies") + assert not config.logging.json + assert config.logging.level == "INFO" diff --git a/tests/integration/assets/base_study.zip b/tests/integration/assets/base_study.zip index 28794fde9e..712833942c 100644 Binary files a/tests/integration/assets/base_study.zip and b/tests/integration/assets/base_study.zip differ diff --git a/tests/integration/assets/config.template.yml b/tests/integration/assets/config.template.yml index f3ad1d256f..71c58a1ba0 100644 --- a/tests/integration/assets/config.template.yml +++ b/tests/integration/assets/config.template.yml @@ -27,6 +27,7 @@ launcher: local: binaries: 700: {{launcher_mock}} + enable_nb_cores_detection: True debug: false diff --git a/tests/integration/assets/variant_study.zip b/tests/integration/assets/variant_study.zip index 5693b29a61..4e08926cda 100644 Binary files a/tests/integration/assets/variant_study.zip and b/tests/integration/assets/variant_study.zip differ diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py new file mode 100644 index 0000000000..7244fba8ee --- /dev/null +++ b/tests/integration/launcher_blueprint/test_launcher_local.py @@ -0,0 +1,70 @@ +import http + +import pytest +from starlette.testclient import TestClient + +from antarest.core.config import LocalConfig + + +# noinspection SpellCheckingInspection +@pytest.mark.integration_test +class TestLauncherNbCores: + """ + The purpose of this unit test is to check the `/v1/launcher/nbcores` endpoint. + """ + + def test_get_launcher_nb_cores( + self, + client: TestClient, + user_access_token: str, + ) -> None: + # NOTE: we have `enable_nb_cores_detection: True` in `tests/integration/assets/config.template.yml`. + local_nb_cores = LocalConfig.from_dict({"enable_nb_cores_detection": True}).nb_cores + nb_cores_expected = local_nb_cores.to_json() + res = client.get( + "/v1/launcher/nbcores", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + res = client.get( + "/v1/launcher/nbcores?launcher=default", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + res = client.get( + "/v1/launcher/nbcores?launcher=local", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + # Check that the endpoint raise an exception when the "slurm" launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=slurm", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'slurm'", + "exception": "UnknownSolverConfig", + } + + # Check that the endpoint raise an exception when an unknown launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=unknown", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'unknown'", + "exception": "UnknownSolverConfig", + } diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index cdd0b464c7..2d1035af3e 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -1,4 +1,3 @@ -import json import re import numpy as np @@ -9,6 +8,17 @@ from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id from tests.integration.utils import wait_task_completion +DEFAULT_PROPERTIES = { + # `name` field is required + "group": "Other1", + "injectionNominalCapacity": 0.0, + "withdrawalNominalCapacity": 0.0, + "reservoirCapacity": 0.0, + "efficiency": 1.0, + "initialLevel": 0.0, + "initialLevelOptim": False, +} + @pytest.mark.unit_test class TestSTStorage: @@ -61,28 +71,46 @@ def test_lifecycle__nominal( task = wait_task_completion(client, user_access_token, task_id) assert task.status == TaskStatus.COMPLETED, task - # creation with default values (only mandatory properties specified) + # ============================= + # SHORT-TERM STORAGE CREATION + # ============================= + area_id = transform_name_to_id("FR") siemens_battery = "Siemens Battery" + + # Un attempt to create a short-term storage without name + # should raise a validation error (other properties are optional). + # Un attempt to create a short-term storage with an empty name + # or an invalid name should also raise a validation error. + attempts = [{}, {"name": ""}, {"name": "!??"}] + for attempt in attempts: + res = client.post( + f"/v1/studies/{study_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=attempt, + ) + assert res.status_code == 422, res.json() + assert res.json()["exception"] in {"ValidationError", "RequestValidationError"}, res.json() + + # We can create a short-term storage with the following properties: + siemens_properties = { + **DEFAULT_PROPERTIES, + "name": siemens_battery, + "group": "Battery", + "injectionNominalCapacity": 1450, + "withdrawalNominalCapacity": 1350, + "reservoirCapacity": 1500, + } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": siemens_battery, "group": "Battery"}, + json=siemens_properties, ) assert res.status_code == 200, res.json() siemens_battery_id = res.json()["id"] assert siemens_battery_id == transform_name_to_id(siemens_battery) - assert res.json() == { - "efficiency": 1.0, - "group": "Battery", - "id": siemens_battery_id, - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "name": siemens_battery, - "reservoirCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - } + siemens_config = {**siemens_properties, "id": siemens_battery_id} + assert res.json() == siemens_config # reading the properties of a short-term storage res = client.get( @@ -90,17 +118,11 @@ def test_lifecycle__nominal( headers={"Authorization": f"Bearer {user_access_token}"}, ) assert res.status_code == 200, res.json() - assert res.json() == { - "efficiency": 1.0, - "group": "Battery", - "id": siemens_battery_id, - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "name": siemens_battery, - "reservoirCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - } + assert res.json() == siemens_config + + # ============================= + # SHORT-TERM STORAGE MATRICES + # ============================= # updating the matrix of a short-term storage array = np.random.rand(8760, 1) * 1000 @@ -134,25 +156,17 @@ def test_lifecycle__nominal( assert res.status_code == 200, res.json() assert res.json() is True + # ================================== + # SHORT-TERM STORAGE LIST / GROUPS + # ================================== + # Reading the list of short-term storages res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, ) assert res.status_code == 200, res.json() - assert res.json() == [ - { - "efficiency": 1.0, - "group": "Battery", - "id": siemens_battery_id, - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "name": siemens_battery, - "reservoirCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - } - ] + assert res.json() == [siemens_config] # updating properties res = client.patch( @@ -164,34 +178,23 @@ def test_lifecycle__nominal( }, ) assert res.status_code == 200, res.json() - assert json.loads(res.text) == { - "id": siemens_battery_id, + siemens_config = { + **siemens_config, "name": "New Siemens Battery", - "group": "Battery", - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, "reservoirCapacity": 2500, } + assert res.json() == siemens_config res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, ) assert res.status_code == 200, res.json() - assert res.json() == { - "id": siemens_battery_id, - "name": "New Siemens Battery", - "group": "Battery", - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - "reservoirCapacity": 2500, - } + assert res.json() == siemens_config + + # =========================== + # SHORT-TERM STORAGE UPDATE + # =========================== # updating properties res = client.patch( @@ -202,37 +205,38 @@ def test_lifecycle__nominal( "reservoirCapacity": 0, }, ) - assert res.status_code == 200, res.json() - assert json.loads(res.text) == { - "id": siemens_battery_id, - "name": "New Siemens Battery", - "group": "Battery", - "efficiency": 1.0, + siemens_config = { + **siemens_config, "initialLevel": 5900, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, "reservoirCapacity": 0, } + assert res.json() == siemens_config + # An attempt to update the `efficiency` property with an invalid value + # should raise a validation error. + # The `efficiency` property must be a float between 0 and 1. + bad_properties = {"efficiency": 2.0} + res = client.patch( + f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=bad_properties, + ) + assert res.status_code == 422, res.json() + assert res.json()["exception"] == "ValidationError", res.json() + + # The short-term storage properties should not have been updated. res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, ) assert res.status_code == 200, res.json() - assert res.json() == { - "id": siemens_battery_id, - "name": "New Siemens Battery", - "group": "Battery", - "efficiency": 1.0, - "initialLevel": 5900, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - "reservoirCapacity": 0, - } + assert res.json() == siemens_config + + # ============================= + # SHORT-TERM STORAGE DELETION + # ============================= - # deletion of short-term storages + # To delete a short-term storage, we need to provide its ID. res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", @@ -242,7 +246,7 @@ def test_lifecycle__nominal( assert res.status_code == 204, res.json() assert res.text in {"", "null"} # Old FastAPI versions return 'null'. - # deletion of short-term storages with empty list + # If the short-term storage list is empty, the deletion should be a no-op. res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", @@ -252,48 +256,79 @@ def test_lifecycle__nominal( assert res.status_code == 204, res.json() assert res.text in {"", "null"} # Old FastAPI versions return 'null'. - # deletion of short-term storages with multiple IDs + # It's possible to delete multiple short-term storages at once. + # In the following example, we will create two short-term storages: + siemens_properties = { + "name": siemens_battery, + "group": "Battery", + "injectionNominalCapacity": 1450, + "withdrawalNominalCapacity": 1350, + "reservoirCapacity": 1500, + "efficiency": 0.90, + "initialLevel": 200, + "initialLevelOptim": False, + } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": siemens_battery, "group": "Battery"}, + json=siemens_properties, ) assert res.status_code == 200, res.json() - siemens_battery_id1 = res.json()["id"] - - siemens_battery_del = f"{siemens_battery}del" + siemens_battery_id = res.json()["id"] + # Create another short-term storage: "Grand'Maison" + grand_maison = "Grand'Maison" + grand_maison_properties = { + "name": grand_maison, + "group": "PSP_closed", + "injectionNominalCapacity": 1500, + "withdrawalNominalCapacity": 1800, + "reservoirCapacity": 20000, + "efficiency": 0.78, + "initialLevel": 10000, + } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": siemens_battery_del, "group": "Battery"}, + json=grand_maison_properties, ) assert res.status_code == 200, res.json() - siemens_battery_id2 = res.json()["id"] + grand_maison_id = res.json()["id"] + # We can check that we have 2 short-term storages in the list. + # Reading the list of short-term storages + res = client.get( + f"/v1/studies/{study_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + siemens_config = {**DEFAULT_PROPERTIES, **siemens_properties, "id": siemens_battery_id} + grand_maison_config = {**DEFAULT_PROPERTIES, **grand_maison_properties, "id": grand_maison_id} + assert res.json() == [siemens_config, grand_maison_config] + + # We can delete the two short-term storages at once. res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json=[siemens_battery_id1, siemens_battery_id2], + json=[siemens_battery_id, grand_maison_id], ) assert res.status_code == 204, res.json() assert res.text in {"", "null"} # Old FastAPI versions return 'null'. - # Check the removal + # The list of short-term storages should be empty. res = client.get( - f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", + f"/v1/studies/{study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, ) - obj = res.json() - description = obj["description"] - assert siemens_battery_id in description - assert re.search(r"fields of storage", description, flags=re.IGNORECASE) - assert re.search(r"not found", description, flags=re.IGNORECASE) + assert res.status_code == 200, res.json() + assert res.json() == [] - assert res.status_code == 404, res.json() + # =========================== + # SHORT-TERM STORAGE ERRORS + # =========================== - # Check delete with the wrong value of area_id + # Check delete with the wrong value of `area_id` bad_area_id = "bad_area" res = client.request( "DELETE", @@ -311,7 +346,7 @@ def test_lifecycle__nominal( flags=re.IGNORECASE, ) - # Check delete with the wrong value of study_id + # Check delete with the wrong value of `study_id` bad_study_id = "bad_study" res = client.request( "DELETE", @@ -324,8 +359,7 @@ def test_lifecycle__nominal( assert res.status_code == 404, res.json() assert bad_study_id in description - # Check get with wrong area_id - + # Check get with wrong `area_id` res = client.get( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -335,8 +369,7 @@ def test_lifecycle__nominal( assert bad_area_id in description assert res.status_code == 404, res.json() - # Check get with wrong study_id - + # Check get with wrong `study_id` res = client.get( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -346,7 +379,7 @@ def test_lifecycle__nominal( assert res.status_code == 404, res.json() assert bad_study_id in description - # Check post with wrong study_id + # Check POST with wrong `study_id` res = client.post( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -357,11 +390,20 @@ def test_lifecycle__nominal( assert res.status_code == 404, res.json() assert bad_study_id in description - # Check post with wrong area_id + # Check POST with wrong `area_id` res = client.post( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": siemens_battery, "group": "Battery"}, + json={ + "name": siemens_battery, + "group": "Battery", + "initialLevel": 0.0, + "initialLevelOptim": False, + "injectionNominalCapacity": 0.0, + "reservoirCapacity": 0.0, + "withdrawalNominalCapacity": 0.0, + "efficiency": 1.0, + }, ) assert res.status_code == 500, res.json() obj = res.json() @@ -370,7 +412,7 @@ def test_lifecycle__nominal( assert re.search(r"Area ", description, flags=re.IGNORECASE) assert re.search(r"does not exist ", description, flags=re.IGNORECASE) - # Check post with wrong group + # Check POST with wrong `group` res = client.post( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -381,7 +423,7 @@ def test_lifecycle__nominal( description = obj["description"] assert re.search(r"not a valid enumeration member", description, flags=re.IGNORECASE) - # Check the put with the wrong area_id + # Check PATCH with the wrong `area_id` res = client.patch( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -401,7 +443,7 @@ def test_lifecycle__nominal( assert bad_area_id in description assert re.search(r"not a child of ", description, flags=re.IGNORECASE) - # Check the put with the wrong siemens_battery_id + # Check PATCH with the wrong `siemens_battery_id` res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -422,7 +464,7 @@ def test_lifecycle__nominal( assert re.search(r"fields of storage", description, flags=re.IGNORECASE) assert re.search(r"not found", description, flags=re.IGNORECASE) - # Check the put with the wrong study_id + # Check PATCH with the wrong `study_id` res = client.patch( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -440,19 +482,3 @@ def test_lifecycle__nominal( obj = res.json() description = obj["description"] assert bad_study_id in description - - # Check the put with the wrong efficiency - res = client.patch( - f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "efficiency": 2.0, - "initialLevel": 0.0, - "initialLevelOptim": True, - "injectionNominalCapacity": 2450, - "name": "New Siemens Battery", - "reservoirCapacity": 2500, - "withdrawalNominalCapacity": 2350, - }, - ) - assert res.status_code == 422, res.json() diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 7fd8131a54..df4249ca47 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -529,7 +529,7 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id: ], ) - client.post( + res = client.post( f"/v1/studies/{study_id}/commands", headers=admin_headers, json=[ @@ -545,8 +545,9 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id: } ], ) + res.raise_for_status() - client.post( + res = client.post( f"/v1/studies/{study_id}/commands", headers=admin_headers, json=[ @@ -562,6 +563,7 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id: } ], ) + res.raise_for_status() res_areas = client.get(f"/v1/studies/{study_id}/areas", headers=admin_headers) assert res_areas.json() == [ diff --git a/tests/integration/test_integration_variantmanager_tool.py b/tests/integration/test_integration_variantmanager_tool.py index d0315a31e1..a247d81c1d 100644 --- a/tests/integration/test_integration_variantmanager_tool.py +++ b/tests/integration/test_integration_variantmanager_tool.py @@ -1,19 +1,15 @@ -import os +import io import urllib.parse from pathlib import Path from typing import List, Tuple from zipfile import ZipFile +import numpy as np +import numpy.typing as npt from fastapi import FastAPI from starlette.testclient import TestClient from antarest.study.storage.rawstudy.io.reader import IniReader -from antarest.study.storage.rawstudy.model.filesystem.matrix.constants import ( - default_4_fixed_hourly, - default_8_fixed_hourly, - default_scenario_daily, - default_scenario_hourly, -) from antarest.study.storage.variantstudy.model.command.common import CommandName from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO from antarest.tools.lib import ( @@ -29,11 +25,10 @@ test_dir: Path = Path(__file__).parent -def generate_csv_string(data: List[List[float]]) -> str: - csv_str = "" - for row in data: - csv_str += "\t".join(["{:.6f}".format(v) for v in row]) + "\n" - return csv_str +def generate_csv_string(array: npt.NDArray[np.float64]) -> str: + buffer = io.StringIO() + np.savetxt(buffer, array, delimiter="\t", fmt="%.6f") + return buffer.getvalue() def generate_study_with_server( @@ -60,7 +55,7 @@ def generate_study_with_server( return generator.apply_commands(commands, matrices_dir), variant_id -def test_variant_manager(app: FastAPI, tmp_path: str): +def test_variant_manager(app: FastAPI, tmp_path: str) -> None: client = TestClient(app, raise_server_exceptions=False) commands = parse_commands(test_dir / "assets" / "commands1.json") matrix_dir = Path(tmp_path) / "empty_matrix_store" @@ -69,7 +64,7 @@ def test_variant_manager(app: FastAPI, tmp_path: str): assert res is not None and res.success -def test_parse_commands(tmp_path: str, app: FastAPI): +def test_parse_commands(tmp_path: str, app: FastAPI) -> None: base_dir = test_dir / "assets" export_path = Path(tmp_path) / "commands" study = "base_study" @@ -92,138 +87,133 @@ def test_parse_commands(tmp_path: str, app: FastAPI): assert generated_study_path.exists() and generated_study_path.is_dir() single_column_empty_items = [ - f"input{os.sep}load{os.sep}series{os.sep}load_hub w.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_south.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_hub n.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_west.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_north.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_hub s.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_hub e.txt", - f"input{os.sep}load{os.sep}series{os.sep}load_east.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_east.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_north.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_hub n.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_south.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_hub w.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_west.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_hub e.txt", - f"input{os.sep}wind{os.sep}series{os.sep}wind_hub s.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_east.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_hub n.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_south.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_hub s.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_north.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_hub w.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_hub e.txt", - f"input{os.sep}solar{os.sep}series{os.sep}solar_west.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}semi base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}peak{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}semi base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}peak{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}semi base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}peak{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}semi base{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}peak{os.sep}series.txt", - f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}base{os.sep}series.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub e{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}south{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub w{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub s{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}west{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub n{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}north{os.sep}ror.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}east{os.sep}ror.txt", + "input/load/series/load_hub w.txt", + "input/load/series/load_south.txt", + "input/load/series/load_hub n.txt", + "input/load/series/load_west.txt", + "input/load/series/load_north.txt", + "input/load/series/load_hub s.txt", + "input/load/series/load_hub e.txt", + "input/load/series/load_east.txt", + "input/wind/series/wind_east.txt", + "input/wind/series/wind_north.txt", + "input/wind/series/wind_hub n.txt", + "input/wind/series/wind_south.txt", + "input/wind/series/wind_hub w.txt", + "input/wind/series/wind_west.txt", + "input/wind/series/wind_hub e.txt", + "input/wind/series/wind_hub s.txt", + "input/solar/series/solar_east.txt", + "input/solar/series/solar_hub n.txt", + "input/solar/series/solar_south.txt", + "input/solar/series/solar_hub s.txt", + "input/solar/series/solar_north.txt", + "input/solar/series/solar_hub w.txt", + "input/solar/series/solar_hub e.txt", + "input/solar/series/solar_west.txt", + "input/thermal/series/west/semi base/series.txt", + "input/thermal/series/west/peak/series.txt", + "input/thermal/series/west/base/series.txt", + "input/thermal/series/north/semi base/series.txt", + "input/thermal/series/north/peak/series.txt", + "input/thermal/series/north/base/series.txt", + "input/thermal/series/east/semi base/series.txt", + "input/thermal/series/east/peak/series.txt", + "input/thermal/series/east/base/series.txt", + "input/thermal/series/south/semi base/series.txt", + "input/thermal/series/south/peak/series.txt", + "input/thermal/series/south/base/series.txt", + "input/hydro/series/hub e/ror.txt", + "input/hydro/series/south/ror.txt", + "input/hydro/series/hub w/ror.txt", + "input/hydro/series/hub s/ror.txt", + "input/hydro/series/west/ror.txt", + "input/hydro/series/hub n/ror.txt", + "input/hydro/series/north/ror.txt", + "input/hydro/series/east/ror.txt", ] single_column_daily_empty_items = [ - f"input{os.sep}hydro{os.sep}series{os.sep}hub e{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}south{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub w{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub s{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}west{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}hub n{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}north{os.sep}mod.txt", - f"input{os.sep}hydro{os.sep}series{os.sep}east{os.sep}mod.txt", + "input/hydro/series/hub e/mod.txt", + "input/hydro/series/south/mod.txt", + "input/hydro/series/hub w/mod.txt", + "input/hydro/series/hub s/mod.txt", + "input/hydro/series/west/mod.txt", + "input/hydro/series/hub n/mod.txt", + "input/hydro/series/north/mod.txt", + "input/hydro/series/east/mod.txt", + ] + fixed_3_cols_hourly_empty_items = [ + "input/bindingconstraints/northern mesh.txt", + "input/bindingconstraints/southern mesh.txt", ] fixed_4_cols_empty_items = [ - f"input{os.sep}reserves{os.sep}hub s.txt", - f"input{os.sep}reserves{os.sep}hub n.txt", - f"input{os.sep}reserves{os.sep}hub w.txt", - f"input{os.sep}reserves{os.sep}hub e.txt", + "input/reserves/hub s.txt", + "input/reserves/hub n.txt", + "input/reserves/hub w.txt", + "input/reserves/hub e.txt", ] fixed_8_cols_empty_items = [ - f"input{os.sep}misc-gen{os.sep}miscgen-hub w.txt", - f"input{os.sep}misc-gen{os.sep}miscgen-hub e.txt", - f"input{os.sep}misc-gen{os.sep}miscgen-hub s.txt", - f"input{os.sep}misc-gen{os.sep}miscgen-hub n.txt", + "input/misc-gen/miscgen-hub w.txt", + "input/misc-gen/miscgen-hub e.txt", + "input/misc-gen/miscgen-hub s.txt", + "input/misc-gen/miscgen-hub n.txt", ] - single_column_empty_data = generate_csv_string(default_scenario_hourly) - single_column_daily_empty_data = generate_csv_string(default_scenario_daily) - fixed_4_columns_empty_data = generate_csv_string(default_4_fixed_hourly) - fixed_8_columns_empty_data = generate_csv_string(default_8_fixed_hourly) - for root, dirs, files in os.walk(study_path): - rel_path = root[len(str(study_path)) + 1 :] - for item in files: - if item in [ - "comments.txt", - "study.antares", - "Desktop.ini", - "study.ico", - ]: - continue - elif f"{rel_path}{os.sep}{item}" in single_column_empty_items: - assert (generated_study_path / rel_path / item).read_text() == single_column_empty_data - elif f"{rel_path}{os.sep}{item}" in single_column_daily_empty_items: - assert (generated_study_path / rel_path / item).read_text() == single_column_daily_empty_data - elif f"{rel_path}{os.sep}{item}" in fixed_4_cols_empty_items: - assert (generated_study_path / rel_path / item).read_text() == fixed_4_columns_empty_data - elif f"{rel_path}{os.sep}{item}" in fixed_8_cols_empty_items: - assert (generated_study_path / rel_path / item).read_text() == fixed_8_columns_empty_data - else: - actual = (study_path / rel_path / item).read_text() - expected = (generated_study_path / rel_path / item).read_text() - assert actual.strip() == expected.strip() - - -def test_diff_local(tmp_path: Path): + single_column_empty_data = generate_csv_string(np.zeros((8760, 1), dtype=np.float64)) + single_column_daily_empty_data = generate_csv_string(np.zeros((365, 1), dtype=np.float64)) + fixed_3_cols_hourly_empty_data = generate_csv_string(np.zeros(shape=(8760, 3), dtype=np.float64)) + fixed_4_columns_empty_data = generate_csv_string(np.zeros((8760, 4), dtype=np.float64)) + fixed_8_columns_empty_data = generate_csv_string(np.zeros((8760, 8), dtype=np.float64)) + for file_path in study_path.rglob("*"): + if file_path.is_dir() or file_path.name in ["comments.txt", "study.antares", "Desktop.ini", "study.ico"]: + continue + item_relpath = file_path.relative_to(study_path).as_posix() + if item_relpath in single_column_empty_items: + assert (generated_study_path / item_relpath).read_text() == single_column_empty_data + elif item_relpath in single_column_daily_empty_items: + assert (generated_study_path / item_relpath).read_text() == single_column_daily_empty_data + elif item_relpath in fixed_3_cols_hourly_empty_items: + assert (generated_study_path / item_relpath).read_text() == fixed_3_cols_hourly_empty_data + elif item_relpath in fixed_4_cols_empty_items: + assert (generated_study_path / item_relpath).read_text() == fixed_4_columns_empty_data + elif item_relpath in fixed_8_cols_empty_items: + assert (generated_study_path / item_relpath).read_text() == fixed_8_columns_empty_data + else: + actual = (study_path / item_relpath).read_text() + expected = (generated_study_path / item_relpath).read_text() + assert actual.strip() == expected.strip() + + +def test_diff_local(tmp_path: Path) -> None: base_dir = test_dir / "assets" export_path = Path(tmp_path) / "generation_result" base_study = "base_study" variant_study = "variant_study" - output_study_commands = Path(export_path) / "output_study_commands" + output_study_commands = export_path / "output_study_commands" output_study_path = Path(tmp_path) / base_study - base_study_commands = Path(export_path) / base_study - variant_study_commands = Path(export_path) / variant_study + base_study_commands = export_path / base_study + variant_study_commands = export_path / variant_study variant_study_path = Path(tmp_path) / variant_study for study in [base_study, variant_study]: with ZipFile(base_dir / f"{study}.zip") as zip_output: zip_output.extractall(path=tmp_path) - extract_commands(Path(tmp_path) / study, Path(export_path) / study) + extract_commands(Path(tmp_path) / study, export_path / study) - res = generate_study(base_study_commands, None, str(Path(export_path) / "base_generated")) - res = generate_study( + generate_study(base_study_commands, None, str(export_path / "base_generated")) + generate_study( variant_study_commands, None, - str(Path(export_path) / "variant_generated"), + str(export_path / "variant_generated"), ) generate_diff(base_study_commands, variant_study_commands, output_study_commands) res = generate_study(output_study_commands, None, output=str(output_study_path)) assert res.success assert output_study_path.exists() and output_study_path.is_dir() - for root, dirs, files in os.walk(variant_study_path): - rel_path = root[len(str(variant_study_path)) + 1 :] - for item in files: - if item in [ - "comments.txt", - "study.antares", - "Desktop.ini", - "study.ico", - ]: - continue - actual = (variant_study_path / rel_path / item).read_text() - expected = (output_study_path / rel_path / item).read_text() - assert actual.strip() == expected.strip() + for file_path in variant_study_path.rglob("*"): + if file_path.is_dir() or file_path.name in ["comments.txt", "study.antares", "Desktop.ini", "study.ico"]: + continue + item_relpath = file_path.relative_to(variant_study_path).as_posix() + actual = (variant_study_path / item_relpath).read_text() + expected = (output_study_path / item_relpath).read_text() + assert actual.strip() == expected.strip() diff --git a/tests/launcher/test_local_launcher.py b/tests/launcher/test_local_launcher.py index 04741d319a..53adc03bf0 100644 --- a/tests/launcher/test_local_launcher.py +++ b/tests/launcher/test_local_launcher.py @@ -1,19 +1,28 @@ import os import textwrap +import uuid from pathlib import Path from unittest.mock import Mock, call -from uuid import uuid4 import pytest -from sqlalchemy import create_engine from antarest.core.config import Config, LauncherConfig, LocalConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.local_launcher.local_launcher import LocalLauncher from antarest.launcher.model import JobStatus, LauncherParametersDTO +SOLVER_NAME = "solver.bat" if os.name == "nt" else "solver.sh" + + +@pytest.fixture +def launcher_config(tmp_path: Path) -> Config: + """ + Fixture to create a launcher config with a local launcher. + """ + solver_path = tmp_path.joinpath(SOLVER_NAME) + data = {"binaries": {"700": solver_path}, "enable_nb_cores_detection": True} + return Config(launcher=LauncherConfig(local=LocalConfig.from_dict(data))) + @pytest.mark.unit_test def test_local_launcher__launcher_init_exception(): @@ -30,21 +39,12 @@ def test_local_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_compute(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - local_launcher = LocalLauncher(Config(), callbacks=Mock(), event_bus=Mock(), cache=Mock()) +def test_compute(tmp_path: Path, launcher_config: Config): + local_launcher = LocalLauncher(launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) # prepare a dummy executable to simulate Antares Solver if os.name == "nt": - solver_name = "solver.bat" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -55,8 +55,7 @@ def test_compute(tmp_path: Path): ) ) else: - solver_name = "solver.sh" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -68,8 +67,8 @@ def test_compute(tmp_path: Path): ) solver_path.chmod(0o775) - uuid = uuid4() - local_launcher.job_id_to_study_id = {str(uuid): ("study-id", tmp_path / "run", Mock())} + study_id = uuid.uuid4() + local_launcher.job_id_to_study_id = {str(study_id): ("study-id", tmp_path / "run", Mock())} local_launcher.callbacks.import_output.return_value = "some output" launcher_parameters = LauncherParametersDTO( adequacy_patch=None, @@ -86,15 +85,15 @@ def test_compute(tmp_path: Path): local_launcher._compute( antares_solver_path=solver_path, study_uuid="study-id", - uuid=uuid, + uuid=study_id, launcher_parameters=launcher_parameters, ) # noinspection PyUnresolvedReferences local_launcher.callbacks.update_status.assert_has_calls( [ - call(str(uuid), JobStatus.RUNNING, None, None), - call(str(uuid), JobStatus.SUCCESS, None, "some output"), + call(str(study_id), JobStatus.RUNNING, None, None), + call(str(study_id), JobStatus.SUCCESS, None, "some output"), ] ) diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index 2c9f94d89c..a6177c5e61 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -3,15 +3,24 @@ import time from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Literal, Union +from typing import Dict, List, Union from unittest.mock import Mock, call, patch from uuid import uuid4 from zipfile import ZIP_DEFLATED, ZipFile import pytest from sqlalchemy import create_engine - -from antarest.core.config import Config, LauncherConfig, LocalConfig, SlurmConfig, StorageConfig +from typing_extensions import Literal + +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, + StorageConfig, +) from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownload, FileDownloadDTO, FileDownloadTaskDTO from antarest.core.interfaces.eventbus import Event, EventType @@ -20,7 +29,7 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base -from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LogType +from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LauncherParametersDTO, LogType from antarest.launcher.service import ( EXECUTION_INFO_FILE, LAUNCHER_PARAM_NAME_SUFFIX, @@ -33,780 +42,908 @@ from antarest.study.model import OwnerInfo, PublicMode, Study, StudyMetadataDTO -@pytest.mark.unit_test -@patch.object(Auth, "get_current_user") -def test_service_run_study(get_current_user_mock): - get_current_user_mock.return_value = None - storage_service_mock = Mock() - storage_service_mock.get_study_information.return_value = StudyMetadataDTO( - id="id", - name="name", - created=1, - updated=1, - type="rawstudy", - owner=OwnerInfo(id=0, name="author"), - groups=[], - public_mode=PublicMode.NONE, - version=42, - workspace="default", - managed=True, - archived=False, - ) - storage_service_mock.get_study_path.return_value = Path("path/to/study") - - uuid = uuid4() - launcher_mock = Mock() - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - event_bus = Mock() - - pending = JobResult( - id=str(uuid), - study_id="study_uuid", - job_status=JobStatus.PENDING, - launcher="local", - ) - repository = Mock() - repository.save.return_value = pending - - launcher_service = LauncherService( - config=Config(), - study_service=storage_service_mock, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=event_bus, - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher_service._generate_new_id = lambda: str(uuid) - - job_id = launcher_service.run_study( - "study_uuid", - "local", - None, - RequestParameters( - user=JWTUser( - id=0, - impersonator=0, - type="users", - ) - ), - ) - - assert job_id == str(uuid) - repository.save.assert_called_once_with(pending) - event_bus.push.assert_called_once_with( - Event( - type=EventType.STUDY_JOB_STARTED, - payload=pending.to_dto().dict(), - permissions=PermissionInfo(owner=0), +class TestLauncherService: + @pytest.mark.unit_test + @patch.object(Auth, "get_current_user") + def test_service_run_study(self, get_current_user_mock) -> None: + get_current_user_mock.return_value = None + storage_service_mock = Mock() + # noinspection SpellCheckingInspection + storage_service_mock.get_study_information.return_value = StudyMetadataDTO( + id="id", + name="name", + created="1", + updated="1", + type="rawstudy", + owner=OwnerInfo(id=0, name="author"), + groups=[], + public_mode=PublicMode.NONE, + version=42, + workspace="default", + managed=True, + archived=False, ) - ) + storage_service_mock.get_study_path.return_value = Path("path/to/study") + uuid = uuid4() + launcher_mock = Mock() + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} -@pytest.mark.unit_test -def test_service_get_result_from_launcher(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - launcher="local", - ) - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_id = uuid4() - assert ( - launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + event_bus = Mock() + pending = JobResult( + id=str(uuid), + study_id="study_uuid", + job_status=JobStatus.PENDING, + launcher="local", + launcher_params=LauncherParametersDTO().json(), + ) + repository = Mock() + repository.save.return_value = pending + + launcher_service = LauncherService( + config=Config(), + study_service=storage_service_mock, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=event_bus, + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher_service._generate_new_id = lambda: str(uuid) -@pytest.mark.unit_test -def test_service_get_result_from_database(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - ) - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - assert ( - launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + job_id = launcher_service.run_study( + "study_uuid", + "local", + LauncherParametersDTO(), + RequestParameters( + user=JWTUser( + id=0, + impersonator=0, + type="users", + ) + ), + ) + assert job_id == str(uuid) + repository.save.assert_called_once_with(pending) + event_bus.push.assert_called_once_with( + Event( + type=EventType.STUDY_JOB_STARTED, + payload=pending.to_dto().dict(), + permissions=PermissionInfo(owner=0), + ) + ) -@pytest.mark.unit_test -def test_service_get_jobs_from_database(): - launcher_mock = Mock() - now = datetime.utcnow() - fake_execution_result = [ - JobResult( + @pytest.mark.unit_test + def test_service_get_result_from_launcher(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( id=str(uuid4()), - study_id="a", + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, + launcher="local", ) - ] - returned_faked_execution_results = [ - JobResult( - id="1", - study_id="a", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - JobResult( - id="2", - study_id="b", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - ] - all_faked_execution_results = returned_faked_execution_results + [ - JobResult( - id="3", - study_id="c", + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_id = uuid4() + assert ( + launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + + @pytest.mark.unit_test + def test_service_get_result_from_database(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( + id=str(uuid4()), + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, - creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), - ) - ] - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.find_by_study.return_value = fake_execution_result - repository.get_all.return_value = all_faked_execution_results - - study_service = Mock() - study_service.repository = Mock() - study_service.repository.get_list.return_value = [ - Mock( - spec=Study, - id="b", - groups=[], - owner=User(id=2), - public_mode=PublicMode.NONE, ) - ] - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - study_id = uuid4() - assert ( - launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) - repository.find_by_study.assert_called_once_with(str(study_id)) - assert ( - launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == all_faked_execution_results - ) - assert ( - launcher_service.get_jobs( - None, - params=RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) - ), + assert ( + launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result ) - == returned_faked_execution_results - ) - with pytest.raises(UserHasNotPermissionError): - launcher_service.remove_job( - "some job", - RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) + @pytest.mark.unit_test + def test_service_get_jobs_from_database(self) -> None: + launcher_mock = Mock() + now = datetime.utcnow() + fake_execution_result = [ + JobResult( + id=str(uuid4()), + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + ) + ] + returned_faked_execution_results = [ + JobResult( + id="1", + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, + ), + JobResult( + id="2", + study_id="b", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, ), + ] + all_faked_execution_results = returned_faked_execution_results + [ + JobResult( + id="3", + study_id="c", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), + ) + ] + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.find_by_study.return_value = fake_execution_result + repository.get_all.return_value = all_faked_execution_results + + study_service = Mock() + study_service.repository = Mock() + study_service.repository.get_list.return_value = [ + Mock( + spec=Study, + id="b", + groups=[], + owner=User(id=2), + public_mode=PublicMode.NONE, + ) + ] + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), ) - launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) - repository.delete.assert_called_with("some job") + study_id = uuid4() + assert ( + launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + repository.find_by_study.assert_called_once_with(str(study_id)) + assert ( + launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == all_faked_execution_results + ) + assert ( + launcher_service.get_jobs( + None, + params=RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) + == returned_faked_execution_results + ) + with pytest.raises(UserHasNotPermissionError): + launcher_service.remove_job( + "some job", + RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) -@pytest.mark.unit_test -@pytest.mark.parametrize( - "config, solver, expected", - [ - pytest.param( - { - "default": "local", - "local": [], - "slurm": [], - }, - "default", - [], - id="empty-config", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "default", - ["123", "456", "798"], - id="local-config-default", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "slurm", - [], - id="local-config-slurm", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "unknown", - [], - id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) + repository.delete.assert_called_with("some job") + + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config, solver, expected", + [ + pytest.param( + { + "default": "local", + "local": [], + "slurm": [], + }, + "default", + [], + id="empty-config", ), - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "default", - ["147", "258", "369"], - id="slurm-config-default", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "local", - [], - id="slurm-config-local", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "unknown", - [], - id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "default", + ["123", "456", "798"], + id="local-config-default", ), - ), - pytest.param( - { - "default": "slurm", - "local": ["456", "123", "798"], - "slurm": ["258", "147", "369"], - }, - "local", - ["123", "456", "798"], - id="local+slurm-config-local", - ), - ], -) -def test_service_get_solver_versions( - config: Dict[str, Union[str, List[str]]], - solver: Literal["default", "local", "slurm", "unknown"], - expected: List[str], -) -> None: - # Prepare the configuration - default = config.get("default", "local") - local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) - slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) - launcher_config = LauncherConfig( - default=default, - local=local if local else None, - slurm=slurm if slurm else None, - ) - config = Config(launcher=launcher_config) - launcher_service = LauncherService( - config=config, - study_service=Mock(), - job_result_repository=Mock(), - factory_launcher=Mock(), - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "slurm", + [], + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "unknown", + [], + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "default", + ["147", "258", "369"], + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "local", + [], + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "unknown", + [], + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": ["456", "123", "798"], + "slurm": ["258", "147", "369"], + }, + "local", + ["123", "456", "798"], + id="local+slurm-config-local", + ), + ], ) + def test_service_get_solver_versions( + self, + config: Dict[str, Union[str, List[str]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: List[str], + ) -> None: + # Prepare the configuration + # the default server version from the configuration file. + # the default server is initialised to local + default = config.get("default", "local") + local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) + slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) + launcher_config = LauncherConfig( + default=default, + local=local if local else None, + slurm=slurm if slurm else None, + ) + config = Config(launcher=launcher_config) + launcher_service = LauncherService( + config=config, + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - # Fetch the solver versions - actual = launcher_service.get_solver_versions(solver) + # Fetch the solver versions + actual = launcher_service.get_solver_versions(solver) + assert actual == expected - # Check the result - assert actual == expected + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config_map, solver, expected", + [ + pytest.param( + {"default": "local", "local": {}, "slurm": {}}, + "default", + {}, + id="empty-config", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "default", + {"min": 1, "default": 11, "max": 12}, + id="local-config-default", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "slurm", + {}, + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "unknown", + {}, + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "default", + {"min": 4, "default": 8, "max": 16}, + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {}, + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "unknown", + {}, + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": {"min": 1, "default": 11, "max": 12}, + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {"min": 1, "default": 11, "max": 12}, + id="local+slurm-config-local", + ), + ], + ) + def test_get_nb_cores( + self, + config_map: Dict[str, Union[str, Dict[str, int]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: Dict[str, int], + ) -> None: + # Prepare the configuration + default = config_map.get("default", "local") + local_nb_cores = config_map.get("local", {}) + slurm_nb_cores = config_map.get("slurm", {}) + launcher_config = LauncherConfig( + default=default, + local=LocalConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": local_nb_cores}), + slurm=SlurmConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": slurm_nb_cores}), + ) + launcher_service = LauncherService( + config=Config(launcher=launcher_config), + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + # Fetch the number of cores + actual = launcher_service.get_nb_cores(solver) + + # Check the result + assert actual == NbCoresConfig(**expected) + + @pytest.mark.unit_test + def test_service_kill_job(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.launcher = launcher + launcher_service.job_result_repository.get.return_value = job_result_mock + launcher_service.launchers = {"slurm": Mock()} + + job_status = launcher_service.kill_job( + job_id=job_id, + params=RequestParameters(user=DEFAULT_ADMIN_USER), + ) -@pytest.mark.unit_test -def test_service_kill_job(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.launcher = launcher - launcher_service.job_result_repository.get.return_value = job_result_mock - launcher_service.launchers = {"slurm": Mock()} - - job_status = launcher_service.kill_job( - job_id=job_id, - params=RequestParameters(user=DEFAULT_ADMIN_USER), - ) + assert job_status.job_status == JobStatus.FAILED + launcher_service.job_result_repository.save.assert_called_once_with(job_status) - launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) + def test_append_logs(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - assert job_status.job_status == JobStatus.FAILED - launcher_service.job_result_repository.save.assert_called_once_with(job_status) + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [] + launcher_service.job_result_repository.get.return_value = job_result_mock + + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) + launcher_service.append_log(job_id, "test", JobLogType.BEFORE) + launcher_service.job_result_repository.save.assert_called_with(job_result_mock) + assert job_result_mock.logs[0].message == "test" + assert job_result_mock.logs[0].job_id == "job_id" + assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) + + def test_get_logs(self, tmp_path: Path) -> None: + study_service = Mock() + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [ + JobLog(message="first message", log_type=str(JobLogType.BEFORE)), + JobLog(message="second message", log_type=str(JobLogType.BEFORE)), + JobLog(message="last message", log_type=str(JobLogType.AFTER)), + ] + job_result_mock.launcher_params = '{"archive_output": false}' + + launcher_service.job_result_repository.get.return_value = job_result_mock + slurm_launcher = Mock() + launcher_service.launchers = {"slurm": slurm_launcher} + slurm_launcher.get_log.return_value = "launcher logs" + + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nlauncher logs\nlast message" + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "launcher logs" + + study_service.get_logs.side_effect = ["some sim log", "error log"] + + job_result_mock.output_id = "some id" + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nsome sim log\nlast message" + + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "error log" + + study_service.get_logs.assert_has_calls( + [ + call( + "study_id", + "some id", + job_id, + False, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + call( + "study_id", + "some id", + job_id, + True, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + ] + ) + def test_manage_output(self, tmp_path: Path) -> None: + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) -def test_append_logs(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [] - launcher_service.job_result_repository.get.return_value = job_result_mock - - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - launcher_service.append_log(job_id, "test", JobLogType.BEFORE) - launcher_service.job_result_repository.save.assert_called_with(job_result_mock) - assert job_result_mock.logs[0].message == "test" - assert job_result_mock.logs[0].job_id == "job_id" - assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) - - -def test_get_logs(tmp_path: Path): - study_service = Mock() - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [ - JobLog(message="first message", log_type=str(JobLogType.BEFORE)), - JobLog(message="second message", log_type=str(JobLogType.BEFORE)), - JobLog(message="last message", log_type=str(JobLogType.AFTER)), - ] - job_result_mock.launcher_params = '{"archive_output": false}' - - launcher_service.job_result_repository.get.return_value = job_result_mock - slurm_launcher = Mock() - launcher_service.launchers = {"slurm": slurm_launcher} - slurm_launcher.get_log.return_value = "launcher logs" - - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nlauncher logs\nlast message" - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "launcher logs" - - study_service.get_logs.side_effect = ["some sim log", "error log"] - - job_result_mock.output_id = "some id" - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nsome sim log\nlast message" - - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "error log" - - study_service.get_logs.assert_has_calls( - [ - call( - "study_id", - "some id", - job_id, - False, - params=RequestParameters(DEFAULT_ADMIN_USER), + output_path = tmp_path / "output" + zipped_output_path = tmp_path / "zipped_output" + os.mkdir(output_path) + os.mkdir(zipped_output_path) + new_output_path = output_path / "new_output" + os.mkdir(new_output_path) + (new_output_path / "log").touch() + (new_output_path / "data").touch() + additional_log = tmp_path / "output.log" + additional_log.write_text("some log") + new_output_zipped_path = zipped_output_path / "test.zip" + with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: + output_data.writestr("some output", "0\n1") + job_id = "job_id" + zipped_job_id = "zipped_job_id" + study_id = "study_id" + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id), + JobResult(id=job_id, study_id=study_id, output_id="some id"), + JobResult(id=zipped_job_id, study_id=study_id), + JobResult( + id=job_id, + study_id=study_id, ), - call( - "study_id", - "some id", - job_id, - True, - params=RequestParameters(DEFAULT_ADMIN_USER), + JobResult( + id=job_id, + study_id=study_id, + launcher_params=json.dumps( + { + "archive_output": False, + f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", + } + ), ), ] - ) + with pytest.raises(JobNotFound): + launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - -def test_manage_output(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - output_path = tmp_path / "output" - zipped_output_path = tmp_path / "zipped_output" - os.mkdir(output_path) - os.mkdir(zipped_output_path) - new_output_path = output_path / "new_output" - os.mkdir(new_output_path) - (new_output_path / "log").touch() - (new_output_path / "data").touch() - additional_log = tmp_path / "output.log" - additional_log.write_text("some log") - new_output_zipped_path = zipped_output_path / "test.zip" - with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: - output_data.writestr("some output", "0\n1") - job_id = "job_id" - zipped_job_id = "zipped_job_id" - study_id = "study_id" - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id), - JobResult(id=job_id, study_id=study_id, output_id="some id"), - JobResult(id=zipped_job_id, study_id=study_id), - JobResult( - id=job_id, - study_id=study_id, - ), - JobResult( - id=job_id, - study_id=study_id, - launcher_params=json.dumps( - { - "archive_output": False, - f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", - } - ), - ), - ] - with pytest.raises(JobNotFound): launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + launcher_service.study_service.import_output.assert_called() - launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() - launcher_service.study_service.import_output.assert_called() - - launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) - launcher_service.study_service.export_output.assert_called() - - launcher_service._import_output( - zipped_job_id, - zipped_output_path, - { - "out.log": [additional_log], - "antares-out": [additional_log], - "antares-err": [additional_log], - }, - ) - launcher_service.study_service.save_logs.has_calls( - [ - call(study_id, zipped_job_id, "out.log", "some log"), - call(study_id, zipped_job_id, "out", "some log"), - call(study_id, zipped_job_id, "err", "some log"), - ] - ) - - launcher_service.study_service.import_output.side_effect = [ - StudyNotFoundError(""), - StudyNotFoundError(""), - ] - - assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - - (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") - output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert output_name is not None - assert output_name.endswith("-hello") - assert launcher_service._get_job_output_fallback_path(job_id).exists() - assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() - - launcher_service.job_result_repository.get.reset_mock() - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id, output_id=output_name), - ] - with pytest.raises(JobNotFound): launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) + launcher_service.study_service.export_output.assert_called() - study_service.get_study.reset_mock() - study_service.get_study.side_effect = StudyNotFoundError("") + launcher_service._import_output( + zipped_job_id, + zipped_output_path, + { + "out.log": [additional_log], + "antares-out": [additional_log], + "antares-err": [additional_log], + }, + ) + launcher_service.study_service.save_logs.has_calls( + [ + call(study_id, zipped_job_id, "out.log", "some log"), + call(study_id, zipped_job_id, "out", "some log"), + call(study_id, zipped_job_id, "err", "some log"), + ] + ) - export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) - launcher_service.file_transfer_manager.request_download.return_value = FileDownload( - id="a", name="a", filename="a", ready=True, path="a" - ) - launcher_service.task_service.add_task.return_value = "some id" + launcher_service.study_service.import_output.side_effect = [ + StudyNotFoundError(""), + StudyNotFoundError(""), + ] - assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( - task="some id", file=export_file - ) + assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() + (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") + output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert output_name is not None + assert output_name.endswith("-hello") + assert launcher_service._get_job_output_fallback_path(job_id).exists() + assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() + launcher_service.job_result_repository.get.reset_mock() + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id, output_id=output_name), + ] + with pytest.raises(JobNotFound): + launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) -def test_save_stats(tmp_path: Path) -> None: - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + study_service.get_study.reset_mock() + study_service.get_study.side_effect = StudyNotFoundError("") - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) + launcher_service.file_transfer_manager.request_download.return_value = FileDownload( + id="a", name="a", filename="a", ready=True, path="a" + ) + launcher_service.task_service.add_task.return_value = "some id" - job_id = "job_id" - study_id = "study_id" - job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) - - output_path = tmp_path / "some-output" - output_path.mkdir() - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_not_called() - - expected_saved_stats = """#item duration_ms NbOccurences -mc_years 216328 1 -study_loading 4304 1 -survey_report 158 1 -total 244581 1 -tsgen_hydro 1683 1 -tsgen_load 2702 1 -tsgen_solar 21606 1 -tsgen_thermal 407 2 -tsgen_wind 2500 1 - """ - (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats=expected_saved_stats, + assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( + task="some id", file=export_file ) - ) - zip_file = tmp_path / "test.zip" - with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: - output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + + def test_save_solver_stats(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service._save_solver_stats(job_result, zip_file) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats="0\n1", + job_id = "job_id" + study_id = "study_id" + job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) + + output_path = tmp_path / "some-output" + output_path.mkdir() + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_not_called() + + expected_saved_stats = """#item duration_ms NbOccurences + mc_years 216328 1 + study_loading 4304 1 + survey_report 158 1 + total 244581 1 + tsgen_hydro 1683 1 + tsgen_load 2702 1 + tsgen_solar 21606 1 + tsgen_thermal 407 2 + tsgen_wind 2500 1 + """ + (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats=expected_saved_stats, + ) ) - ) + zip_file = tmp_path / "test.zip" + with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: + output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + + launcher_service._save_solver_stats(job_result, zip_file) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats="0\n1", + ) + ) -def test_get_load(tmp_path: Path): - study_service = Mock() - job_repository = Mock() + def test_get_load(self, tmp_path: Path) -> None: + study_service = Mock() + job_repository = Mock() - launcher_service = LauncherService( - config=Mock( + config = Config( storage=StorageConfig(tmp_dir=tmp_path), - launcher=LauncherConfig(local=LocalConfig(), slurm=SlurmConfig(default_n_cpu=12)), - ), - study_service=study_service, - job_result_repository=job_repository, - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_repository.get_running.side_effect = [ - [], - [], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - ], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params='{"nb_cpu": 18}', - ), - Mock( - spec=JobResult, - launcher="local", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, + launcher=LauncherConfig( + local=LocalConfig(), + slurm=SlurmConfig(nb_cores=NbCoresConfig(min=1, default=12, max=24)), ), - Mock( - spec=JobResult, - launcher="local", - launcher_params='{"nb_cpu": 7}', - ), - ], - ] - - with pytest.raises(NotImplementedError): - launcher_service.get_load(from_cluster=True) - - load = launcher_service.get_load() - assert load["slurm"] == 0 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 12.0 / 64 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 30.0 / 64 - assert load["local"] == 8.0 / os.cpu_count() + ) + launcher_service = LauncherService( + config=config, + study_service=study_service, + job_result_repository=job_repository, + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_repository.get_running.side_effect = [ + # call #1 + [], + # call #2 + [], + # call #3 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + ], + # call #4 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params='{"nb_cpu": 18}', + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params='{"nb_cpu": 7}', + ), + ], + ] + + # call #1 + with pytest.raises(NotImplementedError): + launcher_service.get_load(from_cluster=True) + + # call #2 + load = launcher_service.get_load() + assert load["slurm"] == 0 + assert load["local"] == 0 + + # call #3 + load = launcher_service.get_load() + slurm_config = config.launcher.slurm + assert load["slurm"] == slurm_config.nb_cores.default / slurm_config.max_cores + assert load["local"] == 0 + + # call #4 + load = launcher_service.get_load() + local_config = config.launcher.local + assert load["slurm"] == (18 + slurm_config.nb_cores.default) / slurm_config.max_cores + assert load["local"] == (7 + local_config.nb_cores.default) / local_config.nb_cores.max diff --git a/tests/launcher/test_slurm_launcher.py b/tests/launcher/test_slurm_launcher.py index 7820abcdea..dfb6846e89 100644 --- a/tests/launcher/test_slurm_launcher.py +++ b/tests/launcher/test_slurm_launcher.py @@ -10,11 +10,9 @@ from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb from antareslauncher.main import MainParameters from antareslauncher.study_dto import StudyDTO -from sqlalchemy import create_engine +from sqlalchemy.orm import Session # type: ignore -from antarest.core.config import Config, LauncherConfig, SlurmConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.core.config import Config, LauncherConfig, NbCoresConfig, SlurmConfig from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.slurm_launcher.slurm_launcher import ( LOG_DIR_NAME, @@ -24,32 +22,34 @@ SlurmLauncher, VersionNotSupportedError, ) -from antarest.launcher.model import JobStatus, LauncherParametersDTO +from antarest.launcher.model import JobStatus, LauncherParametersDTO, XpansionParametersDTO from antarest.tools.admin_lib import clean_locks_from_config @pytest.fixture def launcher_config(tmp_path: Path) -> Config: - return Config( - launcher=LauncherConfig( - slurm=SlurmConfig( - local_workspace=tmp_path, - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], - username="username", - hostname="hostname", - port=42, - private_key_file=Path("private_key_file"), - key_password="key_password", - password="password", - ) - ) - ) + data = { + "local_workspace": tmp_path, + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["840", "850", "860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + } + return Config(launcher=LauncherConfig(slurm=SlurmConfig.from_dict(data))) @pytest.mark.unit_test -def test_slurm_launcher__launcher_init_exception(): +def test_slurm_launcher__launcher_init_exception() -> None: with pytest.raises( LauncherInitException, match="Missing parameter 'launcher.slurm'", @@ -63,13 +63,13 @@ def test_slurm_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_init_slurm_launcher_arguments(tmp_path: Path): +def test_init_slurm_launcher_arguments(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( default_wait_time=42, default_time_limit=43, - default_n_cpu=44, + nb_cores=NbCoresConfig(min=1, default=30, max=36), local_workspace=tmp_path, ) ) @@ -88,13 +88,15 @@ def test_init_slurm_launcher_arguments(tmp_path: Path): assert not arguments.xpansion_mode assert not arguments.version assert not arguments.post_processing - assert Path(arguments.studies_in) == config.launcher.slurm.local_workspace / "STUDIES_IN" - assert Path(arguments.output_dir) == config.launcher.slurm.local_workspace / "OUTPUT" - assert Path(arguments.log_dir) == config.launcher.slurm.local_workspace / "LOGS" + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert Path(arguments.studies_in) == slurm_config.local_workspace / "STUDIES_IN" + assert Path(arguments.output_dir) == slurm_config.local_workspace / "OUTPUT" + assert Path(arguments.log_dir) == slurm_config.local_workspace / "LOGS" @pytest.mark.unit_test -def test_init_slurm_launcher_parameters(tmp_path: Path): +def test_init_slurm_launcher_parameters(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( @@ -115,23 +117,25 @@ def test_init_slurm_launcher_parameters(tmp_path: Path): slurm_launcher = SlurmLauncher(config=config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) main_parameters = slurm_launcher._init_launcher_parameters() - assert main_parameters.json_dir == config.launcher.slurm.local_workspace - assert main_parameters.default_json_db_name == config.launcher.slurm.default_json_db_name - assert main_parameters.slurm_script_path == config.launcher.slurm.slurm_script_path - assert main_parameters.antares_versions_on_remote_server == config.launcher.slurm.antares_versions_on_remote_server + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert main_parameters.json_dir == slurm_config.local_workspace + assert main_parameters.default_json_db_name == slurm_config.default_json_db_name + assert main_parameters.slurm_script_path == slurm_config.slurm_script_path + assert main_parameters.antares_versions_on_remote_server == slurm_config.antares_versions_on_remote_server assert main_parameters.default_ssh_dict == { - "username": config.launcher.slurm.username, - "hostname": config.launcher.slurm.hostname, - "port": config.launcher.slurm.port, - "private_key_file": config.launcher.slurm.private_key_file, - "key_password": config.launcher.slurm.key_password, - "password": config.launcher.slurm.password, + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, } assert main_parameters.db_primary_key == "name" @pytest.mark.unit_test -def test_slurm_launcher_delete_function(tmp_path: str): +def test_slurm_launcher_delete_function(tmp_path: str) -> None: config = Config(launcher=LauncherConfig(slurm=SlurmConfig(local_workspace=Path(tmp_path)))) slurm_launcher = SlurmLauncher( config=config, @@ -155,64 +159,104 @@ def test_slurm_launcher_delete_function(tmp_path: str): assert not file_path.exists() -def test_extra_parameters(launcher_config: Config): +def test_extra_parameters(launcher_config: Config) -> None: + """ + The goal of this unit test is to control the protected method `_check_and_apply_launcher_params`, + which is called by the `SlurmLauncher.run_study` function, in a separate thread. + + The `_check_and_apply_launcher_params` method extract the parameters from the configuration + and populate a `argparse.Namespace` which is used to launch a simulation using Antares Launcher. + + We want to make sure all the parameters are populated correctly. + """ slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock(), ) - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO()) - assert launcher_params.n_cpu == 1 - assert launcher_params.time_limit == 0 + + apply_params = slurm_launcher._apply_params + launcher_params = apply_params(LauncherParametersDTO()) + slurm_config = slurm_launcher.config.launcher.slurm + assert slurm_config is not None + assert launcher_params.n_cpu == slurm_config.nb_cores.default + assert launcher_params.time_limit == slurm_config.default_time_limit assert not launcher_params.xpansion_mode assert not launcher_params.post_processing - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=12)) + launcher_params = apply_params(LauncherParametersDTO(other_options="")) + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(other_options="foo\tbar baz ")) + assert launcher_params.other_options == "foo bar baz" + + launcher_params = apply_params(LauncherParametersDTO(other_options="/foo?bar")) + assert launcher_params.other_options == "foobar" + + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=12)) assert launcher_params.n_cpu == 12 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=48)) - assert launcher_params.n_cpu == 1 + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=999)) + assert launcher_params.n_cpu == slurm_config.nb_cores.default # out of range - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=10)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=10)) assert launcher_params.time_limit == MIN_TIME_LIMIT - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=999999999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=999999999)) assert launcher_params.time_limit == MAX_TIME_LIMIT - 3600 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=99999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=99999)) assert launcher_params.time_limit == 99999 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(xpansion=True)) - assert launcher_params.xpansion_mode + launcher_params = apply_params(LauncherParametersDTO(xpansion=False)) + assert launcher_params.xpansion_mode is None + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True)) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True, xpansion_r_version=True)) + assert launcher_params.xpansion_mode == "r" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=False))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=True))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "xpansion_sensitivity" + + launcher_params = apply_params(LauncherParametersDTO(post_processing=False)) + assert launcher_params.post_processing is False - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(post_processing=True)) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(post_processing=True)) + assert launcher_params.post_processing is True - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(adequacy_patch={})) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(adequacy_patch={})) + assert launcher_params.post_processing is True # noinspection PyUnresolvedReferences @pytest.mark.parametrize( - "version, job_status", - [(42, JobStatus.RUNNING), (99, JobStatus.FAILED), (45, JobStatus.FAILED)], + "version, launcher_called, job_status", + [ + (840, True, JobStatus.RUNNING), + (860, False, JobStatus.FAILED), + pytest.param( + 999, False, JobStatus.FAILED, marks=pytest.mark.xfail(raises=VersionNotSupportedError, strict=True) + ), + ], ) @pytest.mark.unit_test def test_run_study( - tmp_path: Path, launcher_config: Config, version: int, + launcher_called: bool, job_status: JobStatus, -): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -231,7 +275,8 @@ def test_run_study( job_id = str(uuid.uuid4()) study_dir = argument.studies_in / job_id study_dir.mkdir(parents=True) - (study_dir / "study.antares").write_text( + study_antares_path = study_dir.joinpath("study.antares") + study_antares_path.write_text( textwrap.dedent( """\ [antares] @@ -242,22 +287,20 @@ def test_run_study( # noinspection PyUnusedLocal def call_launcher_mock(arguments: Namespace, parameters: MainParameters): - if version != 45: + if launcher_called: slurm_launcher.data_repo_tinydb.save_study(StudyDTO(job_id)) slurm_launcher._call_launcher = call_launcher_mock - if version == 99: - with pytest.raises(VersionNotSupportedError): - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) - else: - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # When the launcher is called + slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # Check the results assert ( version not in launcher_config.launcher.slurm.antares_versions_on_remote_server - or f"solver_version = {version}" in (study_dir / "study.antares").read_text(encoding="utf-8") + or f"solver_version = {version}" in study_antares_path.read_text(encoding="utf-8") ) - # slurm_launcher._clean_local_workspace.assert_called_once() + slurm_launcher.callbacks.export_study.assert_called_once() slurm_launcher.callbacks.update_status.assert_called_once_with(ANY, job_status, ANY, None) if job_status == JobStatus.RUNNING: @@ -266,7 +309,7 @@ def call_launcher_mock(arguments: Namespace, parameters: MainParameters): @pytest.mark.unit_test -def test_check_state(tmp_path: Path, launcher_config: Config): +def test_check_state(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -308,16 +351,7 @@ def test_check_state(tmp_path: Path, launcher_config: Config): @pytest.mark.unit_test -def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - +def test_clean_local_workspace(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -325,7 +359,6 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): use_private_workspace=False, cache=Mock(), ) - (launcher_config.launcher.slurm.local_workspace / "machin.txt").touch() assert os.listdir(launcher_config.launcher.slurm.local_workspace) @@ -335,7 +368,7 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): # noinspection PyUnresolvedReferences @pytest.mark.unit_test -def test_import_study_output(launcher_config, tmp_path): +def test_import_study_output(launcher_config, tmp_path) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -399,7 +432,7 @@ def test_kill_job( run_with_mock, tmp_path: Path, launcher_config: Config, -): +) -> None: launch_id = "launch_id" mock_study = Mock() mock_study.name = launch_id @@ -419,35 +452,36 @@ def test_kill_job( slurm_launcher.kill_job(job_id=launch_id) + slurm_config = launcher_config.launcher.slurm launcher_arguments = Namespace( antares_version=0, check_queue=False, - job_id_to_kill=42, + job_id_to_kill=mock_study.job_id, json_ssh_config=None, log_dir=str(tmp_path / "LOGS"), - n_cpu=1, + n_cpu=slurm_config.nb_cores.default, output_dir=str(tmp_path / "OUTPUT"), post_processing=False, studies_in=str(tmp_path / "STUDIES_IN"), - time_limit=0, + time_limit=slurm_config.default_time_limit, version=False, wait_mode=False, - wait_time=0, + wait_time=slurm_config.default_wait_time, xpansion_mode=None, other_options=None, ) launcher_parameters = MainParameters( json_dir=Path(tmp_path), - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], + default_json_db_name=slurm_config.default_json_db_name, + slurm_script_path=slurm_config.slurm_script_path, + antares_versions_on_remote_server=slurm_config.antares_versions_on_remote_server, default_ssh_dict={ - "username": "username", - "hostname": "hostname", - "port": 42, - "private_key_file": Path("private_key_file"), - "key_password": "key_password", - "password": "password", + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, }, db_primary_key="name", ) @@ -456,7 +490,7 @@ def test_kill_job( @patch("antarest.launcher.adapters.slurm_launcher.slurm_launcher.run_with") -def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config): +def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config) -> None: callbacks = Mock() (tmp_path / LOG_DIR_NAME).mkdir() @@ -474,11 +508,7 @@ def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: clean_locks_from_config(launcher_config) assert not (workspaces[0] / WORKSPACE_LOCK_FILE_NAME).exists() - slurm_launcher.data_repo_tinydb.save_study( - StudyDTO( - path="somepath", - ) - ) + slurm_launcher.data_repo_tinydb.save_study(StudyDTO(path="some_path")) run_with_mock.assert_not_called() # will use existing private workspace diff --git a/tests/storage/repository/filesystem/config/test_config_files.py b/tests/storage/repository/filesystem/config/test_config_files.py index 9b029fcea7..e0e0ccb45f 100644 --- a/tests/storage/repository/filesystem/config/test_config_files.py +++ b/tests/storage/repository/filesystem/config/test_config_files.py @@ -4,6 +4,7 @@ import pytest +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.files import ( _parse_links, _parse_outputs, @@ -73,6 +74,7 @@ def test_parse_bindings(tmp_path: Path) -> None: [bindB] id = bindB + type = weekly """ (study_path / "input/bindingconstraints/bindingconstraints.ini").write_text(content) @@ -81,8 +83,18 @@ def test_parse_bindings(tmp_path: Path) -> None: path=study_path, version=-1, bindings=[ - BindingConstraintDTO(id="bindA", areas=[], clusters=[]), - BindingConstraintDTO(id="bindB", areas=[], clusters=[]), + BindingConstraintDTO( + id="bindA", + areas=set(), + clusters=set(), + time_step=BindingConstraintFrequency.HOURLY, + ), + BindingConstraintDTO( + id="bindB", + areas=set(), + clusters=set(), + time_step=BindingConstraintFrequency.WEEKLY, + ), ], study_id="id", output_path=study_path / "output", diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py index 0b0ed0db87..4986073713 100644 --- a/tests/storage/test_model.py +++ b/tests/storage/test_model.py @@ -1,5 +1,6 @@ from pathlib import Path +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.model import ( Area, BindingConstraintDTO, @@ -41,7 +42,14 @@ def test_file_study_tree_config_dto(): xpansion="", ) }, - bindings=[BindingConstraintDTO(id="b1", areas=[], clusters=[])], + bindings=[ + BindingConstraintDTO( + id="b1", + areas=set(), + clusters=set(), + time_step=BindingConstraintFrequency.DAILY, + ) + ], store_new_set=False, archive_input_series=["?"], enr_modelling="aggregated", diff --git a/tests/study/business/conftest.py b/tests/study/business/conftest.py deleted file mode 100644 index 2638d47b3d..0000000000 --- a/tests/study/business/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -import contextlib - -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from antarest.dbmodel import Base - - -@pytest.fixture(scope="function", name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(scope="function", name="db_session") -def db_session_fixture(db_engine): - make_session = sessionmaker(bind=db_engine) - with contextlib.closing(make_session()) as session: - yield session diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py index e9689834c5..93a3262259 100644 --- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py +++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py @@ -36,3 +36,22 @@ def test_get_st_storage(self, tmp_path): matrix_id5 = ref5.split(MATRIX_PROTOCOL_PREFIX)[1] matrix_dto5 = generator.matrix_service.get(matrix_id5) assert np.array(matrix_dto5.data).all() == matrix_constants.st_storage.series.inflows.all() + + def test_get_binding_constraint(self, tmp_path): + generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path)) + series = matrix_constants.binding_constraint.series + + hourly = generator.get_binding_constraint_hourly() + hourly_matrix_id = hourly.split(MATRIX_PROTOCOL_PREFIX)[1] + hourly_matrix_dto = generator.matrix_service.get(hourly_matrix_id) + assert np.array(hourly_matrix_dto.data).all() == series.default_binding_constraint_hourly.all() + + daily = generator.get_binding_constraint_daily() + daily_matrix_id = daily.split(MATRIX_PROTOCOL_PREFIX)[1] + daily_matrix_dto = generator.matrix_service.get(daily_matrix_id) + assert np.array(daily_matrix_dto.data).all() == series.default_binding_constraint_daily.all() + + weekly = generator.get_binding_constraint_weekly() + weekly_matrix_id = weekly.split(MATRIX_PROTOCOL_PREFIX)[1] + weekly_matrix_dto = generator.matrix_service.get(weekly_matrix_id) + assert np.array(weekly_matrix_dto.data).all() == series.default_binding_constraint_weekly.all() diff --git a/tests/test_resources.py b/tests/test_resources.py index 2a0bf94677..330116e507 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -4,6 +4,8 @@ import pytest +from antarest.core.config import Config + HERE = pathlib.Path(__file__).parent.resolve() PROJECT_DIR = next(iter(p for p in HERE.parents if p.joinpath("antarest").exists())) RESOURCES_DIR = PROJECT_DIR.joinpath("resources") @@ -84,3 +86,17 @@ def test_empty_study_zip(filename: str, expected_list: Sequence[str]): with zipfile.ZipFile(resource_path) as myzip: actual = sorted(myzip.namelist()) assert actual == expected_list + + +def test_resources_config(): + """ + Check that the "resources/config.yaml" file is valid. + + The launcher section must be configured to use a local launcher + with NB Cores detection enabled. + """ + config_path = RESOURCES_DIR.joinpath("deploy/config.yaml") + config = Config.from_yaml_file(config_path, res=RESOURCES_DIR) + assert config.launcher.default == "local" + assert config.launcher.local is not None + assert config.launcher.local.enable_nb_cores_detection is True diff --git a/tests/variantstudy/model/command/test_manage_binding_constraints.py b/tests/variantstudy/model/command/test_manage_binding_constraints.py index 1596ce6476..d22e05ce1e 100644 --- a/tests/variantstudy/model/command/test_manage_binding_constraints.py +++ b/tests/variantstudy/model/command/test_manage_binding_constraints.py @@ -1,10 +1,16 @@ from unittest.mock import Mock from antarest.study.storage.rawstudy.io.reader import IniReader +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.variantstudy.business.command_extractor import CommandExtractor from antarest.study.storage.variantstudy.business.command_reverter import CommandReverter -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep +from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import ( + default_binding_constraint_daily, + default_binding_constraint_hourly, + default_binding_constraint_weekly, +) +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster @@ -56,7 +62,7 @@ def test_manage_binding_constraint( bind1_cmd = CreateBindingConstraint( name="BD 1", - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.LESS, coeffs={"area1%area2": [800, 30]}, comments="Hello", @@ -68,7 +74,7 @@ def test_manage_binding_constraint( bind2_cmd = CreateBindingConstraint( name="BD 2", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"area1.cluster": [50]}, command_context=command_context, @@ -101,13 +107,14 @@ def test_manage_binding_constraint( "type": "daily", } + weekly_values = default_binding_constraint_weekly.tolist() bind_update = UpdateBindingConstraint( id="bd 1", enabled=False, - time_step=TimeStep.WEEKLY, + time_step=BindingConstraintFrequency.WEEKLY, operator=BindingConstraintOperator.BOTH, coeffs={"area1%area2": [800, 30]}, - values=[[0]], + values=weekly_values, command_context=command_context, ) res = bind_update.apply(empty_study) @@ -139,28 +146,29 @@ def test_manage_binding_constraint( def test_match(command_context: CommandContext): + values = default_binding_constraint_daily.tolist() base = CreateBindingConstraint( name="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) other_match = CreateBindingConstraint( name="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) other_not_match = CreateBindingConstraint( name="bar", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, command_context=command_context, @@ -171,31 +179,31 @@ def test_match(command_context: CommandContext): assert not base.match(other_other) assert base.match_signature() == "create_binding_constraint%foo" # check the matrices links - matrix_id = command_context.matrix_service.create([[0]]) + matrix_id = command_context.matrix_service.create(values) assert base.get_inner_matrices() == [matrix_id] base = UpdateBindingConstraint( id="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) other_match = UpdateBindingConstraint( id="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) other_not_match = UpdateBindingConstraint( id="bar", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, command_context=command_context, @@ -206,7 +214,7 @@ def test_match(command_context: CommandContext): assert not base.match(other_other) assert base.match_signature() == "update_binding_constraint%foo" # check the matrices links - matrix_id = command_context.matrix_service.create([[0]]) + matrix_id = command_context.matrix_service.create(values) assert base.get_inner_matrices() == [matrix_id] base = RemoveBindingConstraint(id="foo", command_context=command_context) @@ -221,13 +229,16 @@ def test_match(command_context: CommandContext): def test_revert(command_context: CommandContext): + hourly_values = default_binding_constraint_hourly.tolist() + daily_values = default_binding_constraint_daily.tolist() + weekly_values = default_binding_constraint_weekly.tolist() base = CreateBindingConstraint( name="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=daily_values, command_context=command_context, ) assert CommandReverter().revert(base, [], Mock(spec=FileStudy)) == [ @@ -237,10 +248,10 @@ def test_revert(command_context: CommandContext): base = UpdateBindingConstraint( id="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=daily_values, command_context=command_context, ) mock_command_extractor = Mock(spec=CommandExtractor) @@ -255,19 +266,19 @@ def test_revert(command_context: CommandContext): UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.WEEKLY, + time_step=BindingConstraintFrequency.WEEKLY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=weekly_values, command_context=command_context, ), UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=hourly_values, command_context=command_context, ), ], @@ -276,34 +287,34 @@ def test_revert(command_context: CommandContext): UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=hourly_values, command_context=command_context, ) ] # check the matrices links - matrix_id = command_context.matrix_service.create([[0]]) + hourly_matrix_id = command_context.matrix_service.create(hourly_values) assert CommandReverter().revert( base, [ UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.WEEKLY, + time_step=BindingConstraintFrequency.WEEKLY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=weekly_values, command_context=command_context, ), CreateBindingConstraint( name="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.EQUAL, coeffs={"a": [0.3]}, - values=[[0]], + values=hourly_values, command_context=command_context, ), ], @@ -312,10 +323,10 @@ def test_revert(command_context: CommandContext): UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.EQUAL, coeffs={"a": [0.3]}, - values=matrix_id, + values=hourly_matrix_id, comments=None, command_context=command_context, ) @@ -329,7 +340,7 @@ def test_create_diff(command_context: CommandContext): base = CreateBindingConstraint( name="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, values="a", @@ -338,7 +349,7 @@ def test_create_diff(command_context: CommandContext): other_match = CreateBindingConstraint( name="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.EQUAL, coeffs={"b": [0.3]}, values="b", @@ -348,7 +359,7 @@ def test_create_diff(command_context: CommandContext): UpdateBindingConstraint( id="foo", enabled=True, - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.EQUAL, coeffs={"b": [0.3]}, values="b", @@ -356,22 +367,23 @@ def test_create_diff(command_context: CommandContext): ) ] + values = default_binding_constraint_daily.tolist() base = UpdateBindingConstraint( id="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) other_match = UpdateBindingConstraint( id="foo", enabled=False, - time_step=TimeStep.DAILY, + time_step=BindingConstraintFrequency.DAILY, operator=BindingConstraintOperator.BOTH, coeffs={"a": [0.3]}, - values=[[0]], + values=values, command_context=command_context, ) assert base.create_diff(other_match) == [other_match] diff --git a/tests/variantstudy/model/command/test_remove_area.py b/tests/variantstudy/model/command/test_remove_area.py index cb03c4f2e0..3fb77082f2 100644 --- a/tests/variantstudy/model/command/test_remove_area.py +++ b/tests/variantstudy/model/command/test_remove_area.py @@ -1,9 +1,10 @@ import pytest +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.study_upgrader import upgrade_study -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster @@ -125,7 +126,7 @@ def test_apply( bind1_cmd = CreateBindingConstraint( name="BD 2", - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.LESS, coeffs={ f"{area_id}%{area_id2}": [400, 30], diff --git a/tests/variantstudy/model/command/test_remove_cluster.py b/tests/variantstudy/model/command/test_remove_cluster.py index 0948e962f9..e4525fbc36 100644 --- a/tests/variantstudy/model/command/test_remove_cluster.py +++ b/tests/variantstudy/model/command/test_remove_cluster.py @@ -1,8 +1,9 @@ from checksumdir import dirhash +from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy -from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep +from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator from antarest.study.storage.variantstudy.model.command.create_area import CreateArea from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster @@ -48,7 +49,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): bind1_cmd = CreateBindingConstraint( name="BD 1", - time_step=TimeStep.HOURLY, + time_step=BindingConstraintFrequency.HOURLY, operator=BindingConstraintOperator.LESS, coeffs={ f"{area_id}.{cluster_id}": [800, 30], diff --git a/tests/variantstudy/test_command_factory.py b/tests/variantstudy/test_command_factory.py index 4dbcb6a06f..47e3c57811 100644 --- a/tests/variantstudy/test_command_factory.py +++ b/tests/variantstudy/test_command_factory.py @@ -31,7 +31,10 @@ def setup_class(self): f".{name}", package="antarest.study.storage.variantstudy.model.command", ) - self.command_class_set = {command.__name__ for command in ICommand.__subclasses__()} + abstract_commands = {"AbstractBindingConstraintCommand"} + self.command_class_set = { + cmd.__name__ for cmd in ICommand.__subclasses__() if cmd.__name__ not in abstract_commands + } # noinspection SpellCheckingInspection @pytest.mark.parametrize( diff --git a/webapp/package-lock.json b/webapp/package-lock.json index 9a4f7cd401..3ab3cca0b7 100644 --- a/webapp/package-lock.json +++ b/webapp/package-lock.json @@ -1,6 +1,6 @@ { "name": "antares-web", - "version": "2.15.1", + "version": "2.15.2", "lockfileVersion": 3, "requires": true, "packages": { diff --git a/webapp/package.json b/webapp/package.json index 74a8fd65f0..506c324bf2 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -1,6 +1,6 @@ { "name": "antares-web", - "version": "2.15.1", + "version": "2.15.2", "private": true, "engines": { "node": "18.16.1"