diff --git a/antarest/__init__.py b/antarest/__init__.py
index 4023e2ca7b..447ab7edd4 100644
--- a/antarest/__init__.py
+++ b/antarest/__init__.py
@@ -7,9 +7,9 @@
# Standard project metadata
-__version__ = "2.15.1"
+__version__ = "2.15.2"
__author__ = "RTE, Antares Web Team"
-__date__ = "2023-10-05"
+__date__ = "2023-10-11"
# noinspection SpellCheckingInspection
__credits__ = "(c) Réseau de Transport de l’Électricité (RTE)"
diff --git a/antarest/core/cache/business/redis_cache.py b/antarest/core/cache/business/redis_cache.py
index 3c72845eb8..176e583b76 100644
--- a/antarest/core/cache/business/redis_cache.py
+++ b/antarest/core/cache/business/redis_cache.py
@@ -21,6 +21,7 @@ def __init__(self, redis_client: Redis): # type: ignore
self.redis = redis_client
def start(self) -> None:
+ # Assuming the Redis service is already running; no need to start it here.
pass
def put(self, id: str, data: JSON, duration: int = 3600) -> None:
diff --git a/antarest/core/config.py b/antarest/core/config.py
index 4c232fa475..b48be8ded0 100644
--- a/antarest/core/config.py
+++ b/antarest/core/config.py
@@ -1,16 +1,14 @@
-import logging
+import multiprocessing
import tempfile
-from dataclasses import dataclass, field
+from dataclasses import asdict, dataclass, field
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Dict, List, Optional
import yaml
from antarest.core.model import JSON
from antarest.core.roles import RoleType
-logger = logging.getLogger(__name__)
-
@dataclass(frozen=True)
class ExternalAuthConfig:
@@ -23,13 +21,16 @@ class ExternalAuthConfig:
add_ext_groups: bool = False
group_mapping: Dict[str, str] = field(default_factory=dict)
- @staticmethod
- def from_dict(data: JSON) -> "ExternalAuthConfig":
- return ExternalAuthConfig(
- url=data.get("url", None),
- default_group_role=RoleType(data.get("default_group_role", RoleType.READER.value)),
- add_ext_groups=data.get("add_ext_groups", False),
- group_mapping=data.get("group_mapping", {}),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "ExternalAuthConfig":
+ defaults = cls()
+ return cls(
+ url=data.get("url", defaults.url),
+ default_group_role=(
+ RoleType(data["default_group_role"]) if "default_group_role" in data else defaults.default_group_role
+ ),
+ add_ext_groups=data.get("add_ext_groups", defaults.add_ext_groups),
+ group_mapping=data.get("group_mapping", defaults.group_mapping),
)
@@ -44,13 +45,18 @@ class SecurityConfig:
disabled: bool = False
external_auth: ExternalAuthConfig = ExternalAuthConfig()
- @staticmethod
- def from_dict(data: JSON) -> "SecurityConfig":
- return SecurityConfig(
- jwt_key=data.get("jwt", {}).get("key", ""),
- admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", ""),
- disabled=data.get("disabled", False),
- external_auth=ExternalAuthConfig.from_dict(data.get("external_auth", {})),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "SecurityConfig":
+ defaults = cls()
+ return cls(
+ jwt_key=data.get("jwt", {}).get("key", defaults.jwt_key),
+ admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", defaults.admin_pwd),
+ disabled=data.get("disabled", defaults.disabled),
+ external_auth=(
+ ExternalAuthConfig.from_dict(data["external_auth"])
+ if "external_auth" in data
+ else defaults.external_auth
+ ),
)
@@ -65,13 +71,14 @@ class WorkspaceConfig:
groups: List[str] = field(default_factory=lambda: [])
path: Path = Path()
- @staticmethod
- def from_dict(data: JSON) -> "WorkspaceConfig":
- return WorkspaceConfig(
- path=Path(data["path"]),
- groups=data.get("groups", []),
- filter_in=data.get("filter_in", [".*"]),
- filter_out=data.get("filter_out", []),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "WorkspaceConfig":
+ defaults = cls()
+ return cls(
+ filter_in=data.get("filter_in", defaults.filter_in),
+ filter_out=data.get("filter_out", defaults.filter_out),
+ groups=data.get("groups", defaults.groups),
+ path=Path(data["path"]) if "path" in data else defaults.path,
)
@@ -91,18 +98,19 @@ class DbConfig:
pool_size: int = 5
pool_use_lifo: bool = False
- @staticmethod
- def from_dict(data: JSON) -> "DbConfig":
- return DbConfig(
- db_admin_url=data.get("admin_url", None),
- db_url=data.get("url", ""),
- db_connect_timeout=data.get("db_connect_timeout", 10),
- pool_recycle=data.get("pool_recycle", None),
- pool_pre_ping=data.get("pool_pre_ping", False),
- pool_use_null=data.get("pool_use_null", False),
- pool_max_overflow=data.get("pool_max_overflow", 10),
- pool_size=data.get("pool_size", 5),
- pool_use_lifo=data.get("pool_use_lifo", False),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "DbConfig":
+ defaults = cls()
+ return cls(
+ db_admin_url=data.get("admin_url", defaults.db_admin_url),
+ db_url=data.get("url", defaults.db_url),
+ db_connect_timeout=data.get("db_connect_timeout", defaults.db_connect_timeout),
+ pool_recycle=data.get("pool_recycle", defaults.pool_recycle),
+ pool_pre_ping=data.get("pool_pre_ping", defaults.pool_pre_ping),
+ pool_use_null=data.get("pool_use_null", defaults.pool_use_null),
+ pool_max_overflow=data.get("pool_max_overflow", defaults.pool_max_overflow),
+ pool_size=data.get("pool_size", defaults.pool_size),
+ pool_use_lifo=data.get("pool_use_lifo", defaults.pool_use_lifo),
)
@@ -115,7 +123,7 @@ class StorageConfig:
matrixstore: Path = Path("./matrixstore")
archive_dir: Path = Path("./archives")
tmp_dir: Path = Path(tempfile.gettempdir())
- workspaces: Dict[str, WorkspaceConfig] = field(default_factory=lambda: {})
+ workspaces: Dict[str, WorkspaceConfig] = field(default_factory=dict)
allow_deletion: bool = False
watcher_lock: bool = True
watcher_lock_delay: int = 10
@@ -127,39 +135,112 @@ class StorageConfig:
auto_archive_sleeping_time: int = 3600
auto_archive_max_parallel: int = 5
- @staticmethod
- def from_dict(data: JSON) -> "StorageConfig":
- return StorageConfig(
- tmp_dir=Path(data.get("tmp_dir", tempfile.gettempdir())),
- matrixstore=Path(data["matrixstore"]),
- workspaces={n: WorkspaceConfig.from_dict(w) for n, w in data["workspaces"].items()},
- allow_deletion=data.get("allow_deletion", False),
- archive_dir=Path(data["archive_dir"]),
- watcher_lock=data.get("watcher_lock", True),
- watcher_lock_delay=data.get("watcher_lock_delay", 10),
- download_default_expiration_timeout_minutes=data.get("download_default_expiration_timeout_minutes", 1440),
- matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", 3600),
- matrix_gc_dry_run=data.get("matrix_gc_dry_run", False),
- auto_archive_threshold_days=data.get("auto_archive_threshold_days", 60),
- auto_archive_dry_run=data.get("auto_archive_dry_run", False),
- auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", 3600),
- auto_archive_max_parallel=data.get("auto_archive_max_parallel", 5),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "StorageConfig":
+ defaults = cls()
+ workspaces = (
+ {key: WorkspaceConfig.from_dict(value) for key, value in data["workspaces"].items()}
+ if "workspaces" in data
+ else defaults.workspaces
+ )
+ return cls(
+ matrixstore=Path(data["matrixstore"]) if "matrixstore" in data else defaults.matrixstore,
+ archive_dir=Path(data["archive_dir"]) if "archive_dir" in data else defaults.archive_dir,
+ tmp_dir=Path(data["tmp_dir"]) if "tmp_dir" in data else defaults.tmp_dir,
+ workspaces=workspaces,
+ allow_deletion=data.get("allow_deletion", defaults.allow_deletion),
+ watcher_lock=data.get("watcher_lock", defaults.watcher_lock),
+ watcher_lock_delay=data.get("watcher_lock_delay", defaults.watcher_lock_delay),
+ download_default_expiration_timeout_minutes=(
+ data.get(
+ "download_default_expiration_timeout_minutes",
+ defaults.download_default_expiration_timeout_minutes,
+ )
+ ),
+ matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", defaults.matrix_gc_sleeping_time),
+ matrix_gc_dry_run=data.get("matrix_gc_dry_run", defaults.matrix_gc_dry_run),
+ auto_archive_threshold_days=data.get("auto_archive_threshold_days", defaults.auto_archive_threshold_days),
+ auto_archive_dry_run=data.get("auto_archive_dry_run", defaults.auto_archive_dry_run),
+ auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", defaults.auto_archive_sleeping_time),
+ auto_archive_max_parallel=data.get("auto_archive_max_parallel", defaults.auto_archive_max_parallel),
)
+@dataclass(frozen=True)
+class NbCoresConfig:
+ """
+ The NBCoresConfig class is designed to manage the configuration of the number of CPU cores
+ """
+
+ min: int = 1
+ default: int = 22
+ max: int = 24
+
+ def to_json(self) -> Dict[str, int]:
+ """
+ Retrieves the number of cores parameters, returning a dictionary containing the values "min"
+ (minimum allowed value), "defaultValue" (default value), and "max" (maximum allowed value)
+
+ Returns:
+ A dictionary: `{"min": min, "defaultValue": default, "max": max}`.
+ Because ReactJs Material UI expects "min", "defaultValue" and "max" keys.
+ """
+ return {"min": self.min, "defaultValue": self.default, "max": self.max}
+
+ def __post_init__(self) -> None:
+ """validation of CPU configuration"""
+ if 1 <= self.min <= self.default <= self.max:
+ return
+ msg = f"Invalid configuration: 1 <= {self.min=} <= {self.default=} <= {self.max=}"
+ raise ValueError(msg)
+
+
@dataclass(frozen=True)
class LocalConfig:
+ """Sub config object dedicated to launcher module (local)"""
+
binaries: Dict[str, Path] = field(default_factory=dict)
+ enable_nb_cores_detection: bool = True
+ nb_cores: NbCoresConfig = NbCoresConfig()
- @staticmethod
- def from_dict(data: JSON) -> Optional["LocalConfig"]:
- return LocalConfig(
- binaries={str(v): Path(p) for v, p in data["binaries"].items()},
+ @classmethod
+ def from_dict(cls, data: JSON) -> "LocalConfig":
+ """
+ Creates an instance of LocalConfig from a data dictionary
+ Args:
+ data: Parse config from dict.
+ Returns: object NbCoresConfig
+ """
+ defaults = cls()
+ binaries = data.get("binaries", defaults.binaries)
+ enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection)
+ nb_cores = data.get("nb_cores", asdict(defaults.nb_cores))
+ if enable_nb_cores_detection:
+ nb_cores.update(cls._autodetect_nb_cores())
+ return cls(
+ binaries={str(v): Path(p) for v, p in binaries.items()},
+ enable_nb_cores_detection=enable_nb_cores_detection,
+ nb_cores=NbCoresConfig(**nb_cores),
)
+ @classmethod
+ def _autodetect_nb_cores(cls) -> Dict[str, int]:
+ """
+ Automatically detects the number of cores available on the user's machine
+ Returns: Instance of NbCoresConfig
+ """
+ min_cpu = cls.nb_cores.min
+ max_cpu = multiprocessing.cpu_count()
+ default = max(min_cpu, max_cpu - 2)
+ return {"min": min_cpu, "max": max_cpu, "default": default}
+
@dataclass(frozen=True)
class SlurmConfig:
+ """
+ Sub config object dedicated to launcher module (slurm)
+ """
+
local_workspace: Path = Path()
username: str = ""
hostname: str = ""
@@ -169,31 +250,68 @@ class SlurmConfig:
password: str = ""
default_wait_time: int = 0
default_time_limit: int = 0
- default_n_cpu: int = 1
default_json_db_name: str = ""
slurm_script_path: str = ""
max_cores: int = 64
antares_versions_on_remote_server: List[str] = field(default_factory=list)
+ enable_nb_cores_detection: bool = False
+ nb_cores: NbCoresConfig = NbCoresConfig()
- @staticmethod
- def from_dict(data: JSON) -> "SlurmConfig":
- return SlurmConfig(
- local_workspace=Path(data["local_workspace"]),
- username=data["username"],
- hostname=data["hostname"],
- port=data["port"],
- private_key_file=data.get("private_key_file", None),
- key_password=data.get("key_password", None),
- password=data.get("password", None),
- default_wait_time=data["default_wait_time"],
- default_time_limit=data["default_time_limit"],
- default_n_cpu=data["default_n_cpu"],
- default_json_db_name=data["default_json_db_name"],
- slurm_script_path=data["slurm_script_path"],
- antares_versions_on_remote_server=data["antares_versions_on_remote_server"],
- max_cores=data.get("max_cores", 64),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "SlurmConfig":
+ """
+ Creates an instance of SlurmConfig from a data dictionary
+
+ Args:
+ data: Parsed config from dict.
+ Returns: object SlurmConfig
+ """
+ defaults = cls()
+ enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection)
+ nb_cores = data.get("nb_cores", asdict(defaults.nb_cores))
+ if "default_n_cpu" in data:
+ # Use the old way to configure the NB cores for backward compatibility
+ nb_cores["default"] = int(data["default_n_cpu"])
+ nb_cores["min"] = min(nb_cores["min"], nb_cores["default"])
+ nb_cores["max"] = max(nb_cores["max"], nb_cores["default"])
+ if enable_nb_cores_detection:
+ nb_cores.update(cls._autodetect_nb_cores())
+ return cls(
+ local_workspace=Path(data.get("local_workspace", defaults.local_workspace)),
+ username=data.get("username", defaults.username),
+ hostname=data.get("hostname", defaults.hostname),
+ port=data.get("port", defaults.port),
+ private_key_file=data.get("private_key_file", defaults.private_key_file),
+ key_password=data.get("key_password", defaults.key_password),
+ password=data.get("password", defaults.password),
+ default_wait_time=data.get("default_wait_time", defaults.default_wait_time),
+ default_time_limit=data.get("default_time_limit", defaults.default_time_limit),
+ default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name),
+ slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path),
+ antares_versions_on_remote_server=data.get(
+ "antares_versions_on_remote_server",
+ defaults.antares_versions_on_remote_server,
+ ),
+ max_cores=data.get("max_cores", defaults.max_cores),
+ enable_nb_cores_detection=enable_nb_cores_detection,
+ nb_cores=NbCoresConfig(**nb_cores),
)
+ @classmethod
+ def _autodetect_nb_cores(cls) -> Dict[str, int]:
+ raise NotImplementedError("NB Cores auto-detection is not implemented for SLURM server")
+
+
+class InvalidConfigurationError(Exception):
+ """
+ Exception raised when an attempt is made to retrieve the number of cores
+ of a launcher that doesn't exist in the configuration.
+ """
+
+ def __init__(self, launcher: str):
+ msg = f"Configuration is not available for the '{launcher}' launcher"
+ super().__init__(msg)
+
@dataclass(frozen=True)
class LauncherConfig:
@@ -202,27 +320,53 @@ class LauncherConfig:
"""
default: str = "local"
- local: Optional[LocalConfig] = LocalConfig()
- slurm: Optional[SlurmConfig] = SlurmConfig()
+ local: Optional[LocalConfig] = None
+ slurm: Optional[SlurmConfig] = None
batch_size: int = 9999
- @staticmethod
- def from_dict(data: JSON) -> "LauncherConfig":
- local: Optional[LocalConfig] = None
- if "local" in data:
- local = LocalConfig.from_dict(data["local"])
-
- slurm: Optional[SlurmConfig] = None
- if "slurm" in data:
- slurm = SlurmConfig.from_dict(data["slurm"])
-
- return LauncherConfig(
- default=data.get("default", "local"),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "LauncherConfig":
+ defaults = cls()
+ default = data.get("default", cls.default)
+ local = LocalConfig.from_dict(data["local"]) if "local" in data else defaults.local
+ slurm = SlurmConfig.from_dict(data["slurm"]) if "slurm" in data else defaults.slurm
+ batch_size = data.get("batch_size", defaults.batch_size)
+ return cls(
+ default=default,
local=local,
slurm=slurm,
- batch_size=data.get("batch_size", 9999),
+ batch_size=batch_size,
)
+ def __post_init__(self) -> None:
+ possible = {"local", "slurm"}
+ if self.default in possible:
+ return
+ msg = f"Invalid configuration: {self.default=} must be one of {possible!r}"
+ raise ValueError(msg)
+
+ def get_nb_cores(self, launcher: str) -> "NbCoresConfig":
+ """
+ Retrieve the number of cores configuration for a given launcher: "local" or "slurm".
+ If "default" is specified, retrieve the configuration of the default launcher.
+
+ Args:
+ launcher: type of launcher "local", "slurm" or "default".
+
+ Returns:
+ Number of cores of the given launcher.
+
+ Raises:
+ InvalidConfigurationError: Exception raised when an attempt is made to retrieve
+ the number of cores of a launcher that doesn't exist in the configuration.
+ """
+ config_map = {"local": self.local, "slurm": self.slurm}
+ config_map["default"] = config_map[self.default]
+ launcher_config = config_map.get(launcher)
+ if launcher_config is None:
+ raise InvalidConfigurationError(launcher)
+ return launcher_config.nb_cores
+
@dataclass(frozen=True)
class LoggingConfig:
@@ -234,14 +378,13 @@ class LoggingConfig:
json: bool = False
level: str = "INFO"
- @staticmethod
- def from_dict(data: JSON) -> "LoggingConfig":
- logging_config: Dict[str, Any] = data or {}
- logfile: Optional[str] = logging_config.get("logfile")
- return LoggingConfig(
- logfile=Path(logfile) if logfile is not None else None,
- json=logging_config.get("json", False),
- level=logging_config.get("level", "INFO"),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "LoggingConfig":
+ defaults = cls()
+ return cls(
+ logfile=Path(data["logfile"]) if "logfile" in data else defaults.logfile,
+ json=data.get("json", defaults.json),
+ level=data.get("level", defaults.level),
)
@@ -255,12 +398,13 @@ class RedisConfig:
port: int = 6379
password: Optional[str] = None
- @staticmethod
- def from_dict(data: JSON) -> "RedisConfig":
- return RedisConfig(
- host=data["host"],
- port=data["port"],
- password=data.get("password", None),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "RedisConfig":
+ defaults = cls()
+ return cls(
+ host=data.get("host", defaults.host),
+ port=data.get("port", defaults.port),
+ password=data.get("password", defaults.password),
)
@@ -271,9 +415,9 @@ class EventBusConfig:
"""
# noinspection PyUnusedLocal
- @staticmethod
- def from_dict(data: JSON) -> "EventBusConfig":
- return EventBusConfig()
+ @classmethod
+ def from_dict(cls, data: JSON) -> "EventBusConfig":
+ return cls()
@dataclass(frozen=True)
@@ -284,10 +428,11 @@ class CacheConfig:
checker_delay: float = 0.2 # in seconds
- @staticmethod
- def from_dict(data: JSON) -> "CacheConfig":
- return CacheConfig(
- checker_delay=float(data["checker_delay"]) if "checker_delay" in data else 0.2,
+ @classmethod
+ def from_dict(cls, data: JSON) -> "CacheConfig":
+ defaults = cls()
+ return cls(
+ checker_delay=data.get("checker_delay", defaults.checker_delay),
)
@@ -296,9 +441,13 @@ class RemoteWorkerConfig:
name: str
queues: List[str] = field(default_factory=list)
- @staticmethod
- def from_dict(data: JSON) -> "RemoteWorkerConfig":
- return RemoteWorkerConfig(name=data["name"], queues=data.get("queues", []))
+ @classmethod
+ def from_dict(cls, data: JSON) -> "RemoteWorkerConfig":
+ defaults = cls(name="") # `name` is mandatory
+ return cls(
+ name=data["name"],
+ queues=data.get("queues", defaults.queues),
+ )
@dataclass(frozen=True)
@@ -310,16 +459,17 @@ class TaskConfig:
max_workers: int = 5
remote_workers: List[RemoteWorkerConfig] = field(default_factory=list)
- @staticmethod
- def from_dict(data: JSON) -> "TaskConfig":
- return TaskConfig(
- max_workers=int(data["max_workers"]) if "max_workers" in data else 5,
- remote_workers=list(
- map(
- lambda x: RemoteWorkerConfig.from_dict(x),
- data.get("remote_workers", []),
- )
- ),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "TaskConfig":
+ defaults = cls()
+ remote_workers = (
+ [RemoteWorkerConfig.from_dict(d) for d in data["remote_workers"]]
+ if "remote_workers" in data
+ else defaults.remote_workers
+ )
+ return cls(
+ max_workers=data.get("max_workers", defaults.max_workers),
+ remote_workers=remote_workers,
)
@@ -332,11 +482,12 @@ class ServerConfig:
worker_threadpool_size: int = 5
services: List[str] = field(default_factory=list)
- @staticmethod
- def from_dict(data: JSON) -> "ServerConfig":
- return ServerConfig(
- worker_threadpool_size=int(data["worker_threadpool_size"]) if "worker_threadpool_size" in data else 5,
- services=data.get("services", []),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "ServerConfig":
+ defaults = cls()
+ return cls(
+ worker_threadpool_size=data.get("worker_threadpool_size", defaults.worker_threadpool_size),
+ services=data.get("services", defaults.services),
)
@@ -360,36 +511,27 @@ class Config:
tasks: TaskConfig = TaskConfig()
root_path: str = ""
- @staticmethod
- def from_dict(data: JSON, res: Optional[Path] = None) -> "Config":
- """
- Parse config from dict.
-
- Args:
- data: dict struct to parse
- res: resources path is not present in yaml file.
-
- Returns:
-
- """
- return Config(
- security=SecurityConfig.from_dict(data.get("security", {})),
- storage=StorageConfig.from_dict(data["storage"]),
- launcher=LauncherConfig.from_dict(data.get("launcher", {})),
- db=DbConfig.from_dict(data["db"]) if "db" in data else DbConfig(),
- logging=LoggingConfig.from_dict(data.get("logging", {})),
- debug=data.get("debug", False),
- resources_path=res or Path(),
- root_path=data.get("root_path", ""),
- redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else None,
- eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else EventBusConfig(),
- cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else CacheConfig(),
- tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else TaskConfig(),
- server=ServerConfig.from_dict(data["server"]) if "server" in data else ServerConfig(),
+ @classmethod
+ def from_dict(cls, data: JSON) -> "Config":
+ defaults = cls()
+ return cls(
+ server=ServerConfig.from_dict(data["server"]) if "server" in data else defaults.server,
+ security=SecurityConfig.from_dict(data["security"]) if "security" in data else defaults.security,
+ storage=StorageConfig.from_dict(data["storage"]) if "storage" in data else defaults.storage,
+ launcher=LauncherConfig.from_dict(data["launcher"]) if "launcher" in data else defaults.launcher,
+ db=DbConfig.from_dict(data["db"]) if "db" in data else defaults.db,
+ logging=LoggingConfig.from_dict(data["logging"]) if "logging" in data else defaults.logging,
+ debug=data.get("debug", defaults.debug),
+ resources_path=data["resources_path"] if "resources_path" in data else defaults.resources_path,
+ redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else defaults.redis,
+ eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else defaults.eventbus,
+ cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache,
+ tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks,
+ root_path=data.get("root_path", defaults.root_path),
)
- @staticmethod
- def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config":
+ @classmethod
+ def from_yaml_file(cls, file: Path, res: Optional[Path] = None) -> "Config":
"""
Parse config from yaml file.
@@ -400,5 +542,8 @@ def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config":
Returns:
"""
- data = yaml.safe_load(open(file))
- return Config.from_dict(data, res)
+ with open(file) as f:
+ data = yaml.safe_load(f)
+ if res is not None:
+ data["resources_path"] = res
+ return cls.from_dict(data)
diff --git a/antarest/core/filetransfer/model.py b/antarest/core/filetransfer/model.py
index 9442f051dc..bbb61c00b6 100644
--- a/antarest/core/filetransfer/model.py
+++ b/antarest/core/filetransfer/model.py
@@ -13,7 +13,7 @@ class FileDownloadNotFound(HTTPException):
def __init__(self) -> None:
super().__init__(
HTTPStatus.NOT_FOUND,
- f"Requested download file was not found. It must have expired",
+ "Requested download file was not found. It must have expired",
)
@@ -21,7 +21,7 @@ class FileDownloadNotReady(HTTPException):
def __init__(self) -> None:
super().__init__(
HTTPStatus.NOT_ACCEPTABLE,
- f"Requested file is not ready for download.",
+ "Requested file is not ready for download.",
)
@@ -70,4 +70,11 @@ def to_dto(self) -> FileDownloadDTO:
)
def __repr__(self) -> str:
- return f"(id={self.id},name={self.name},filename={self.filename},path={self.path},ready={self.ready},expiration_date={self.expiration_date})"
+ return (
+ f"(id={self.id},"
+ f" name={self.name},"
+ f" filename={self.filename},"
+ f" path={self.path},"
+ f" ready={self.ready},"
+ f" expiration_date={self.expiration_date})"
+ )
diff --git a/antarest/core/logging/utils.py b/antarest/core/logging/utils.py
index 9115dba45b..b0e5227ea3 100644
--- a/antarest/core/logging/utils.py
+++ b/antarest/core/logging/utils.py
@@ -124,12 +124,14 @@ def configure_logger(config: Config, handler_cls: str = "logging.FileHandler") -
"filters": ["context"],
}
elif handler_cls == "logging.handlers.TimedRotatingFileHandler":
+ # 90 days = 3 months
+ # keep only 1 backup (0 means keep all)
logging_config["handlers"]["default"] = {
"class": handler_cls,
"filename": config.logging.logfile,
- "when": "D", # D = day
- "interval": 90, # 90 days = 3 months
- "backupCount": 1, # keep only 1 backup (0 means keep all)
+ "when": "D",
+ "interval": 90,
+ "backupCount": 1,
"encoding": "utf-8",
"delay": False,
"utc": False,
diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py
index 1c9ac3bf18..227b3fbf71 100644
--- a/antarest/core/tasks/service.py
+++ b/antarest/core/tasks/service.py
@@ -80,7 +80,7 @@ def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None:
# noinspection PyUnusedLocal
def noop_notifier(message: str) -> None:
- pass
+ """This function is used in tasks when no notification is required."""
DEFAULT_AWAIT_MAX_TIMEOUT = 172800
@@ -121,7 +121,7 @@ async def _await_task_end(event: Event) -> None:
return _await_task_end
- # todo: Is `logger_` parameter required? (consider refactoring)
+ # noinspection PyUnusedLocal
def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult:
listener_id = self.event_bus.add_listener(
_create_awaiter(task_result_wrapper),
@@ -338,14 +338,18 @@ def _run_task(
result.message,
result.return_value,
)
+ event_type = {True: EventType.TASK_COMPLETED, False: EventType.TASK_FAILED}[result.success]
+ event_msg = {True: "completed", False: "failed"}[result.success]
self.event_bus.push(
Event(
- type=EventType.TASK_COMPLETED if result.success else EventType.TASK_FAILED,
+ type=event_type,
payload=TaskEventPayload(
id=task_id,
- message=custom_event_messages.end
- if custom_event_messages is not None
- else f'Task {task_id} {"completed" if result.success else "failed"}',
+ message=(
+ custom_event_messages.end
+ if custom_event_messages is not None
+ else f"Task {task_id} {event_msg}"
+ ),
).dict(),
permissions=PermissionInfo(public_mode=PublicMode.READ),
channel=EventChannelDirectory.TASK + task_id,
diff --git a/antarest/launcher/adapters/local_launcher/local_launcher.py b/antarest/launcher/adapters/local_launcher/local_launcher.py
index 7865f8740c..8ee598985b 100644
--- a/antarest/launcher/adapters/local_launcher/local_launcher.py
+++ b/antarest/launcher/adapters/local_launcher/local_launcher.py
@@ -1,3 +1,4 @@
+import io
import logging
import shutil
import signal
@@ -6,7 +7,7 @@
import threading
import time
from pathlib import Path
-from typing import IO, Callable, Dict, Optional, Tuple, cast
+from typing import Callable, Dict, Optional, Tuple, cast
from uuid import UUID
from antarest.core.config import Config
@@ -14,7 +15,7 @@
from antarest.core.interfaces.eventbus import IEventBus
from antarest.core.requests import RequestParameters
from antarest.launcher.adapters.abstractlauncher import AbstractLauncher, LauncherCallbacks, LauncherInitException
-from antarest.launcher.adapters.log_manager import LogTailManager
+from antarest.launcher.adapters.log_manager import follow
from antarest.launcher.model import JobStatus, LauncherParametersDTO, LogType
logger = logging.getLogger(__name__)
@@ -133,8 +134,8 @@ def stop_reading_output() -> bool:
)
thread = threading.Thread(
- target=lambda: LogTailManager.follow(
- cast(IO[str], process.stdout),
+ target=lambda: follow(
+ cast(io.StringIO, process.stdout),
self.create_update_log(str(uuid)),
stop_reading_output,
None,
diff --git a/antarest/launcher/adapters/log_manager.py b/antarest/launcher/adapters/log_manager.py
index a0bfbdfe70..eeca586f32 100644
--- a/antarest/launcher/adapters/log_manager.py
+++ b/antarest/launcher/adapters/log_manager.py
@@ -1,8 +1,10 @@
+import contextlib
+import io
import logging
import time
from pathlib import Path
from threading import Thread
-from typing import IO, Callable, Dict, Optional
+from typing import Callable, Dict, Optional, cast
logger = logging.getLogger(__name__)
@@ -11,7 +13,7 @@ class LogTailManager:
BATCH_SIZE = 10
def __init__(self, log_base_dir: Path) -> None:
- logger.info(f"Initiating Log manager")
+ logger.info("Initiating Log manager")
self.log_base_dir = log_base_dir
self.tracked_logs: Dict[str, Thread] = {}
@@ -47,43 +49,6 @@ def stop_tracking(self, log_path: Optional[Path]) -> None:
if log_path_key in self.tracked_logs:
del self.tracked_logs[log_path_key]
- @staticmethod
- def follow(
- io: IO[str],
- handler: Callable[[str], None],
- stop: Callable[[], bool],
- log_file: Optional[str],
- ) -> None:
- line = ""
- line_count = 0
-
- while True:
- if stop():
- break
- tmp = io.readline()
- if not tmp:
- if line:
- logger.debug(f"Calling handler for {log_file}")
- try:
- handler(line)
- except Exception as e:
- logger.error("Could not handle this log line", exc_info=e)
- line = ""
- line_count = 0
- time.sleep(0.1)
- else:
- line += tmp
- if line.endswith("\n"):
- line_count += 1
- if line_count >= LogTailManager.BATCH_SIZE:
- logger.debug(f"Calling handler for {log_file}")
- try:
- handler(line)
- except Exception as e:
- logger.error("Could not handle this log line", exc_info=e)
- line = ""
- line_count = 0
-
def _follow(
self,
log_file: Optional[Path],
@@ -97,4 +62,37 @@ def _follow(
with open(log_file, "r") as fh:
logger.info(f"Scanning {log_file}")
- LogTailManager.follow(fh, handler, stop, str(log_file))
+ follow(cast(io.StringIO, fh), handler, stop, str(log_file))
+
+
+def follow(
+ file: io.StringIO,
+ handler: Callable[[str], None],
+ stop: Callable[[], bool],
+ log_file: Optional[str],
+) -> None:
+ line = ""
+ line_count = 0
+
+ while True:
+ if stop():
+ break
+ tmp = file.readline()
+ if tmp:
+ line += tmp
+ if line.endswith("\n"):
+ line_count += 1
+ if line_count >= LogTailManager.BATCH_SIZE:
+ logger.debug(f"Calling handler for {log_file}")
+ with contextlib.suppress(Exception):
+ handler(line)
+ line = ""
+ line_count = 0
+ else:
+ if line:
+ logger.debug(f"Calling handler for {log_file}")
+ with contextlib.suppress(Exception):
+ handler(line)
+ line = ""
+ line_count = 0
+ time.sleep(0.1)
diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
index 926f2b50a3..00283b9ce8 100644
--- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
+++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
@@ -32,7 +32,6 @@
logger = logging.getLogger(__name__)
logging.getLogger("paramiko").setLevel("WARN")
-MAX_NB_CPU = 24
MAX_TIME_LIMIT = 864000
MIN_TIME_LIMIT = 3600
WORKSPACE_LOCK_FILE_NAME = ".lock"
@@ -153,7 +152,7 @@ def _init_launcher_arguments(self, local_workspace: Optional[Path] = None) -> ar
main_options_parameters = ParserParameters(
default_wait_time=self.slurm_config.default_wait_time,
default_time_limit=self.slurm_config.default_time_limit,
- default_n_cpu=self.slurm_config.default_n_cpu,
+ default_n_cpu=self.slurm_config.nb_cores.default,
studies_in_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_INPUT_DIR_NAME)),
log_dir=str((Path(self.slurm_config.local_workspace) / LOG_DIR_NAME)),
finished_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_OUTPUT_DIR_NAME)),
@@ -440,7 +439,7 @@ def _run_study(
_override_solver_version(study_path, version)
append_log(launch_uuid, "Submitting study to slurm launcher")
- launcher_args = self._check_and_apply_launcher_params(launcher_params)
+ launcher_args = self._apply_params(launcher_params)
self._call_launcher(launcher_args, self.launcher_params)
launch_success = self._check_if_study_is_in_launcher_db(launch_uuid)
@@ -481,23 +480,40 @@ def _check_if_study_is_in_launcher_db(self, job_id: str) -> bool:
studies = self.data_repo_tinydb.get_list_of_studies()
return any(s.name == job_id for s in studies)
- def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace:
+ def _apply_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace:
+ """
+ Populate a `argparse.Namespace` object with the user parameters.
+
+ Args:
+ launcher_params:
+ Contains the launcher parameters selected by the user.
+ If a parameter is not provided (`None`), the default value should be retrieved
+ from the configuration.
+
+ Returns:
+ The `argparse.Namespace` object which is then passed to `antarestlauncher.main.run_with`,
+ to launch a simulation using Antares Launcher.
+ """
if launcher_params:
launcher_args = deepcopy(self.launcher_args)
- other_options = []
+
if launcher_params.other_options:
- options = re.split("\\s+", launcher_params.other_options)
- for opt in options:
- other_options.append(re.sub("[^a-zA-Z0-9_,-]", "", opt))
- if launcher_params.xpansion is not None:
- launcher_args.xpansion_mode = "r" if launcher_params.xpansion_r_version else "cpp"
+ options = launcher_params.other_options.split()
+ other_options = [re.sub("[^a-zA-Z0-9_,-]", "", opt) for opt in options]
+ else:
+ other_options = []
+
+ # launcher_params.xpansion can be an `XpansionParametersDTO`, a bool or `None`
+ if launcher_params.xpansion: # not None and not False
+ launcher_args.xpansion_mode = {True: "r", False: "cpp"}[launcher_params.xpansion_r_version]
if (
isinstance(launcher_params.xpansion, XpansionParametersDTO)
and launcher_params.xpansion.sensitivity_mode
):
other_options.append("xpansion_sensitivity")
+
time_limit = launcher_params.time_limit
- if time_limit and isinstance(time_limit, int):
+ if time_limit is not None:
if MIN_TIME_LIMIT > time_limit:
logger.warning(
f"Invalid slurm launcher time limit ({time_limit}),"
@@ -512,15 +528,23 @@ def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDT
launcher_args.time_limit = MAX_TIME_LIMIT - 3600
else:
launcher_args.time_limit = time_limit
+
post_processing = launcher_params.post_processing
- if isinstance(post_processing, bool):
+ if post_processing is not None:
launcher_args.post_processing = post_processing
+
nb_cpu = launcher_params.nb_cpu
- if nb_cpu and isinstance(nb_cpu, int):
- if 0 < nb_cpu <= MAX_NB_CPU:
+ if nb_cpu is not None:
+ nb_cores = self.slurm_config.nb_cores
+ if nb_cores.min <= nb_cpu <= nb_cores.max:
launcher_args.n_cpu = nb_cpu
else:
- logger.warning(f"Invalid slurm launcher nb_cpu ({nb_cpu}), should be between 1 and 24")
+ logger.warning(
+ f"Invalid slurm launcher nb_cpu ({nb_cpu}),"
+ f" should be between {nb_cores.min} and {nb_cores.max}"
+ )
+ launcher_args.n_cpu = nb_cores.default
+
if launcher_params.adequacy_patch is not None: # the adequacy patch can be an empty object
launcher_args.post_processing = True
diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py
index 7a3c615811..a9bf0f6fde 100644
--- a/antarest/launcher/model.py
+++ b/antarest/launcher/model.py
@@ -17,14 +17,14 @@ class XpansionParametersDTO(BaseModel):
class LauncherParametersDTO(BaseModel):
- # Warning ! This class must be retrocompatible (that's the reason for the weird bool/XpansionParametersDTO union)
+ # Warning ! This class must be retro-compatible (that's the reason for the weird bool/XpansionParametersDTO union)
# The reason is that it's stored in json format in database and deserialized using the latest class version
# If compatibility is to be broken, an (alembic) data migration script should be added
adequacy_patch: Optional[Dict[str, Any]] = None
nb_cpu: Optional[int] = None
post_processing: bool = False
- time_limit: Optional[int] = None
- xpansion: Union[bool, Optional[XpansionParametersDTO]] = None
+ time_limit: Optional[int] = None # 3600 <= time_limit < 864000 (10 days)
+ xpansion: Union[XpansionParametersDTO, bool, None] = None
xpansion_r_version: bool = False
archive_output: bool = True
auto_unzip: bool = True
diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py
index 340c00c27c..7312181dea 100644
--- a/antarest/launcher/service.py
+++ b/antarest/launcher/service.py
@@ -1,4 +1,5 @@
import functools
+import json
import logging
import os
import shutil
@@ -10,7 +11,7 @@
from fastapi import HTTPException
-from antarest.core.config import Config
+from antarest.core.config import Config, NbCoresConfig
from antarest.core.exceptions import StudyNotFoundError
from antarest.core.filetransfer.model import FileDownloadTaskDTO
from antarest.core.filetransfer.service import FileTransferManager
@@ -99,6 +100,21 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]:
def get_launchers(self) -> List[str]:
return list(self.launchers.keys())
+ def get_nb_cores(self, launcher: str) -> NbCoresConfig:
+ """
+ Retrieve the configuration of the launcher's nb of cores.
+
+ Args:
+ launcher: name of the launcher: "default", "slurm" or "local".
+
+ Returns:
+ Number of cores of the launcher
+
+ Raises:
+ InvalidConfigurationError: if the launcher configuration is not available
+ """
+ return self.config.launcher.get_nb_cores(launcher)
+
def _after_export_flat_hooks(
self,
job_id: str,
@@ -160,7 +176,7 @@ def update(
channel=EventChannelDirectory.JOB_STATUS + job_result.id,
)
)
- logger.info(f"Study status set")
+ logger.info("Study status set")
def append_log(self, job_id: str, message: str, log_type: JobLogType) -> None:
try:
@@ -586,27 +602,31 @@ def get_load(self, from_cluster: bool = False) -> Dict[str, float]:
local_running_jobs.append(job)
else:
logger.warning(f"Unknown job launcher {job.launcher}")
+
load = {}
- if self.config.launcher.slurm:
+
+ slurm_config = self.config.launcher.slurm
+ if slurm_config is not None:
if from_cluster:
- raise NotImplementedError
- slurm_used_cpus = functools.reduce(
- lambda count, j: count
- + (
- LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu
- or self.config.launcher.slurm.default_n_cpu # type: ignore
- ),
- slurm_running_jobs,
- 0,
- )
- load["slurm"] = float(slurm_used_cpus) / self.config.launcher.slurm.max_cores
- if self.config.launcher.local:
- local_used_cpus = functools.reduce(
- lambda count, j: count + (LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu or 1),
- local_running_jobs,
- 0,
- )
- load["local"] = float(local_used_cpus) / (os.cpu_count() or 1)
+ raise NotImplementedError("Cluster load not implemented yet")
+ default_cpu = slurm_config.nb_cores.default
+ slurm_used_cpus = 0
+ for job in slurm_running_jobs:
+ obj = json.loads(job.launcher_params) if job.launcher_params else {}
+ launch_params = LauncherParametersDTO(**obj)
+ slurm_used_cpus += launch_params.nb_cpu or default_cpu
+ load["slurm"] = slurm_used_cpus / slurm_config.max_cores
+
+ local_config = self.config.launcher.local
+ if local_config is not None:
+ default_cpu = local_config.nb_cores.default
+ local_used_cpus = 0
+ for job in local_running_jobs:
+ obj = json.loads(job.launcher_params) if job.launcher_params else {}
+ launch_params = LauncherParametersDTO(**obj)
+ local_used_cpus += launch_params.nb_cpu or default_cpu
+ load["local"] = local_used_cpus / local_config.nb_cores.max
+
return load
def get_solver_versions(self, solver: str) -> List[str]:
diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py
index ffb4cf6ccf..51b3582997 100644
--- a/antarest/launcher/web.py
+++ b/antarest/launcher/web.py
@@ -6,7 +6,7 @@
from fastapi import APIRouter, Depends, Query
from fastapi.exceptions import HTTPException
-from antarest.core.config import Config
+from antarest.core.config import Config, InvalidConfigurationError
from antarest.core.filetransfer.model import FileDownloadTaskDTO
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters
@@ -230,4 +230,48 @@ def get_solver_versions(
raise UnknownSolverConfig(solver)
return service.get_solver_versions(solver)
+ # noinspection SpellCheckingInspection
+ @bp.get(
+ "/launcher/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints
+ tags=[APITag.launcher],
+ summary="Retrieving Min, Default, and Max Core Count",
+ response_model=Dict[str, int],
+ )
+ def get_nb_cores(
+ launcher: str = Query(
+ "default",
+ examples={
+ "Default launcher": {
+ "description": "Min, Default, and Max Core Count",
+ "value": "default",
+ },
+ "SLURM launcher": {
+ "description": "Min, Default, and Max Core Count",
+ "value": "slurm",
+ },
+ "Local launcher": {
+ "description": "Min, Default, and Max Core Count",
+ "value": "local",
+ },
+ },
+ )
+ ) -> Dict[str, int]:
+ """
+ Retrieve the numer of cores of the launcher.
+
+ Args:
+ - `launcher`: name of the configuration to read: "slurm" or "local".
+ If "default" is specified, retrieve the configuration of the default launcher.
+
+ Returns:
+ - "min": min number of cores
+ - "defaultValue": default number of cores
+ - "max": max number of cores
+ """
+ logger.info(f"Fetching the number of cores for the '{launcher}' configuration")
+ try:
+ return service.config.launcher.get_nb_cores(launcher).to_json()
+ except InvalidConfigurationError:
+ raise UnknownSolverConfig(launcher)
+
return bp
diff --git a/antarest/login/web.py b/antarest/login/web.py
index 86561457ac..ab63ec16b2 100644
--- a/antarest/login/web.py
+++ b/antarest/login/web.py
@@ -110,7 +110,7 @@ def users_get_all(
details: Optional[bool] = False,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
- logger.info(f"Fetching users list", extra={"user": current_user.id})
+ logger.info("Fetching users list", extra={"user": current_user.id})
params = RequestParameters(user=current_user)
return service.get_all_users(params, details)
@@ -188,7 +188,7 @@ def groups_get_all(
details: Optional[bool] = False,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
- logger.info(f"Fetching groups list", extra={"user": current_user.id})
+ logger.info("Fetching groups list", extra={"user": current_user.id})
params = RequestParameters(user=current_user)
return service.get_all_groups(params, details)
diff --git a/antarest/matrixstore/web.py b/antarest/matrixstore/web.py
index 4b47135b52..523176b241 100644
--- a/antarest/matrixstore/web.py
+++ b/antarest/matrixstore/web.py
@@ -37,7 +37,7 @@ def create(
matrix: List[List[MatrixData]] = Body(description="matrix dto", default=[]),
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
- logger.info(f"Creating new matrix", extra={"user": current_user.id})
+ logger.info("Creating new matrix", extra={"user": current_user.id})
if current_user.id is not None:
return service.create(matrix)
raise UserHasNotPermissionError()
@@ -60,7 +60,7 @@ def create_by_importation(
@bp.get("/matrix/{id}", tags=[APITag.matrix], response_model=MatrixDTO)
def get(id: str, user: JWTUser = Depends(auth.get_current_user)) -> Any:
- logger.info(f"Fetching matrix", extra={"user": user.id})
+ logger.info("Fetching matrix", extra={"user": user.id})
if user.id is not None:
return service.get(id)
raise UserHasNotPermissionError()
diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py
index 0fedb98393..ca1f714750 100644
--- a/antarest/study/business/binding_constraint_management.py
+++ b/antarest/study/business/binding_constraint_management.py
@@ -12,8 +12,9 @@
from antarest.matrixstore.model import MatrixData
from antarest.study.business.utils import execute_or_add_commands
from antarest.study.model import Study
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.storage_service import StudyStorageService
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint
@@ -43,7 +44,7 @@ class BindingConstraintDTO(BaseModel):
id: str
name: str
enabled: bool = True
- time_step: TimeStep
+ time_step: BindingConstraintFrequency
operator: BindingConstraintOperator
values: Optional[Union[List[List[MatrixData]], str]] = None
comments: Optional[str] = None
diff --git a/antarest/study/business/st_storage_manager.py b/antarest/study/business/st_storage_manager.py
index d1c040741b..f16ff680d4 100644
--- a/antarest/study/business/st_storage_manager.py
+++ b/antarest/study/business/st_storage_manager.py
@@ -1,10 +1,10 @@
import functools
import json
import operator
-from typing import Any, Dict, List, Mapping, MutableMapping, Sequence
+from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence
import numpy as np
-from pydantic import BaseModel, Extra, Field, root_validator, validator
+from pydantic import BaseModel, Extra, root_validator, validator
from typing_extensions import Literal
from antarest.core.exceptions import (
@@ -12,9 +12,13 @@
STStorageFieldsNotFoundError,
STStorageMatrixNotFoundError,
)
-from antarest.study.business.utils import AllOptionalMetaclass, FormFieldsBaseModel, execute_or_add_commands
+from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands
from antarest.study.model import Study
-from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageConfig, STStorageGroup
+from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import (
+ STStorageConfig,
+ STStorageGroup,
+ STStorageProperties,
+)
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.storage.variantstudy.model.command.create_st_storage import CreateSTStorage
from antarest.study.storage.variantstudy.model.command.remove_st_storage import RemoveSTStorage
@@ -23,77 +27,12 @@
_HOURS_IN_YEAR = 8760
-class FormBaseModel(FormFieldsBaseModel):
- """
- A foundational model for all form-based models, providing common configurations.
- """
-
- class Config:
- validate_assignment = True
- allow_population_by_field_name = True
-
-
-class StorageCreation(FormBaseModel):
+@camel_case_model
+class StorageInput(STStorageProperties, metaclass=AllOptionalMetaclass):
"""
- Model representing the form used to create a new short-term storage entry.
+ Model representing the form used to EDIT an existing short-term storage.
"""
- name: str = Field(
- description="Name of the storage.",
- regex=r"[a-zA-Z0-9_(),& -]+",
- )
- group: STStorageGroup = Field(
- description="Energy storage system group.",
- )
-
- class Config:
- @staticmethod
- def schema_extra(schema: MutableMapping[str, Any]) -> None:
- schema["example"] = StorageCreation(
- name="Siemens Battery",
- group=STStorageGroup.BATTERY,
- )
-
- @property
- def to_config(self) -> STStorageConfig:
- values = self.dict(by_alias=False)
- return STStorageConfig(**values)
-
-
-class StorageUpdate(StorageCreation, metaclass=AllOptionalMetaclass):
- """set name, group as optional fields"""
-
-
-class StorageInput(StorageUpdate):
- """
- Model representing the form used to edit existing short-term storage details.
- """
-
- injection_nominal_capacity: float = Field(
- description="Injection nominal capacity (MW)",
- ge=0,
- )
- withdrawal_nominal_capacity: float = Field(
- description="Withdrawal nominal capacity (MW)",
- ge=0,
- )
- reservoir_capacity: float = Field(
- description="Reservoir capacity (MWh)",
- ge=0,
- )
- efficiency: float = Field(
- description="Efficiency of the storage system",
- ge=0,
- le=1,
- )
- initial_level: float = Field(
- description="Initial level of the storage system",
- ge=0,
- )
- initial_level_optim: bool = Field(
- description="Flag indicating if the initial level is optimized",
- )
-
class Config:
@staticmethod
def schema_extra(schema: MutableMapping[str, Any]) -> None:
@@ -104,19 +43,37 @@ def schema_extra(schema: MutableMapping[str, Any]) -> None:
withdrawal_nominal_capacity=150,
reservoir_capacity=600,
efficiency=0.94,
+ initial_level=0.5,
initial_level_optim=True,
)
-class StorageOutput(StorageInput):
+class StorageCreation(StorageInput):
"""
- Model representing the form used to display the details of a short-term storage entry.
+ Model representing the form used to CREATE a new short-term storage.
"""
- id: str = Field(
- description="Short-term storage ID",
- regex=r"[a-zA-Z0-9_(),& -]+",
- )
+ # noinspection Pydantic
+ @validator("name", pre=True)
+ def validate_name(cls, name: Optional[str]) -> str:
+ """
+ Validator to check if the name is not empty.
+ """
+ if not name:
+ raise ValueError("'name' must not be empty")
+ return name
+
+ @property
+ def to_config(self) -> STStorageConfig:
+ values = self.dict(by_alias=False, exclude_none=True)
+ return STStorageConfig(**values)
+
+
+@camel_case_model
+class StorageOutput(STStorageConfig):
+ """
+ Model representing the form used to display the details of a short-term storage entry.
+ """
class Config:
@staticmethod
diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py
index b602c60138..33b62d766c 100644
--- a/antarest/study/business/utils.py
+++ b/antarest/study/business/utils.py
@@ -117,3 +117,19 @@ def __new__(
annotations[field] = Optional[annotations[field]]
namespaces["__annotations__"] = annotations
return super().__new__(cls, name, bases, namespaces)
+
+
+def camel_case_model(model: Type[BaseModel]) -> Type[BaseModel]:
+ """
+ This decorator can be used to modify a model to use camel case aliases.
+
+ Args:
+ model: The pydantic model to modify.
+
+ Returns:
+ The modified model.
+ """
+ model.__config__.alias_generator = to_camel_case
+ for field_name, field in model.__fields__.items():
+ field.alias = to_camel_case(field_name)
+ return model
diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py
index 4dec6e3ef6..e8c6c79f98 100644
--- a/antarest/study/business/xpansion_management.py
+++ b/antarest/study/business/xpansion_management.py
@@ -436,7 +436,7 @@ def _assert_candidate_is_correct(
xpansion_candidate_dto: XpansionCandidateDTO,
new_name: bool = False,
) -> None:
- logger.info(f"Checking given candidate is correct")
+ logger.info("Checking given candidate is correct")
self._assert_no_illegal_character_is_in_candidate_name(xpansion_candidate_dto.name)
if new_name:
self._assert_candidate_name_is_not_already_taken(candidates, xpansion_candidate_dto.name)
diff --git a/antarest/study/service.py b/antarest/study/service.py
index b9ec491c18..64d5c7d83c 100644
--- a/antarest/study/service.py
+++ b/antarest/study/service.py
@@ -641,15 +641,20 @@ def create_study(
def get_user_name(self, params: RequestParameters) -> str:
"""
- Args: params : Request parameters
+ Retrieves the name of a user based on the provided request parameters.
- Returns: The user's name
+ Args:
+ params: The request parameters which includes user information.
+
+ Returns:
+ Returns the user's name or, if the logged user is a "bot"
+ (i.e., an application's token), it returns the token's author name.
"""
- author = "Unknown"
if params.user:
- if curr_user := self.user_service.get_user(params.user.id, params):
- author = curr_user.to_dto().name
- return author
+ user_id = params.user.impersonator if params.user.type == "bots" else params.user.id
+ if curr_user := self.user_service.get_user(user_id, params):
+ return curr_user.to_dto().name
+ return "Unknown"
def get_study_synthesis(self, study_id: str, params: RequestParameters) -> FileStudyTreeConfigDTO:
"""
diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py b/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py
new file mode 100644
index 0000000000..3648051773
--- /dev/null
+++ b/antarest/study/storage/rawstudy/model/filesystem/config/binding_constraint.py
@@ -0,0 +1,17 @@
+from enum import Enum
+from typing import Set
+
+from pydantic import BaseModel
+
+
+class BindingConstraintFrequency(str, Enum):
+ HOURLY = "hourly"
+ DAILY = "daily"
+ WEEKLY = "weekly"
+
+
+class BindingConstraintDTO(BaseModel):
+ id: str
+ areas: Set[str]
+ clusters: Set[str]
+ time_step: BindingConstraintFrequency
diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/files.py b/antarest/study/storage/rawstudy/model/filesystem/config/files.py
index 9f8111911a..8e174f3ea0 100644
--- a/antarest/study/storage/rawstudy/model/filesystem/config/files.py
+++ b/antarest/study/storage/rawstudy/model/filesystem/config/files.py
@@ -10,13 +10,16 @@
from antarest.core.model import JSON
from antarest.core.utils.utils import extract_file_to_tmp_dir
from antarest.study.storage.rawstudy.io.reader import IniReader, MultipleSameKeysIniReader
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import (
+ BindingConstraintDTO,
+ BindingConstraintFrequency,
+)
from antarest.study.storage.rawstudy.model.filesystem.config.exceptions import (
SimulationParsingError,
XpansionParsingError,
)
from antarest.study.storage.rawstudy.model.filesystem.config.model import (
Area,
- BindingConstraintDTO,
Cluster,
DistrictSet,
FileStudyTreeConfig,
@@ -143,8 +146,12 @@ def _parse_bindings(root: Path) -> List[BindingConstraintDTO]:
area_set = set()
# contains a set of strings in the following format: "area.cluster"
cluster_set = set()
+ # Default value for time_step
+ time_step = BindingConstraintFrequency.HOURLY
for key in bind:
- if "%" in key:
+ if key == "type":
+ time_step = BindingConstraintFrequency(bind[key])
+ elif "%" in key:
areas = key.split("%", 1)
area_set.add(areas[0])
area_set.add(areas[1])
@@ -152,7 +159,9 @@ def _parse_bindings(root: Path) -> List[BindingConstraintDTO]:
cluster_set.add(key)
area_set.add(key.split(".", 1)[0])
- output_list.append(BindingConstraintDTO(id=bind["id"], areas=area_set, clusters=cluster_set))
+ output_list.append(
+ BindingConstraintDTO(id=bind["id"], areas=area_set, clusters=cluster_set, time_step=time_step)
+ )
return output_list
diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py
index f688765922..774d225344 100644
--- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py
+++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py
@@ -1,7 +1,7 @@
import re
from enum import Enum
from pathlib import Path
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional
from pydantic import Extra
from pydantic.main import BaseModel
@@ -9,6 +9,7 @@
from antarest.core.model import JSON
from antarest.core.utils.utils import DTO
+from .binding_constraint import BindingConstraintDTO
from .st_storage import STStorageConfig
@@ -106,12 +107,6 @@ def get_file(self) -> str:
return f"{self.date}{modes[self.mode]}{dash}{self.name}"
-class BindingConstraintDTO(BaseModel):
- id: str
- areas: Set[str]
- clusters: Set[str]
-
-
class FileStudyTreeConfig(DTO):
"""
Root object to handle all study parameters which impact tree structure
diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py
index d2a68c4799..b82910a191 100644
--- a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py
+++ b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py
@@ -30,22 +30,18 @@ class STStorageGroup(EnumIgnoreCase):
# noinspection SpellCheckingInspection
-class STStorageConfig(BaseModel):
- """
- Manage the configuration files in the context of Short-Term Storage.
- It provides a convenient way to read and write configuration data from/to an INI file format.
+class STStorageProperties(
+ BaseModel,
+ extra=Extra.forbid,
+ validate_assignment=True,
+ allow_population_by_field_name=True,
+):
"""
+ Properties of a short-term storage system read from the configuration files.
- class Config:
- extra = Extra.forbid
- allow_population_by_field_name = True
+ All aliases match the name of the corresponding field in the INI files.
+ """
- # The `id` field is a calculated from the `name` if not provided.
- # This value must be stored in the config cache.
- id: str = Field(
- description="Short-term storage ID",
- regex=r"[a-zA-Z0-9_(),& -]+",
- )
name: str = Field(
description="Short-term storage name",
regex=r"[a-zA-Z0-9_(),& -]+",
@@ -90,6 +86,21 @@ class Config:
alias="initialleveloptim",
)
+
+# noinspection SpellCheckingInspection
+class STStorageConfig(STStorageProperties):
+ """
+ Manage the configuration files in the context of Short-Term Storage.
+ It provides a convenient way to read and write configuration data from/to an INI file format.
+ """
+
+ # The `id` field is a calculated from the `name` if not provided.
+ # This value must be stored in the config cache.
+ id: str = Field(
+ description="Short-term storage ID",
+ regex=r"[a-zA-Z0-9_(),& -]+",
+ )
+
@root_validator(pre=True)
def calculate_storage_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py
index 92d5443957..e86dedfe18 100644
--- a/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py
+++ b/antarest/study/storage/rawstudy/model/filesystem/root/input/bindingconstraints/bindingcontraints.py
@@ -1,16 +1,33 @@
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.folder_node import FolderNode
from antarest.study.storage.rawstudy.model.filesystem.inode import TREE
from antarest.study.storage.rawstudy.model.filesystem.matrix.input_series_matrix import InputSeriesMatrix
+from antarest.study.storage.rawstudy.model.filesystem.matrix.matrix import MatrixFrequency
from antarest.study.storage.rawstudy.model.filesystem.root.input.bindingconstraints.bindingconstraints_ini import (
BindingConstraintsIni,
)
+from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import (
+ default_binding_constraint_daily,
+ default_binding_constraint_hourly,
+ default_binding_constraint_weekly,
+)
class BindingConstraints(FolderNode):
def build(self) -> TREE:
+ default_matrices = {
+ BindingConstraintFrequency.HOURLY: default_binding_constraint_hourly,
+ BindingConstraintFrequency.DAILY: default_binding_constraint_daily,
+ BindingConstraintFrequency.WEEKLY: default_binding_constraint_weekly,
+ }
children: TREE = {
- binding.id: InputSeriesMatrix(self.context, self.config.next_file(f"{binding.id}.txt"))
- # todo get the freq of binding to set the default empty matrix
+ binding.id: InputSeriesMatrix(
+ self.context,
+ self.config.next_file(f"{binding.id}.txt"),
+ freq=MatrixFrequency(binding.time_step),
+ nb_columns=3,
+ default_empty=default_matrices[binding.time_step],
+ )
for binding in self.config.bindings
}
diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py
index 0602c8b62a..ab1f8b6d42 100644
--- a/antarest/study/storage/variantstudy/business/command_extractor.py
+++ b/antarest/study/storage/variantstudy/business/command_extractor.py
@@ -9,12 +9,13 @@
from antarest.matrixstore.model import MatrixData
from antarest.matrixstore.service import ISimpleMatrixService
from antarest.study.storage.patch_service import PatchService
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.files import get_playlist
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import FileStudyTree
from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants
from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
@@ -348,7 +349,7 @@ def extract_binding_constraint(
binding_constraint_command = CreateBindingConstraint(
name=binding["name"],
enabled=binding["enabled"],
- time_step=TimeStep(binding["type"]),
+ time_step=BindingConstraintFrequency(binding["type"]),
operator=BindingConstraintOperator(binding["operator"]),
coeffs={
coeff: [float(el) for el in str(value).split("%")]
diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py b/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py
index 0f9b1e77ca..9212a8e6c6 100644
--- a/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py
+++ b/antarest/study/storage/variantstudy/business/matrix_constants/__init__.py
@@ -1 +1 @@
-from . import hydro, link, prepro, st_storage, thermals
+from . import binding_constraint, hydro, link, prepro, st_storage, thermals
diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py
new file mode 100644
index 0000000000..0a1b9046e5
--- /dev/null
+++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/__init__.py
@@ -0,0 +1 @@
+from . import series
diff --git a/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py
new file mode 100644
index 0000000000..e7b20a1137
--- /dev/null
+++ b/antarest/study/storage/variantstudy/business/matrix_constants/binding_constraint/series.py
@@ -0,0 +1,10 @@
+import numpy as np
+
+default_binding_constraint_hourly = np.zeros((8760, 3), dtype=np.float64)
+default_binding_constraint_hourly.flags.writeable = False
+
+default_binding_constraint_daily = np.zeros((365, 3), dtype=np.float64)
+default_binding_constraint_daily.flags.writeable = False
+
+default_binding_constraint_weekly = np.zeros((52, 3), dtype=np.float64)
+default_binding_constraint_weekly.flags.writeable = False
diff --git a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py
index 338506dd7c..8cb973785e 100644
--- a/antarest/study/storage/variantstudy/business/matrix_constants_generator.py
+++ b/antarest/study/storage/variantstudy/business/matrix_constants_generator.py
@@ -34,6 +34,10 @@
EMPTY_SCENARIO_MATRIX = "empty_scenario_matrix"
ONES_SCENARIO_MATRIX = "ones_scenario_matrix"
+# Binding constraint aliases
+BINDING_CONSTRAINT_HOURLY = "empty_2nd_member_hourly"
+BINDING_CONSTRAINT_DAILY = "empty_2nd_member_daily"
+BINDING_CONSTRAINT_WEEKLY = "empty_2nd_member_daily"
# Short-term storage aliases
ST_STORAGE_PMAX_INJECTION = ONES_SCENARIO_MATRIX
@@ -84,6 +88,12 @@ def _init(self) -> None:
self.hashes[RESERVES_TS] = self.matrix_service.create(FIXED_4_COLUMNS)
self.hashes[MISCGEN_TS] = self.matrix_service.create(FIXED_8_COLUMNS)
+ # Binding constraint matrices
+ series = matrix_constants.binding_constraint.series
+ self.hashes[BINDING_CONSTRAINT_HOURLY] = self.matrix_service.create(series.default_binding_constraint_hourly)
+ self.hashes[BINDING_CONSTRAINT_DAILY] = self.matrix_service.create(series.default_binding_constraint_daily)
+ self.hashes[BINDING_CONSTRAINT_WEEKLY] = self.matrix_service.create(series.default_binding_constraint_weekly)
+
# Some short-term storage matrices use np.ones((8760, 1))
self.hashes[ONES_SCENARIO_MATRIX] = self.matrix_service.create(
matrix_constants.st_storage.series.pmax_injection
@@ -141,6 +151,18 @@ def get_default_reserves(self) -> str:
def get_default_miscgen(self) -> str:
return MATRIX_PROTOCOL_PREFIX + self.hashes[MISCGEN_TS]
+ def get_binding_constraint_hourly(self) -> str:
+ """2D-matrix of shape (8760, 3), filled-in with zeros."""
+ return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_HOURLY]
+
+ def get_binding_constraint_daily(self) -> str:
+ """2D-matrix of shape (365, 3), filled-in with zeros."""
+ return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_DAILY]
+
+ def get_binding_constraint_weekly(self) -> str:
+ """2D-matrix of shape (52, 3), filled-in with zeros."""
+ return MATRIX_PROTOCOL_PREFIX + self.hashes[BINDING_CONSTRAINT_WEEKLY]
+
def get_st_storage_pmax_injection(self) -> str:
"""2D-matrix of shape (8760, 1), filled-in with ones."""
return MATRIX_PROTOCOL_PREFIX + self.hashes[ST_STORAGE_PMAX_INJECTION]
diff --git a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py
index 13e3617ec4..b2eded8a6a 100644
--- a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py
+++ b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py
@@ -1,10 +1,14 @@
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
from antarest.core.model import JSON
from antarest.matrixstore.model import MatrixData
-from antarest.study.storage.rawstudy.model.filesystem.config.model import BindingConstraintDTO, FileStudyTreeConfig
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import (
+ BindingConstraintDTO,
+ BindingConstraintFrequency,
+)
+from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, CommandOutput, TimeStep
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, CommandOutput
def cluster_does_not_exist(study_data: FileStudy, area: str, thermal_id: str) -> bool:
@@ -25,27 +29,27 @@ def apply_binding_constraint(
name: str,
comments: Optional[str],
enabled: bool,
- time_step: TimeStep,
+ freq: BindingConstraintFrequency,
operator: BindingConstraintOperator,
coeffs: Dict[str, List[float]],
values: Optional[Union[List[List[MatrixData]], str]],
filter_year_by_year: Optional[str] = None,
filter_synthesis: Optional[str] = None,
) -> CommandOutput:
- binding_constraints[str(new_key)] = {
+ binding_constraints[new_key] = {
"name": name,
"id": bd_id,
"enabled": enabled,
- "type": time_step.value,
+ "type": freq.value,
"operator": operator.value,
}
if study_data.config.version >= 830:
if filter_year_by_year:
- binding_constraints[str(new_key)]["filter-year-by-year"] = filter_year_by_year
+ binding_constraints[new_key]["filter-year-by-year"] = filter_year_by_year
if filter_synthesis:
- binding_constraints[str(new_key)]["filter-synthesis"] = filter_synthesis
+ binding_constraints[new_key]["filter-synthesis"] = filter_synthesis
if comments is not None:
- binding_constraints[str(new_key)]["comments"] = comments
+ binding_constraints[new_key]["comments"] = comments
for link_or_thermal in coeffs:
if "%" in link_or_thermal:
@@ -67,7 +71,7 @@ def apply_binding_constraint(
if len(coeffs[link_or_thermal]) == 2:
coeffs[link_or_thermal][1] = int(coeffs[link_or_thermal][1])
- binding_constraints[str(new_key)][link_or_thermal] = "%".join(
+ binding_constraints[new_key][link_or_thermal] = "%".join(
[str(coeff_val) for coeff_val in coeffs[link_or_thermal]]
)
parse_bindings_coeffs_and_save_into_config(bd_id, study_data.config, coeffs)
@@ -76,7 +80,8 @@ def apply_binding_constraint(
["input", "bindingconstraints", "bindingconstraints"],
)
if values:
- assert isinstance(values, str)
+ if not isinstance(values, str): # pragma: no cover
+ raise TypeError(repr(values))
study_data.tree.save(values, ["input", "bindingconstraints", bd_id])
return CommandOutput(status=True)
@@ -84,19 +89,24 @@ def apply_binding_constraint(
def parse_bindings_coeffs_and_save_into_config(
bd_id: str,
study_data_config: FileStudyTreeConfig,
- coeffs: Dict[str, List[float]],
+ coeffs: Mapping[str, Union[Literal["hourly", "daily", "weekly"], Sequence[float]]],
) -> None:
if bd_id not in [bind.id for bind in study_data_config.bindings]:
areas_set = set()
clusters_set = set()
+ # Default time_step value
+ time_step = BindingConstraintFrequency.HOURLY
for k, v in coeffs.items():
+ if k == "type":
+ time_step = BindingConstraintFrequency(v)
if "%" in k:
- areas_set.add(k.split("%")[0])
- areas_set.add(k.split("%")[1])
+ areas_set |= set(k.split("%"))
elif "." in k:
clusters_set.add(k)
areas_set.add(k.split(".")[0])
- study_data_config.bindings.append(BindingConstraintDTO(id=bd_id, areas=areas_set, clusters=clusters_set))
+ study_data_config.bindings.append(
+ BindingConstraintDTO(id=bd_id, areas=areas_set, clusters=clusters_set, time_step=time_step)
+ )
def remove_area_cluster_from_binding_constraints(
diff --git a/antarest/study/storage/variantstudy/model/command/common.py b/antarest/study/storage/variantstudy/model/command/common.py
index 34c41402d6..a6ac905fd9 100644
--- a/antarest/study/storage/variantstudy/model/command/common.py
+++ b/antarest/study/storage/variantstudy/model/command/common.py
@@ -8,12 +8,6 @@ class CommandOutput:
message: str = ""
-class TimeStep(Enum):
- HOURLY = "hourly"
- DAILY = "daily"
- WEEKLY = "weekly"
-
-
class BindingConstraintOperator(Enum):
BOTH = "both"
EQUAL = "equal"
diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py
index 02e888f9ac..ed53e5e31c 100644
--- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py
+++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py
@@ -1,11 +1,14 @@
+from abc import ABCMeta
from typing import Any, Dict, List, Optional, Tuple, Union, cast
-from pydantic import validator
+import numpy as np
+from pydantic import Field, validator
-from antarest.core.utils.utils import assert_this
from antarest.matrixstore.model import MatrixData
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig, transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
+from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants
from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol, validate_matrix
from antarest.study.storage.variantstudy.business.utils_binding_constraint import (
apply_binding_constraint,
@@ -15,39 +18,119 @@
BindingConstraintOperator,
CommandName,
CommandOutput,
- TimeStep,
)
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO
+__all__ = ("AbstractBindingConstraintCommand", "CreateBindingConstraint", "check_matrix_values")
-class CreateBindingConstraint(ICommand):
- name: str
+MatrixType = List[List[MatrixData]]
+
+
+def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType) -> None:
+ """
+ Check the binding constraint's matrix values for the specified time step.
+
+ Args:
+ time_step: The frequency of the binding constraint: "hourly", "daily" or "weekly".
+ values: The binding constraint's 2nd member matrix.
+
+ Raises:
+ ValueError:
+ If the matrix shape does not match the expected shape for the given time step.
+ If the matrix values contain NaN (Not-a-Number).
+ """
+ shapes = {
+ BindingConstraintFrequency.HOURLY: (8760, 3),
+ BindingConstraintFrequency.DAILY: (365, 3),
+ BindingConstraintFrequency.WEEKLY: (52, 3),
+ }
+ # Check the matrix values and create the corresponding matrix link
+ array = np.array(values, dtype=np.float64)
+ if array.shape != shapes[time_step]:
+ raise ValueError(f"Invalid matrix shape {array.shape}, expected {shapes[time_step]}")
+ if np.isnan(array).any():
+ raise ValueError("Matrix values cannot contain NaN")
+
+
+class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta):
+ """
+ Abstract class for binding constraint commands.
+ """
+
+ # todo: add the `name` attribute because it should also be updated
enabled: bool = True
- time_step: TimeStep
+ time_step: BindingConstraintFrequency
operator: BindingConstraintOperator
coeffs: Dict[str, List[float]]
- values: Optional[Union[List[List[MatrixData]], str]] = None
+ values: Optional[Union[MatrixType, str]] = Field(None, description="2nd member matrix")
filter_year_by_year: Optional[str] = None
filter_synthesis: Optional[str] = None
comments: Optional[str] = None
- def __init__(self, **data: Any) -> None:
- super().__init__(
- command_name=CommandName.CREATE_BINDING_CONSTRAINT,
- version=1,
- **data,
+ def to_dto(self) -> CommandDTO:
+ args = {
+ "enabled": self.enabled,
+ "time_step": self.time_step.value,
+ "operator": self.operator.value,
+ "coeffs": self.coeffs,
+ "comments": self.comments,
+ "filter_year_by_year": self.filter_year_by_year,
+ "filter_synthesis": self.filter_synthesis,
+ }
+ if self.values is not None:
+ args["values"] = strip_matrix_protocol(self.values)
+ return CommandDTO(
+ action=self.command_name.value,
+ args=args,
)
+ def get_inner_matrices(self) -> List[str]:
+ if self.values is not None:
+ if not isinstance(self.values, str): # pragma: no cover
+ raise TypeError(repr(self.values))
+ return [strip_matrix_protocol(self.values)]
+ return []
+
+
+class CreateBindingConstraint(AbstractBindingConstraintCommand):
+ """
+ Command used to create a binding constraint.
+ """
+
+ command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT
+ version: int = 1
+
+ # Properties of the `CREATE_BINDING_CONSTRAINT` command:
+ name: str
+
@validator("values", always=True)
def validate_series(
- cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any
- ) -> Optional[Union[List[List[MatrixData]], str]]:
+ cls,
+ v: Optional[Union[MatrixType, str]],
+ values: Dict[str, Any],
+ ) -> Optional[Union[MatrixType, str]]:
+ constants: GeneratorMatrixConstants
+ constants = values["command_context"].generator_matrix_constants
+ time_step = values["time_step"]
if v is None:
- v = values["command_context"].generator_matrix_constants.get_null_matrix()
- return v
- else:
+ # Use an already-registered default matrix
+ methods = {
+ BindingConstraintFrequency.HOURLY: constants.get_binding_constraint_hourly,
+ BindingConstraintFrequency.DAILY: constants.get_binding_constraint_daily,
+ BindingConstraintFrequency.WEEKLY: constants.get_binding_constraint_weekly,
+ }
+ method = methods[time_step]
+ return method()
+ if isinstance(v, str):
+ # Check the matrix link
+ return validate_matrix(v, values)
+ if isinstance(v, list):
+ check_matrix_values(time_step, v)
return validate_matrix(v, values)
+ # Invalid datatype
+ # pragma: no cover
+ raise TypeError(repr(v))
def _apply_config(self, study_data_config: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
bd_id = transform_name_to_id(self.name)
@@ -55,7 +138,6 @@ def _apply_config(self, study_data_config: FileStudyTreeConfig) -> Tuple[Command
return CommandOutput(status=True), {}
def _apply(self, study_data: FileStudy) -> CommandOutput:
- assert_this(isinstance(self.values, str))
binding_constraints = study_data.tree.get(["input", "bindingconstraints", "bindingconstraints"])
new_key = len(binding_constraints.keys())
bd_id = transform_name_to_id(self.name)
@@ -76,20 +158,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
)
def to_dto(self) -> CommandDTO:
- return CommandDTO(
- action=CommandName.CREATE_BINDING_CONSTRAINT.value,
- args={
- "name": self.name,
- "enabled": self.enabled,
- "time_step": self.time_step.value,
- "operator": self.operator.value,
- "coeffs": self.coeffs,
- "values": strip_matrix_protocol(self.values),
- "comments": self.comments,
- "filter_year_by_year": self.filter_year_by_year,
- "filter_synthesis": self.filter_synthesis,
- },
- )
+ dto = super().to_dto()
+ dto.args["name"] = self.name # type: ignore
+ return dto
def match_signature(self) -> str:
return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.name)
@@ -129,9 +200,3 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]:
command_context=other.command_context,
)
]
-
- def get_inner_matrices(self) -> List[str]:
- if self.values is not None:
- assert_this(isinstance(self.values, str))
- return [strip_matrix_protocol(self.values)]
- return []
diff --git a/antarest/study/storage/variantstudy/model/command/remove_area.py b/antarest/study/storage/variantstudy/model/command/remove_area.py
index 49b785869a..03b287cf9f 100644
--- a/antarest/study/storage/variantstudy/model/command/remove_area.py
+++ b/antarest/study/storage/variantstudy/model/command/remove_area.py
@@ -17,10 +17,15 @@
class RemoveArea(ICommand):
- id: str
+ """
+ Command used to remove an area.
+ """
+
+ command_name: CommandName = CommandName.REMOVE_AREA
+ version: int = 1
- def __init__(self, **data: Any) -> None:
- super().__init__(command_name=CommandName.REMOVE_AREA, version=1, **data)
+ # Properties of the `REMOVE_AREA` command:
+ id: str
def _remove_area_from_links_in_config(self, study_data_config: FileStudyTreeConfig) -> None:
link_to_remove = [
diff --git a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py
index 4e33fd6000..97a27f5297 100644
--- a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py
+++ b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py
@@ -9,14 +9,15 @@
class RemoveBindingConstraint(ICommand):
- id: str
+ """
+ Command used to remove a binding constraint.
+ """
- def __init__(self, **data: Any) -> None:
- super().__init__(
- command_name=CommandName.REMOVE_BINDING_CONSTRAINT,
- version=1,
- **data,
- )
+ command_name: CommandName = CommandName.REMOVE_BINDING_CONSTRAINT
+ version: int = 1
+
+ # Properties of the `REMOVE_BINDING_CONSTRAINT` command:
+ id: str
def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
if self.id not in [bind.id for bind in study_data.bindings]:
diff --git a/antarest/study/storage/variantstudy/model/command/remove_link.py b/antarest/study/storage/variantstudy/model/command/remove_link.py
index aeb262fb14..5ec4c042cb 100644
--- a/antarest/study/storage/variantstudy/model/command/remove_link.py
+++ b/antarest/study/storage/variantstudy/model/command/remove_link.py
@@ -8,12 +8,17 @@
class RemoveLink(ICommand):
+ """
+ Command used to remove a link.
+ """
+
+ command_name: CommandName = CommandName.REMOVE_LINK
+ version: int = 1
+
+ # Properties of the `REMOVE_LINK` command:
area1: str
area2: str
- def __init__(self, **data: Any) -> None:
- super().__init__(command_name=CommandName.REMOVE_LINK, version=1, **data)
-
def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
result = self._check_link_exists(study_data)
if result[0].status:
diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py
index dc17838f58..3ec874a375 100644
--- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py
+++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py
@@ -3,47 +3,54 @@
from pydantic import validator
from antarest.core.model import JSON
-from antarest.core.utils.utils import assert_this
from antarest.matrixstore.model import MatrixData
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
-from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol, validate_matrix
+from antarest.study.storage.variantstudy.business.utils import validate_matrix
from antarest.study.storage.variantstudy.business.utils_binding_constraint import apply_binding_constraint
-from antarest.study.storage.variantstudy.model.command.common import (
- BindingConstraintOperator,
- CommandName,
- CommandOutput,
- TimeStep,
+from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput
+from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
+ AbstractBindingConstraintCommand,
+ check_matrix_values,
)
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO
+__all__ = ("UpdateBindingConstraint",)
-class UpdateBindingConstraint(ICommand):
+MatrixType = List[List[MatrixData]]
+
+
+class UpdateBindingConstraint(AbstractBindingConstraintCommand):
+ """
+ Command used to update a binding constraint.
+ """
+
+ command_name: CommandName = CommandName.UPDATE_BINDING_CONSTRAINT
+ version: int = 1
+
+ # Properties of the `UPDATE_BINDING_CONSTRAINT` command:
id: str
- enabled: bool = True
- time_step: TimeStep
- operator: BindingConstraintOperator
- coeffs: Dict[str, List[float]]
- values: Optional[Union[List[List[MatrixData]], str]] = None
- filter_year_by_year: Optional[str] = None
- filter_synthesis: Optional[str] = None
- comments: Optional[str] = None
-
- def __init__(self, **data: Any) -> None:
- super().__init__(
- command_name=CommandName.UPDATE_BINDING_CONSTRAINT,
- version=1,
- **data,
- )
@validator("values", always=True)
def validate_series(
- cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any
- ) -> Optional[Union[List[List[MatrixData]], str]]:
- if v is not None:
+ cls,
+ v: Optional[Union[MatrixType, str]],
+ values: Dict[str, Any],
+ ) -> Optional[Union[MatrixType, str]]:
+ time_step = values["time_step"]
+ if v is None:
+ # The matrix is not updated
+ return None
+ if isinstance(v, str):
+ # Check the matrix link
+ return validate_matrix(v, values)
+ if isinstance(v, list):
+ check_matrix_values(time_step, v)
return validate_matrix(v, values)
- return None
+ # Invalid datatype
+ # pragma: no cover
+ raise TypeError(repr(v))
def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
return CommandOutput(status=True), {}
@@ -81,22 +88,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
)
def to_dto(self) -> CommandDTO:
- args = {
- "id": self.id,
- "enabled": self.enabled,
- "time_step": self.time_step.value,
- "operator": self.operator.value,
- "coeffs": self.coeffs,
- "comments": self.comments,
- "filter_year_by_year": self.filter_year_by_year,
- "filter_synthesis": self.filter_synthesis,
- }
- if self.values is not None:
- args["values"] = strip_matrix_protocol(self.values)
- return CommandDTO(
- action=CommandName.UPDATE_BINDING_CONSTRAINT.value,
- args=args,
- )
+ dto = super().to_dto()
+ dto.args["id"] = self.id # type: ignore
+ return dto
def match_signature(self) -> str:
return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.id)
@@ -119,9 +113,3 @@ def match(self, other: ICommand, equal: bool = False) -> bool:
def _create_diff(self, other: "ICommand") -> List["ICommand"]:
return [other]
-
- def get_inner_matrices(self) -> List[str]:
- if self.values is not None:
- assert_this(isinstance(self.values, str))
- return [strip_matrix_protocol(self.values)]
- return []
diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py
index 931816391a..610e3b7263 100644
--- a/antarest/study/web/study_data_blueprint.py
+++ b/antarest/study/web/study_data_blueprint.py
@@ -1577,8 +1577,15 @@ def create_st_storage(
Args:
- `uuid`: The UUID of the study.
- `area_id`: The area ID.
- - `form`: The name and the group(PSP_open, PSP_closed, Pondage, Battery, Other1, Other2, Other3, Other4, Other5)
- of the storage that we want to create.
+ - `form`: The characteristic of the storage that we can update:
+ - `name`: The name of the updated storage.
+ - `group`: The group of the updated storage.
+ - `injectionNominalCapacity`: The injection Nominal Capacity of the updated storage.
+ - `withdrawalNominalCapacity`: The withdrawal Nominal Capacity of the updated storage.
+ - `reservoirCapacity`: The reservoir capacity of the updated storage.
+ - `efficiency`: The efficiency of the updated storage
+ - `initialLevel`: The initial Level of the updated storage
+ - `initialLevelOptim`: The initial Level Optim of the updated storage
Returns: New storage with the following attributes:
- `id`: The storage ID of the study.
diff --git a/antarest/study/web/xpansion_studies_blueprint.py b/antarest/study/web/xpansion_studies_blueprint.py
index 98c11a91c2..af9496ba4c 100644
--- a/antarest/study/web/xpansion_studies_blueprint.py
+++ b/antarest/study/web/xpansion_studies_blueprint.py
@@ -162,7 +162,7 @@ def get_candidates(
uuid: str,
current_user: JWTUser = Depends(auth.get_current_user),
) -> Any:
- logger.info(f"Fetching study list", extra={"user": current_user.id})
+ logger.info("Fetching study list", extra={"user": current_user.id})
params = RequestParameters(user=current_user)
return study_service.get_candidates(uuid, params)
diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py
index f939d8ea61..d37a8825f5 100644
--- a/antarest/worker/simulator_worker.py
+++ b/antarest/worker/simulator_worker.py
@@ -1,9 +1,10 @@
+import io
import logging
import subprocess
import threading
import time
from pathlib import Path
-from typing import IO, cast
+from typing import cast
from pydantic import BaseModel
@@ -12,7 +13,7 @@
from antarest.core.interfaces.eventbus import IEventBus
from antarest.core.tasks.model import TaskResult
from antarest.core.utils.fastapi_sqlalchemy import db
-from antarest.launcher.adapters.log_manager import LogTailManager
+from antarest.launcher.adapters.log_manager import follow
from antarest.matrixstore.service import MatrixService
from antarest.matrixstore.uri_resolver_service import UriResolverService
from antarest.study.storage.rawstudy.model.filesystem.factory import StudyFactory
@@ -101,8 +102,8 @@ def stop_reading() -> bool:
encoding="utf-8",
)
thread = threading.Thread(
- target=lambda: LogTailManager.follow(
- cast(IO[str], process.stdout),
+ target=lambda: follow(
+ cast(io.StringIO, process.stdout),
append_output,
stop_reading,
None,
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index e6d96180aa..4fc768bc04 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -1,6 +1,38 @@
Antares Web Changelog
=====================
+v2.15.2 (2023-10-11)
+--------------------
+
+### Hotfix
+
+* **service:** user connected via tokens cannot create a study (#1757) ([f620197](https://github.com/AntaresSimulatorTeam/AntaREST/commit/f6201976a653db19739cbc42e91ea27ac790da10))
+
+
+### Features
+
+* **binding-constraint:** handling binding constraints frequency in study configuration parsing (#1702) ([703351a](https://github.com/AntaresSimulatorTeam/AntaREST/commit/703351a6d8d4f70491e66c3c54a92c6d28cb92ea))
+ - add the binding constraint series in the matrix constants generator ([e00d58b](https://github.com/AntaresSimulatorTeam/AntaREST/commit/e00d58b203023363860cb0e849576e02ed97fd81))
+ - command `create_binding_constraint` can check the matrix shape ([68bf99f](https://github.com/AntaresSimulatorTeam/AntaREST/commit/68bf99f1170181f6111bc15c03ede27030f809d2))
+ - command `update_binding_constraint` can check the matrix shape ([c962f73](https://github.com/AntaresSimulatorTeam/AntaREST/commit/c962f7344c7ea07c7a8c7699b2af35f90f3b853c))
+ - add missing command docstring ([d277805](https://github.com/AntaresSimulatorTeam/AntaREST/commit/d277805c10d3f9c7134166e6d2f7170c7b752428))
+ - reduce code duplication ([b41d957](https://github.com/AntaresSimulatorTeam/AntaREST/commit/b41d957cffa6a8dde21a022f8b6c24c8de2559a2))
+ - correct `test_command_factory` unit test to ignore abstract commands ([789c2ad](https://github.com/AntaresSimulatorTeam/AntaREST/commit/789c2adfc3ef3999f3779a345e0730f2f9ad906a))
+* **api:** add endpoint get_nb_cores (#1727) ([9cfa9f1](https://github.com/AntaresSimulatorTeam/AntaREST/commit/9cfa9f13d363ea4f73aa31ed760d525b091f04a4))
+* **st-storage:** allow all parameters in endpoint for short term storage creation (#1736) ([853cf6b](https://github.com/AntaresSimulatorTeam/AntaREST/commit/853cf6ba48a23d39f247a0842afac440c4ea4570))
+
+
+### Chore
+
+* **sonarcloud:** correct SonarCloud issues ([901d00d](https://github.com/AntaresSimulatorTeam/AntaREST/commit/901d00df558f7e79b728e2ce7406d1bdea69f839))
+
+
+### Contributors
+
+laurent-laporte-pro,
+MartinBelthle
+
+
v2.15.1 (2023-10-05)
--------------------
diff --git a/resources/application.yaml b/resources/application.yaml
index 75a7b8605e..a85357634f 100644
--- a/resources/application.yaml
+++ b/resources/application.yaml
@@ -20,9 +20,9 @@ db:
#pool_recycle:
storage:
- tmp_dir: /tmp
+ tmp_dir: ./tmp
matrixstore: ./matrices
- archive_dir: examples/archives
+ archive_dir: ./examples/archives
allow_deletion: false # indicate if studies found in non default workspace can be deleted by the application
#matrix_gc_sleeping_time: 3600 # time in seconds to sleep between two garbage collection
#matrix_gc_dry_run: False # Skip matrix effective deletion
@@ -32,20 +32,23 @@ storage:
#auto_archive_max_parallel: 5 # max auto archival tasks in parallel
workspaces:
default: # required, no filters applied, this folder is not watched
- path: examples/internal_studies/
+ path: ./examples/internal_studies/
# other workspaces can be added
# if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside
tmp:
- path: examples/studies/
+ path: ./examples/studies/
# filter_in: ['.*'] # default to '.*'
# filter_out: [] # default to empty
# groups: [] # default empty
launcher:
default: local
+
local:
binaries:
700: path/to/700
+ enable_nb_cores_detection: true
+
# slurm:
# local_workspace: path/to/workspace
# username: username
@@ -56,7 +59,11 @@ launcher:
# password: password_is_optional_but_necessary_if_key_is_absent
# default_wait_time: 900
# default_time_limit: 172800
-# default_n_cpu: 12
+# enable_nb_cores_detection: False
+# nb_cores:
+# min: 1
+# default: 22
+# max: 24
# default_json_db_name: launcher_db.json
# slurm_script_path: /path/to/launchantares_v1.1.3.sh
# db_primary_key: name
@@ -70,7 +77,7 @@ launcher:
debug: true
-root_path: ""
+root_path: "api"
#tasks:
# max_workers: 5
diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml
index 1bb5e30878..02fbb4b8bc 100644
--- a/resources/deploy/config.prod.yaml
+++ b/resources/deploy/config.prod.yaml
@@ -32,9 +32,12 @@ storage:
launcher:
default: local
+
local:
binaries:
800: /antares_simulator/antares-8.2-solver
+ enable_nb_cores_detection: true
+
# slurm:
# local_workspace: path/to/workspace
# username: username
@@ -45,7 +48,11 @@ launcher:
# password: password_is_optional_but_necessary_if_key_is_absent
# default_wait_time: 900
# default_time_limit: 172800
-# default_n_cpu: 12
+# enable_nb_cores_detection: False
+# nb_cores:
+# min: 1
+# default: 22
+# max: 24
# default_json_db_name: launcher_db.json
# slurm_script_path: /path/to/launchantares_v1.1.3.sh
# db_primary_key: name
@@ -59,7 +66,7 @@ launcher:
debug: false
-root_path: "/api"
+root_path: "api"
#tasks:
# max_workers: 5
diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml
index 48cea48a22..810e1f8d24 100644
--- a/resources/deploy/config.yaml
+++ b/resources/deploy/config.yaml
@@ -29,9 +29,12 @@ storage:
launcher:
default: local
+
local:
binaries:
700: path/to/700
+ enable_nb_cores_detection: true
+
# slurm:
# local_workspace: path/to/workspace
# username: username
@@ -42,7 +45,11 @@ launcher:
# password: password_is_optional_but_necessary_if_key_is_absent
# default_wait_time: 900
# default_time_limit: 172800
-# default_n_cpu: 12
+# enable_nb_cores_detection: False
+# nb_cores:
+# min: 1
+# default: 22
+# max: 24
# default_json_db_name: launcher_db.json
# slurm_script_path: /path/to/launchantares_v1.1.3.sh
# db_primary_key: name
diff --git a/setup.py b/setup.py
index 92e4e34f98..d0b8c5deb3 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@
setup(
name="AntaREST",
- version="2.15.1",
+ version="2.15.2",
description="Antares Server",
long_description=Path("README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
diff --git a/sonar-project.properties b/sonar-project.properties
index 9bf66c1c14..1d3a132a04 100644
--- a/sonar-project.properties
+++ b/sonar-project.properties
@@ -6,5 +6,5 @@ sonar.exclusions=antarest/gui.py,antarest/main.py
sonar.python.coverage.reportPaths=coverage.xml
sonar.python.version=3.8
sonar.javascript.lcov.reportPaths=webapp/coverage/lcov.info
-sonar.projectVersion=2.15.1
+sonar.projectVersion=2.15.2
sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/*
\ No newline at end of file
diff --git a/tests/conftest_db.py b/tests/conftest_db.py
index 877ca119d1..bcb4177766 100644
--- a/tests/conftest_db.py
+++ b/tests/conftest_db.py
@@ -3,7 +3,8 @@
import pytest
from sqlalchemy import create_engine # type: ignore
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.engine.base import Engine # type: ignore
+from sqlalchemy.orm import Session, sessionmaker # type: ignore
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.dbmodel import Base
@@ -12,7 +13,7 @@
@pytest.fixture(name="db_engine")
-def db_engine_fixture() -> Generator[Any, None, None]:
+def db_engine_fixture() -> Generator[Engine, None, None]:
"""
Fixture that creates an in-memory SQLite database engine for testing.
@@ -26,7 +27,7 @@ def db_engine_fixture() -> Generator[Any, None, None]:
@pytest.fixture(name="db_session")
-def db_session_fixture(db_engine) -> Generator:
+def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]:
"""
Fixture that creates a database session for testing purposes.
@@ -46,7 +47,7 @@ def db_session_fixture(db_engine) -> Generator:
@pytest.fixture(name="db_middleware", autouse=True)
def db_middleware_fixture(
- db_engine: Any,
+ db_engine: Engine,
) -> Generator[DBSessionMiddleware, None, None]:
"""
Fixture that sets up a database session middleware with custom engine settings.
diff --git a/tests/core/assets/__init__.py b/tests/core/assets/__init__.py
new file mode 100644
index 0000000000..773f16ec60
--- /dev/null
+++ b/tests/core/assets/__init__.py
@@ -0,0 +1,3 @@
+from pathlib import Path
+
+ASSETS_DIR = Path(__file__).parent.resolve()
diff --git a/tests/core/assets/config/application-2.14.yaml b/tests/core/assets/config/application-2.14.yaml
new file mode 100644
index 0000000000..650093286d
--- /dev/null
+++ b/tests/core/assets/config/application-2.14.yaml
@@ -0,0 +1,61 @@
+security:
+ disabled: false
+ jwt:
+ key: super-secret
+ login:
+ admin:
+ pwd: admin
+
+db:
+ url: "sqlite:////home/john/antares_data/database.db"
+
+storage:
+ tmp_dir: /tmp
+ matrixstore: /home/john/antares_data/matrices
+ archive_dir: /home/john/antares_data/archives
+ allow_deletion: false
+ workspaces:
+ default:
+ path: /home/john/antares_data/internal_studies/
+ studies:
+ path: /home/john/antares_data/studies/
+
+launcher:
+ default: slurm
+ local:
+ binaries:
+ 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver
+ 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver
+
+ slurm:
+ local_workspace: /home/john/antares_data/slurm_workspace
+
+ username: antares
+ hostname: slurm-prod-01
+
+ port: 22
+ private_key_file: /home/john/.ssh/id_rsa
+ key_password:
+ default_wait_time: 900
+ default_time_limit: 172800
+ default_n_cpu: 20
+ default_json_db_name: launcher_db.json
+ slurm_script_path: /applis/antares/launchAntares.sh
+ db_primary_key: name
+ antares_versions_on_remote_server:
+ - '850' # 8.5.1/antares-8.5-solver
+ - '860' # 8.6.2/antares-8.6-solver
+ - '870' # 8.7.0/antares-8.7-solver
+
+debug: false
+
+root_path: ""
+
+server:
+ worker_threadpool_size: 12
+ services:
+ - watcher
+ - matrix_gc
+
+logging:
+ level: INFO
diff --git a/tests/core/assets/config/application-2.15.yaml b/tests/core/assets/config/application-2.15.yaml
new file mode 100644
index 0000000000..c51d32aaae
--- /dev/null
+++ b/tests/core/assets/config/application-2.15.yaml
@@ -0,0 +1,66 @@
+security:
+ disabled: false
+ jwt:
+ key: super-secret
+ login:
+ admin:
+ pwd: admin
+
+db:
+ url: "sqlite:////home/john/antares_data/database.db"
+
+storage:
+ tmp_dir: /tmp
+ matrixstore: /home/john/antares_data/matrices
+ archive_dir: /home/john/antares_data/archives
+ allow_deletion: false
+ workspaces:
+ default:
+ path: /home/john/antares_data/internal_studies/
+ studies:
+ path: /home/john/antares_data/studies/
+
+launcher:
+ default: slurm
+ local:
+ binaries:
+ 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver
+ 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver
+ enable_nb_cores_detection: True
+
+ slurm:
+ local_workspace: /home/john/antares_data/slurm_workspace
+
+ username: antares
+ hostname: slurm-prod-01
+
+ port: 22
+ private_key_file: /home/john/.ssh/id_rsa
+ key_password:
+ default_wait_time: 900
+ default_time_limit: 172800
+ enable_nb_cores_detection: False
+ nb_cores:
+ min: 1
+ default: 22
+ max: 24
+ default_json_db_name: launcher_db.json
+ slurm_script_path: /applis/antares/launchAntares.sh
+ db_primary_key: name
+ antares_versions_on_remote_server:
+ - '850' # 8.5.1/antares-8.5-solver
+ - '860' # 8.6.2/antares-8.6-solver
+ - '870' # 8.7.0/antares-8.7-solver
+
+debug: false
+
+root_path: ""
+
+server:
+ worker_threadpool_size: 12
+ services:
+ - watcher
+ - matrix_gc
+
+logging:
+ level: INFO
diff --git a/tests/core/test_config.py b/tests/core/test_config.py
index 1c1c96a180..00c6f9458d 100644
--- a/tests/core/test_config.py
+++ b/tests/core/test_config.py
@@ -1,15 +1,253 @@
from pathlib import Path
+from unittest import mock
import pytest
-from antarest.core.config import Config
+from antarest.core.config import (
+ Config,
+ InvalidConfigurationError,
+ LauncherConfig,
+ LocalConfig,
+ NbCoresConfig,
+ SlurmConfig,
+)
+from tests.core.assets import ASSETS_DIR
+LAUNCHER_CONFIG = {
+ "default": "slurm",
+ "local": {
+ "binaries": {"860": Path("/bin/solver-860.exe")},
+ "enable_nb_cores_detection": False,
+ "nb_cores": {"min": 2, "default": 10, "max": 20},
+ },
+ "slurm": {
+ "local_workspace": Path("/home/john/antares/workspace"),
+ "username": "john",
+ "hostname": "slurm-001",
+ "port": 22,
+ "private_key_file": Path("/home/john/.ssh/id_rsa"),
+ "key_password": "password",
+ "password": "password",
+ "default_wait_time": 10,
+ "default_time_limit": 20,
+ "default_json_db_name": "antares.db",
+ "slurm_script_path": "/path/to/slurm/launcher.sh",
+ "max_cores": 32,
+ "antares_versions_on_remote_server": ["860"],
+ "enable_nb_cores_detection": False,
+ "nb_cores": {"min": 1, "default": 34, "max": 36},
+ },
+ "batch_size": 100,
+}
-@pytest.mark.unit_test
-def test_get_yaml(project_path: Path):
- config = Config.from_yaml_file(file=project_path / "resources/application.yaml")
- assert config.security.admin_pwd == "admin"
- assert config.storage.workspaces["default"].path == Path("examples/internal_studies/")
- assert not config.logging.json
- assert config.logging.level == "INFO"
+class TestNbCoresConfig:
+ def test_init__default_values(self):
+ config = NbCoresConfig()
+ assert config.min == 1
+ assert config.default == 22
+ assert config.max == 24
+
+ def test_init__invalid_values(self):
+ with pytest.raises(ValueError):
+ # default < min
+ NbCoresConfig(min=2, default=1, max=24)
+ with pytest.raises(ValueError):
+ # default > max
+ NbCoresConfig(min=1, default=25, max=24)
+ with pytest.raises(ValueError):
+ # min < 0
+ NbCoresConfig(min=0, default=22, max=23)
+ with pytest.raises(ValueError):
+ # min > max
+ NbCoresConfig(min=22, default=22, max=21)
+
+ def test_to_json(self):
+ config = NbCoresConfig()
+ # ReactJs Material UI expects "min", "defaultValue" and "max" keys
+ assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24}
+
+
+class TestLocalConfig:
+ def test_init__default_values(self):
+ config = LocalConfig()
+ assert config.binaries == {}, "binaries should be empty by default"
+ assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default"
+ assert config.nb_cores == NbCoresConfig()
+
+ def test_from_dict(self):
+ config = LocalConfig.from_dict(
+ {
+ "binaries": {"860": Path("/bin/solver-860.exe")},
+ "enable_nb_cores_detection": False,
+ "nb_cores": {"min": 2, "default": 10, "max": 20},
+ }
+ )
+ assert config.binaries == {"860": Path("/bin/solver-860.exe")}
+ assert not config.enable_nb_cores_detection
+ assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20)
+
+ def test_from_dict__auto_detect(self):
+ with mock.patch("multiprocessing.cpu_count", return_value=8):
+ config = LocalConfig.from_dict(
+ {
+ "binaries": {"860": Path("/bin/solver-860.exe")},
+ "enable_nb_cores_detection": True,
+ }
+ )
+ assert config.binaries == {"860": Path("/bin/solver-860.exe")}
+ assert config.enable_nb_cores_detection
+ assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8)
+
+
+class TestSlurmConfig:
+ def test_init__default_values(self):
+ config = SlurmConfig()
+ assert config.local_workspace == Path()
+ assert config.username == ""
+ assert config.hostname == ""
+ assert config.port == 0
+ assert config.private_key_file == Path()
+ assert config.key_password == ""
+ assert config.password == ""
+ assert config.default_wait_time == 0
+ assert config.default_time_limit == 0
+ assert config.default_json_db_name == ""
+ assert config.slurm_script_path == ""
+ assert config.max_cores == 64
+ assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default"
+ assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default"
+ assert config.nb_cores == NbCoresConfig()
+
+ def test_from_dict(self):
+ config = SlurmConfig.from_dict(
+ {
+ "local_workspace": Path("/home/john/antares/workspace"),
+ "username": "john",
+ "hostname": "slurm-001",
+ "port": 22,
+ "private_key_file": Path("/home/john/.ssh/id_rsa"),
+ "key_password": "password",
+ "password": "password",
+ "default_wait_time": 10,
+ "default_time_limit": 20,
+ "default_json_db_name": "antares.db",
+ "slurm_script_path": "/path/to/slurm/launcher.sh",
+ "max_cores": 32,
+ "antares_versions_on_remote_server": ["860"],
+ "enable_nb_cores_detection": False,
+ "nb_cores": {"min": 2, "default": 10, "max": 20},
+ }
+ )
+ assert config.local_workspace == Path("/home/john/antares/workspace")
+ assert config.username == "john"
+ assert config.hostname == "slurm-001"
+ assert config.port == 22
+ assert config.private_key_file == Path("/home/john/.ssh/id_rsa")
+ assert config.key_password == "password"
+ assert config.password == "password"
+ assert config.default_wait_time == 10
+ assert config.default_time_limit == 20
+ assert config.default_json_db_name == "antares.db"
+ assert config.slurm_script_path == "/path/to/slurm/launcher.sh"
+ assert config.max_cores == 32
+ assert config.antares_versions_on_remote_server == ["860"]
+ assert not config.enable_nb_cores_detection
+ assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20)
+
+ def test_from_dict__default_n_cpu__backport(self):
+ config = SlurmConfig.from_dict(
+ {
+ "local_workspace": Path("/home/john/antares/workspace"),
+ "username": "john",
+ "hostname": "slurm-001",
+ "port": 22,
+ "private_key_file": Path("/home/john/.ssh/id_rsa"),
+ "key_password": "password",
+ "password": "password",
+ "default_wait_time": 10,
+ "default_time_limit": 20,
+ "default_json_db_name": "antares.db",
+ "slurm_script_path": "/path/to/slurm/launcher.sh",
+ "max_cores": 32,
+ "antares_versions_on_remote_server": ["860"],
+ "default_n_cpu": 15,
+ }
+ )
+ assert config.nb_cores == NbCoresConfig(min=1, default=15, max=24)
+
+ def test_from_dict__auto_detect(self):
+ with pytest.raises(NotImplementedError):
+ SlurmConfig.from_dict({"enable_nb_cores_detection": True})
+
+
+class TestLauncherConfig:
+ def test_init__default_values(self):
+ config = LauncherConfig()
+ assert config.default == "local", "default launcher should be local"
+ assert config.local is None
+ assert config.slurm is None
+ assert config.batch_size == 9999
+
+ def test_from_dict(self):
+ config = LauncherConfig.from_dict(LAUNCHER_CONFIG)
+ assert config.default == "slurm"
+ assert config.local == LocalConfig(
+ binaries={"860": Path("/bin/solver-860.exe")},
+ enable_nb_cores_detection=False,
+ nb_cores=NbCoresConfig(min=2, default=10, max=20),
+ )
+ assert config.slurm == SlurmConfig(
+ local_workspace=Path("/home/john/antares/workspace"),
+ username="john",
+ hostname="slurm-001",
+ port=22,
+ private_key_file=Path("/home/john/.ssh/id_rsa"),
+ key_password="password",
+ password="password",
+ default_wait_time=10,
+ default_time_limit=20,
+ default_json_db_name="antares.db",
+ slurm_script_path="/path/to/slurm/launcher.sh",
+ max_cores=32,
+ antares_versions_on_remote_server=["860"],
+ enable_nb_cores_detection=False,
+ nb_cores=NbCoresConfig(min=1, default=34, max=36),
+ )
+ assert config.batch_size == 100
+
+ def test_init__invalid_launcher(self):
+ with pytest.raises(ValueError):
+ LauncherConfig(default="invalid_launcher")
+
+ def test_get_nb_cores__default(self):
+ config = LauncherConfig.from_dict(LAUNCHER_CONFIG)
+ # default == "slurm"
+ assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36)
+
+ def test_get_nb_cores__local(self):
+ config = LauncherConfig.from_dict(LAUNCHER_CONFIG)
+ assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20)
+
+ def test_get_nb_cores__slurm(self):
+ config = LauncherConfig.from_dict(LAUNCHER_CONFIG)
+ assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36)
+
+ def test_get_nb_cores__invalid_configuration(self):
+ config = LauncherConfig.from_dict(LAUNCHER_CONFIG)
+ with pytest.raises(InvalidConfigurationError):
+ config.get_nb_cores("invalid_launcher")
+ config = LauncherConfig.from_dict({})
+ with pytest.raises(InvalidConfigurationError):
+ config.get_nb_cores("slurm")
+
+
+class TestConfig:
+ @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"])
+ def test_from_yaml_file(self, config_name: str) -> None:
+ yaml_path = ASSETS_DIR.joinpath("config", config_name)
+ config = Config.from_yaml_file(yaml_path)
+ assert config.security.admin_pwd == "admin"
+ assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies")
+ assert not config.logging.json
+ assert config.logging.level == "INFO"
diff --git a/tests/integration/assets/base_study.zip b/tests/integration/assets/base_study.zip
index 28794fde9e..712833942c 100644
Binary files a/tests/integration/assets/base_study.zip and b/tests/integration/assets/base_study.zip differ
diff --git a/tests/integration/assets/config.template.yml b/tests/integration/assets/config.template.yml
index f3ad1d256f..71c58a1ba0 100644
--- a/tests/integration/assets/config.template.yml
+++ b/tests/integration/assets/config.template.yml
@@ -27,6 +27,7 @@ launcher:
local:
binaries:
700: {{launcher_mock}}
+ enable_nb_cores_detection: True
debug: false
diff --git a/tests/integration/assets/variant_study.zip b/tests/integration/assets/variant_study.zip
index 5693b29a61..4e08926cda 100644
Binary files a/tests/integration/assets/variant_study.zip and b/tests/integration/assets/variant_study.zip differ
diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py
new file mode 100644
index 0000000000..7244fba8ee
--- /dev/null
+++ b/tests/integration/launcher_blueprint/test_launcher_local.py
@@ -0,0 +1,70 @@
+import http
+
+import pytest
+from starlette.testclient import TestClient
+
+from antarest.core.config import LocalConfig
+
+
+# noinspection SpellCheckingInspection
+@pytest.mark.integration_test
+class TestLauncherNbCores:
+ """
+ The purpose of this unit test is to check the `/v1/launcher/nbcores` endpoint.
+ """
+
+ def test_get_launcher_nb_cores(
+ self,
+ client: TestClient,
+ user_access_token: str,
+ ) -> None:
+ # NOTE: we have `enable_nb_cores_detection: True` in `tests/integration/assets/config.template.yml`.
+ local_nb_cores = LocalConfig.from_dict({"enable_nb_cores_detection": True}).nb_cores
+ nb_cores_expected = local_nb_cores.to_json()
+ res = client.get(
+ "/v1/launcher/nbcores",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ res.raise_for_status()
+ actual = res.json()
+ assert actual == nb_cores_expected
+
+ res = client.get(
+ "/v1/launcher/nbcores?launcher=default",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ res.raise_for_status()
+ actual = res.json()
+ assert actual == nb_cores_expected
+
+ res = client.get(
+ "/v1/launcher/nbcores?launcher=local",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ res.raise_for_status()
+ actual = res.json()
+ assert actual == nb_cores_expected
+
+ # Check that the endpoint raise an exception when the "slurm" launcher is requested.
+ res = client.get(
+ "/v1/launcher/nbcores?launcher=slurm",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json()
+ actual = res.json()
+ assert actual == {
+ "description": "Unknown solver configuration: 'slurm'",
+ "exception": "UnknownSolverConfig",
+ }
+
+ # Check that the endpoint raise an exception when an unknown launcher is requested.
+ res = client.get(
+ "/v1/launcher/nbcores?launcher=unknown",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json()
+ actual = res.json()
+ assert actual == {
+ "description": "Unknown solver configuration: 'unknown'",
+ "exception": "UnknownSolverConfig",
+ }
diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py
index cdd0b464c7..2d1035af3e 100644
--- a/tests/integration/study_data_blueprint/test_st_storage.py
+++ b/tests/integration/study_data_blueprint/test_st_storage.py
@@ -1,4 +1,3 @@
-import json
import re
import numpy as np
@@ -9,6 +8,17 @@
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from tests.integration.utils import wait_task_completion
+DEFAULT_PROPERTIES = {
+ # `name` field is required
+ "group": "Other1",
+ "injectionNominalCapacity": 0.0,
+ "withdrawalNominalCapacity": 0.0,
+ "reservoirCapacity": 0.0,
+ "efficiency": 1.0,
+ "initialLevel": 0.0,
+ "initialLevelOptim": False,
+}
+
@pytest.mark.unit_test
class TestSTStorage:
@@ -61,28 +71,46 @@ def test_lifecycle__nominal(
task = wait_task_completion(client, user_access_token, task_id)
assert task.status == TaskStatus.COMPLETED, task
- # creation with default values (only mandatory properties specified)
+ # =============================
+ # SHORT-TERM STORAGE CREATION
+ # =============================
+
area_id = transform_name_to_id("FR")
siemens_battery = "Siemens Battery"
+
+ # Un attempt to create a short-term storage without name
+ # should raise a validation error (other properties are optional).
+ # Un attempt to create a short-term storage with an empty name
+ # or an invalid name should also raise a validation error.
+ attempts = [{}, {"name": ""}, {"name": "!??"}]
+ for attempt in attempts:
+ res = client.post(
+ f"/v1/studies/{study_id}/areas/{area_id}/storages",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ json=attempt,
+ )
+ assert res.status_code == 422, res.json()
+ assert res.json()["exception"] in {"ValidationError", "RequestValidationError"}, res.json()
+
+ # We can create a short-term storage with the following properties:
+ siemens_properties = {
+ **DEFAULT_PROPERTIES,
+ "name": siemens_battery,
+ "group": "Battery",
+ "injectionNominalCapacity": 1450,
+ "withdrawalNominalCapacity": 1350,
+ "reservoirCapacity": 1500,
+ }
res = client.post(
f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
- json={"name": siemens_battery, "group": "Battery"},
+ json=siemens_properties,
)
assert res.status_code == 200, res.json()
siemens_battery_id = res.json()["id"]
assert siemens_battery_id == transform_name_to_id(siemens_battery)
- assert res.json() == {
- "efficiency": 1.0,
- "group": "Battery",
- "id": siemens_battery_id,
- "initialLevel": 0.0,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "name": siemens_battery,
- "reservoirCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
- }
+ siemens_config = {**siemens_properties, "id": siemens_battery_id}
+ assert res.json() == siemens_config
# reading the properties of a short-term storage
res = client.get(
@@ -90,17 +118,11 @@ def test_lifecycle__nominal(
headers={"Authorization": f"Bearer {user_access_token}"},
)
assert res.status_code == 200, res.json()
- assert res.json() == {
- "efficiency": 1.0,
- "group": "Battery",
- "id": siemens_battery_id,
- "initialLevel": 0.0,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "name": siemens_battery,
- "reservoirCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
- }
+ assert res.json() == siemens_config
+
+ # =============================
+ # SHORT-TERM STORAGE MATRICES
+ # =============================
# updating the matrix of a short-term storage
array = np.random.rand(8760, 1) * 1000
@@ -134,25 +156,17 @@ def test_lifecycle__nominal(
assert res.status_code == 200, res.json()
assert res.json() is True
+ # ==================================
+ # SHORT-TERM STORAGE LIST / GROUPS
+ # ==================================
+
# Reading the list of short-term storages
res = client.get(
f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
)
assert res.status_code == 200, res.json()
- assert res.json() == [
- {
- "efficiency": 1.0,
- "group": "Battery",
- "id": siemens_battery_id,
- "initialLevel": 0.0,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "name": siemens_battery,
- "reservoirCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
- }
- ]
+ assert res.json() == [siemens_config]
# updating properties
res = client.patch(
@@ -164,34 +178,23 @@ def test_lifecycle__nominal(
},
)
assert res.status_code == 200, res.json()
- assert json.loads(res.text) == {
- "id": siemens_battery_id,
+ siemens_config = {
+ **siemens_config,
"name": "New Siemens Battery",
- "group": "Battery",
- "efficiency": 1.0,
- "initialLevel": 0.0,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
"reservoirCapacity": 2500,
}
+ assert res.json() == siemens_config
res = client.get(
f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
)
assert res.status_code == 200, res.json()
- assert res.json() == {
- "id": siemens_battery_id,
- "name": "New Siemens Battery",
- "group": "Battery",
- "efficiency": 1.0,
- "initialLevel": 0.0,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
- "reservoirCapacity": 2500,
- }
+ assert res.json() == siemens_config
+
+ # ===========================
+ # SHORT-TERM STORAGE UPDATE
+ # ===========================
# updating properties
res = client.patch(
@@ -202,37 +205,38 @@ def test_lifecycle__nominal(
"reservoirCapacity": 0,
},
)
- assert res.status_code == 200, res.json()
- assert json.loads(res.text) == {
- "id": siemens_battery_id,
- "name": "New Siemens Battery",
- "group": "Battery",
- "efficiency": 1.0,
+ siemens_config = {
+ **siemens_config,
"initialLevel": 5900,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
"reservoirCapacity": 0,
}
+ assert res.json() == siemens_config
+ # An attempt to update the `efficiency` property with an invalid value
+ # should raise a validation error.
+ # The `efficiency` property must be a float between 0 and 1.
+ bad_properties = {"efficiency": 2.0}
+ res = client.patch(
+ f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ json=bad_properties,
+ )
+ assert res.status_code == 422, res.json()
+ assert res.json()["exception"] == "ValidationError", res.json()
+
+ # The short-term storage properties should not have been updated.
res = client.get(
f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
)
assert res.status_code == 200, res.json()
- assert res.json() == {
- "id": siemens_battery_id,
- "name": "New Siemens Battery",
- "group": "Battery",
- "efficiency": 1.0,
- "initialLevel": 5900,
- "initialLevelOptim": False,
- "injectionNominalCapacity": 0.0,
- "withdrawalNominalCapacity": 0.0,
- "reservoirCapacity": 0,
- }
+ assert res.json() == siemens_config
+
+ # =============================
+ # SHORT-TERM STORAGE DELETION
+ # =============================
- # deletion of short-term storages
+ # To delete a short-term storage, we need to provide its ID.
res = client.request(
"DELETE",
f"/v1/studies/{study_id}/areas/{area_id}/storages",
@@ -242,7 +246,7 @@ def test_lifecycle__nominal(
assert res.status_code == 204, res.json()
assert res.text in {"", "null"} # Old FastAPI versions return 'null'.
- # deletion of short-term storages with empty list
+ # If the short-term storage list is empty, the deletion should be a no-op.
res = client.request(
"DELETE",
f"/v1/studies/{study_id}/areas/{area_id}/storages",
@@ -252,48 +256,79 @@ def test_lifecycle__nominal(
assert res.status_code == 204, res.json()
assert res.text in {"", "null"} # Old FastAPI versions return 'null'.
- # deletion of short-term storages with multiple IDs
+ # It's possible to delete multiple short-term storages at once.
+ # In the following example, we will create two short-term storages:
+ siemens_properties = {
+ "name": siemens_battery,
+ "group": "Battery",
+ "injectionNominalCapacity": 1450,
+ "withdrawalNominalCapacity": 1350,
+ "reservoirCapacity": 1500,
+ "efficiency": 0.90,
+ "initialLevel": 200,
+ "initialLevelOptim": False,
+ }
res = client.post(
f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
- json={"name": siemens_battery, "group": "Battery"},
+ json=siemens_properties,
)
assert res.status_code == 200, res.json()
- siemens_battery_id1 = res.json()["id"]
-
- siemens_battery_del = f"{siemens_battery}del"
+ siemens_battery_id = res.json()["id"]
+ # Create another short-term storage: "Grand'Maison"
+ grand_maison = "Grand'Maison"
+ grand_maison_properties = {
+ "name": grand_maison,
+ "group": "PSP_closed",
+ "injectionNominalCapacity": 1500,
+ "withdrawalNominalCapacity": 1800,
+ "reservoirCapacity": 20000,
+ "efficiency": 0.78,
+ "initialLevel": 10000,
+ }
res = client.post(
f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
- json={"name": siemens_battery_del, "group": "Battery"},
+ json=grand_maison_properties,
)
assert res.status_code == 200, res.json()
- siemens_battery_id2 = res.json()["id"]
+ grand_maison_id = res.json()["id"]
+ # We can check that we have 2 short-term storages in the list.
+ # Reading the list of short-term storages
+ res = client.get(
+ f"/v1/studies/{study_id}/areas/{area_id}/storages",
+ headers={"Authorization": f"Bearer {user_access_token}"},
+ )
+ assert res.status_code == 200, res.json()
+ siemens_config = {**DEFAULT_PROPERTIES, **siemens_properties, "id": siemens_battery_id}
+ grand_maison_config = {**DEFAULT_PROPERTIES, **grand_maison_properties, "id": grand_maison_id}
+ assert res.json() == [siemens_config, grand_maison_config]
+
+ # We can delete the two short-term storages at once.
res = client.request(
"DELETE",
f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
- json=[siemens_battery_id1, siemens_battery_id2],
+ json=[siemens_battery_id, grand_maison_id],
)
assert res.status_code == 204, res.json()
assert res.text in {"", "null"} # Old FastAPI versions return 'null'.
- # Check the removal
+ # The list of short-term storages should be empty.
res = client.get(
- f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}",
+ f"/v1/studies/{study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
)
- obj = res.json()
- description = obj["description"]
- assert siemens_battery_id in description
- assert re.search(r"fields of storage", description, flags=re.IGNORECASE)
- assert re.search(r"not found", description, flags=re.IGNORECASE)
+ assert res.status_code == 200, res.json()
+ assert res.json() == []
- assert res.status_code == 404, res.json()
+ # ===========================
+ # SHORT-TERM STORAGE ERRORS
+ # ===========================
- # Check delete with the wrong value of area_id
+ # Check delete with the wrong value of `area_id`
bad_area_id = "bad_area"
res = client.request(
"DELETE",
@@ -311,7 +346,7 @@ def test_lifecycle__nominal(
flags=re.IGNORECASE,
)
- # Check delete with the wrong value of study_id
+ # Check delete with the wrong value of `study_id`
bad_study_id = "bad_study"
res = client.request(
"DELETE",
@@ -324,8 +359,7 @@ def test_lifecycle__nominal(
assert res.status_code == 404, res.json()
assert bad_study_id in description
- # Check get with wrong area_id
-
+ # Check get with wrong `area_id`
res = client.get(
f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -335,8 +369,7 @@ def test_lifecycle__nominal(
assert bad_area_id in description
assert res.status_code == 404, res.json()
- # Check get with wrong study_id
-
+ # Check get with wrong `study_id`
res = client.get(
f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -346,7 +379,7 @@ def test_lifecycle__nominal(
assert res.status_code == 404, res.json()
assert bad_study_id in description
- # Check post with wrong study_id
+ # Check POST with wrong `study_id`
res = client.post(
f"/v1/studies/{bad_study_id}/areas/{area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -357,11 +390,20 @@ def test_lifecycle__nominal(
assert res.status_code == 404, res.json()
assert bad_study_id in description
- # Check post with wrong area_id
+ # Check POST with wrong `area_id`
res = client.post(
f"/v1/studies/{study_id}/areas/{bad_area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
- json={"name": siemens_battery, "group": "Battery"},
+ json={
+ "name": siemens_battery,
+ "group": "Battery",
+ "initialLevel": 0.0,
+ "initialLevelOptim": False,
+ "injectionNominalCapacity": 0.0,
+ "reservoirCapacity": 0.0,
+ "withdrawalNominalCapacity": 0.0,
+ "efficiency": 1.0,
+ },
)
assert res.status_code == 500, res.json()
obj = res.json()
@@ -370,7 +412,7 @@ def test_lifecycle__nominal(
assert re.search(r"Area ", description, flags=re.IGNORECASE)
assert re.search(r"does not exist ", description, flags=re.IGNORECASE)
- # Check post with wrong group
+ # Check POST with wrong `group`
res = client.post(
f"/v1/studies/{study_id}/areas/{bad_area_id}/storages",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -381,7 +423,7 @@ def test_lifecycle__nominal(
description = obj["description"]
assert re.search(r"not a valid enumeration member", description, flags=re.IGNORECASE)
- # Check the put with the wrong area_id
+ # Check PATCH with the wrong `area_id`
res = client.patch(
f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -401,7 +443,7 @@ def test_lifecycle__nominal(
assert bad_area_id in description
assert re.search(r"not a child of ", description, flags=re.IGNORECASE)
- # Check the put with the wrong siemens_battery_id
+ # Check PATCH with the wrong `siemens_battery_id`
res = client.patch(
f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -422,7 +464,7 @@ def test_lifecycle__nominal(
assert re.search(r"fields of storage", description, flags=re.IGNORECASE)
assert re.search(r"not found", description, flags=re.IGNORECASE)
- # Check the put with the wrong study_id
+ # Check PATCH with the wrong `study_id`
res = client.patch(
f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
@@ -440,19 +482,3 @@ def test_lifecycle__nominal(
obj = res.json()
description = obj["description"]
assert bad_study_id in description
-
- # Check the put with the wrong efficiency
- res = client.patch(
- f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}",
- headers={"Authorization": f"Bearer {user_access_token}"},
- json={
- "efficiency": 2.0,
- "initialLevel": 0.0,
- "initialLevelOptim": True,
- "injectionNominalCapacity": 2450,
- "name": "New Siemens Battery",
- "reservoirCapacity": 2500,
- "withdrawalNominalCapacity": 2350,
- },
- )
- assert res.status_code == 422, res.json()
diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py
index 7fd8131a54..df4249ca47 100644
--- a/tests/integration/test_integration.py
+++ b/tests/integration/test_integration.py
@@ -529,7 +529,7 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id:
],
)
- client.post(
+ res = client.post(
f"/v1/studies/{study_id}/commands",
headers=admin_headers,
json=[
@@ -545,8 +545,9 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id:
}
],
)
+ res.raise_for_status()
- client.post(
+ res = client.post(
f"/v1/studies/{study_id}/commands",
headers=admin_headers,
json=[
@@ -562,6 +563,7 @@ def test_area_management(client: TestClient, admin_access_token: str, study_id:
}
],
)
+ res.raise_for_status()
res_areas = client.get(f"/v1/studies/{study_id}/areas", headers=admin_headers)
assert res_areas.json() == [
diff --git a/tests/integration/test_integration_variantmanager_tool.py b/tests/integration/test_integration_variantmanager_tool.py
index d0315a31e1..a247d81c1d 100644
--- a/tests/integration/test_integration_variantmanager_tool.py
+++ b/tests/integration/test_integration_variantmanager_tool.py
@@ -1,19 +1,15 @@
-import os
+import io
import urllib.parse
from pathlib import Path
from typing import List, Tuple
from zipfile import ZipFile
+import numpy as np
+import numpy.typing as npt
from fastapi import FastAPI
from starlette.testclient import TestClient
from antarest.study.storage.rawstudy.io.reader import IniReader
-from antarest.study.storage.rawstudy.model.filesystem.matrix.constants import (
- default_4_fixed_hourly,
- default_8_fixed_hourly,
- default_scenario_daily,
- default_scenario_hourly,
-)
from antarest.study.storage.variantstudy.model.command.common import CommandName
from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO
from antarest.tools.lib import (
@@ -29,11 +25,10 @@
test_dir: Path = Path(__file__).parent
-def generate_csv_string(data: List[List[float]]) -> str:
- csv_str = ""
- for row in data:
- csv_str += "\t".join(["{:.6f}".format(v) for v in row]) + "\n"
- return csv_str
+def generate_csv_string(array: npt.NDArray[np.float64]) -> str:
+ buffer = io.StringIO()
+ np.savetxt(buffer, array, delimiter="\t", fmt="%.6f")
+ return buffer.getvalue()
def generate_study_with_server(
@@ -60,7 +55,7 @@ def generate_study_with_server(
return generator.apply_commands(commands, matrices_dir), variant_id
-def test_variant_manager(app: FastAPI, tmp_path: str):
+def test_variant_manager(app: FastAPI, tmp_path: str) -> None:
client = TestClient(app, raise_server_exceptions=False)
commands = parse_commands(test_dir / "assets" / "commands1.json")
matrix_dir = Path(tmp_path) / "empty_matrix_store"
@@ -69,7 +64,7 @@ def test_variant_manager(app: FastAPI, tmp_path: str):
assert res is not None and res.success
-def test_parse_commands(tmp_path: str, app: FastAPI):
+def test_parse_commands(tmp_path: str, app: FastAPI) -> None:
base_dir = test_dir / "assets"
export_path = Path(tmp_path) / "commands"
study = "base_study"
@@ -92,138 +87,133 @@ def test_parse_commands(tmp_path: str, app: FastAPI):
assert generated_study_path.exists() and generated_study_path.is_dir()
single_column_empty_items = [
- f"input{os.sep}load{os.sep}series{os.sep}load_hub w.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_south.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_hub n.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_west.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_north.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_hub s.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_hub e.txt",
- f"input{os.sep}load{os.sep}series{os.sep}load_east.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_east.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_north.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_hub n.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_south.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_hub w.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_west.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_hub e.txt",
- f"input{os.sep}wind{os.sep}series{os.sep}wind_hub s.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_east.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_hub n.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_south.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_hub s.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_north.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_hub w.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_hub e.txt",
- f"input{os.sep}solar{os.sep}series{os.sep}solar_west.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}semi base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}peak{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}west{os.sep}base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}semi base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}peak{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}north{os.sep}base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}semi base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}peak{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}east{os.sep}base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}semi base{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}peak{os.sep}series.txt",
- f"input{os.sep}thermal{os.sep}series{os.sep}south{os.sep}base{os.sep}series.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub e{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}south{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub w{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub s{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}west{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub n{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}north{os.sep}ror.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}east{os.sep}ror.txt",
+ "input/load/series/load_hub w.txt",
+ "input/load/series/load_south.txt",
+ "input/load/series/load_hub n.txt",
+ "input/load/series/load_west.txt",
+ "input/load/series/load_north.txt",
+ "input/load/series/load_hub s.txt",
+ "input/load/series/load_hub e.txt",
+ "input/load/series/load_east.txt",
+ "input/wind/series/wind_east.txt",
+ "input/wind/series/wind_north.txt",
+ "input/wind/series/wind_hub n.txt",
+ "input/wind/series/wind_south.txt",
+ "input/wind/series/wind_hub w.txt",
+ "input/wind/series/wind_west.txt",
+ "input/wind/series/wind_hub e.txt",
+ "input/wind/series/wind_hub s.txt",
+ "input/solar/series/solar_east.txt",
+ "input/solar/series/solar_hub n.txt",
+ "input/solar/series/solar_south.txt",
+ "input/solar/series/solar_hub s.txt",
+ "input/solar/series/solar_north.txt",
+ "input/solar/series/solar_hub w.txt",
+ "input/solar/series/solar_hub e.txt",
+ "input/solar/series/solar_west.txt",
+ "input/thermal/series/west/semi base/series.txt",
+ "input/thermal/series/west/peak/series.txt",
+ "input/thermal/series/west/base/series.txt",
+ "input/thermal/series/north/semi base/series.txt",
+ "input/thermal/series/north/peak/series.txt",
+ "input/thermal/series/north/base/series.txt",
+ "input/thermal/series/east/semi base/series.txt",
+ "input/thermal/series/east/peak/series.txt",
+ "input/thermal/series/east/base/series.txt",
+ "input/thermal/series/south/semi base/series.txt",
+ "input/thermal/series/south/peak/series.txt",
+ "input/thermal/series/south/base/series.txt",
+ "input/hydro/series/hub e/ror.txt",
+ "input/hydro/series/south/ror.txt",
+ "input/hydro/series/hub w/ror.txt",
+ "input/hydro/series/hub s/ror.txt",
+ "input/hydro/series/west/ror.txt",
+ "input/hydro/series/hub n/ror.txt",
+ "input/hydro/series/north/ror.txt",
+ "input/hydro/series/east/ror.txt",
]
single_column_daily_empty_items = [
- f"input{os.sep}hydro{os.sep}series{os.sep}hub e{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}south{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub w{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub s{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}west{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}hub n{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}north{os.sep}mod.txt",
- f"input{os.sep}hydro{os.sep}series{os.sep}east{os.sep}mod.txt",
+ "input/hydro/series/hub e/mod.txt",
+ "input/hydro/series/south/mod.txt",
+ "input/hydro/series/hub w/mod.txt",
+ "input/hydro/series/hub s/mod.txt",
+ "input/hydro/series/west/mod.txt",
+ "input/hydro/series/hub n/mod.txt",
+ "input/hydro/series/north/mod.txt",
+ "input/hydro/series/east/mod.txt",
+ ]
+ fixed_3_cols_hourly_empty_items = [
+ "input/bindingconstraints/northern mesh.txt",
+ "input/bindingconstraints/southern mesh.txt",
]
fixed_4_cols_empty_items = [
- f"input{os.sep}reserves{os.sep}hub s.txt",
- f"input{os.sep}reserves{os.sep}hub n.txt",
- f"input{os.sep}reserves{os.sep}hub w.txt",
- f"input{os.sep}reserves{os.sep}hub e.txt",
+ "input/reserves/hub s.txt",
+ "input/reserves/hub n.txt",
+ "input/reserves/hub w.txt",
+ "input/reserves/hub e.txt",
]
fixed_8_cols_empty_items = [
- f"input{os.sep}misc-gen{os.sep}miscgen-hub w.txt",
- f"input{os.sep}misc-gen{os.sep}miscgen-hub e.txt",
- f"input{os.sep}misc-gen{os.sep}miscgen-hub s.txt",
- f"input{os.sep}misc-gen{os.sep}miscgen-hub n.txt",
+ "input/misc-gen/miscgen-hub w.txt",
+ "input/misc-gen/miscgen-hub e.txt",
+ "input/misc-gen/miscgen-hub s.txt",
+ "input/misc-gen/miscgen-hub n.txt",
]
- single_column_empty_data = generate_csv_string(default_scenario_hourly)
- single_column_daily_empty_data = generate_csv_string(default_scenario_daily)
- fixed_4_columns_empty_data = generate_csv_string(default_4_fixed_hourly)
- fixed_8_columns_empty_data = generate_csv_string(default_8_fixed_hourly)
- for root, dirs, files in os.walk(study_path):
- rel_path = root[len(str(study_path)) + 1 :]
- for item in files:
- if item in [
- "comments.txt",
- "study.antares",
- "Desktop.ini",
- "study.ico",
- ]:
- continue
- elif f"{rel_path}{os.sep}{item}" in single_column_empty_items:
- assert (generated_study_path / rel_path / item).read_text() == single_column_empty_data
- elif f"{rel_path}{os.sep}{item}" in single_column_daily_empty_items:
- assert (generated_study_path / rel_path / item).read_text() == single_column_daily_empty_data
- elif f"{rel_path}{os.sep}{item}" in fixed_4_cols_empty_items:
- assert (generated_study_path / rel_path / item).read_text() == fixed_4_columns_empty_data
- elif f"{rel_path}{os.sep}{item}" in fixed_8_cols_empty_items:
- assert (generated_study_path / rel_path / item).read_text() == fixed_8_columns_empty_data
- else:
- actual = (study_path / rel_path / item).read_text()
- expected = (generated_study_path / rel_path / item).read_text()
- assert actual.strip() == expected.strip()
-
-
-def test_diff_local(tmp_path: Path):
+ single_column_empty_data = generate_csv_string(np.zeros((8760, 1), dtype=np.float64))
+ single_column_daily_empty_data = generate_csv_string(np.zeros((365, 1), dtype=np.float64))
+ fixed_3_cols_hourly_empty_data = generate_csv_string(np.zeros(shape=(8760, 3), dtype=np.float64))
+ fixed_4_columns_empty_data = generate_csv_string(np.zeros((8760, 4), dtype=np.float64))
+ fixed_8_columns_empty_data = generate_csv_string(np.zeros((8760, 8), dtype=np.float64))
+ for file_path in study_path.rglob("*"):
+ if file_path.is_dir() or file_path.name in ["comments.txt", "study.antares", "Desktop.ini", "study.ico"]:
+ continue
+ item_relpath = file_path.relative_to(study_path).as_posix()
+ if item_relpath in single_column_empty_items:
+ assert (generated_study_path / item_relpath).read_text() == single_column_empty_data
+ elif item_relpath in single_column_daily_empty_items:
+ assert (generated_study_path / item_relpath).read_text() == single_column_daily_empty_data
+ elif item_relpath in fixed_3_cols_hourly_empty_items:
+ assert (generated_study_path / item_relpath).read_text() == fixed_3_cols_hourly_empty_data
+ elif item_relpath in fixed_4_cols_empty_items:
+ assert (generated_study_path / item_relpath).read_text() == fixed_4_columns_empty_data
+ elif item_relpath in fixed_8_cols_empty_items:
+ assert (generated_study_path / item_relpath).read_text() == fixed_8_columns_empty_data
+ else:
+ actual = (study_path / item_relpath).read_text()
+ expected = (generated_study_path / item_relpath).read_text()
+ assert actual.strip() == expected.strip()
+
+
+def test_diff_local(tmp_path: Path) -> None:
base_dir = test_dir / "assets"
export_path = Path(tmp_path) / "generation_result"
base_study = "base_study"
variant_study = "variant_study"
- output_study_commands = Path(export_path) / "output_study_commands"
+ output_study_commands = export_path / "output_study_commands"
output_study_path = Path(tmp_path) / base_study
- base_study_commands = Path(export_path) / base_study
- variant_study_commands = Path(export_path) / variant_study
+ base_study_commands = export_path / base_study
+ variant_study_commands = export_path / variant_study
variant_study_path = Path(tmp_path) / variant_study
for study in [base_study, variant_study]:
with ZipFile(base_dir / f"{study}.zip") as zip_output:
zip_output.extractall(path=tmp_path)
- extract_commands(Path(tmp_path) / study, Path(export_path) / study)
+ extract_commands(Path(tmp_path) / study, export_path / study)
- res = generate_study(base_study_commands, None, str(Path(export_path) / "base_generated"))
- res = generate_study(
+ generate_study(base_study_commands, None, str(export_path / "base_generated"))
+ generate_study(
variant_study_commands,
None,
- str(Path(export_path) / "variant_generated"),
+ str(export_path / "variant_generated"),
)
generate_diff(base_study_commands, variant_study_commands, output_study_commands)
res = generate_study(output_study_commands, None, output=str(output_study_path))
assert res.success
assert output_study_path.exists() and output_study_path.is_dir()
- for root, dirs, files in os.walk(variant_study_path):
- rel_path = root[len(str(variant_study_path)) + 1 :]
- for item in files:
- if item in [
- "comments.txt",
- "study.antares",
- "Desktop.ini",
- "study.ico",
- ]:
- continue
- actual = (variant_study_path / rel_path / item).read_text()
- expected = (output_study_path / rel_path / item).read_text()
- assert actual.strip() == expected.strip()
+ for file_path in variant_study_path.rglob("*"):
+ if file_path.is_dir() or file_path.name in ["comments.txt", "study.antares", "Desktop.ini", "study.ico"]:
+ continue
+ item_relpath = file_path.relative_to(variant_study_path).as_posix()
+ actual = (variant_study_path / item_relpath).read_text()
+ expected = (output_study_path / item_relpath).read_text()
+ assert actual.strip() == expected.strip()
diff --git a/tests/launcher/test_local_launcher.py b/tests/launcher/test_local_launcher.py
index 04741d319a..53adc03bf0 100644
--- a/tests/launcher/test_local_launcher.py
+++ b/tests/launcher/test_local_launcher.py
@@ -1,19 +1,28 @@
import os
import textwrap
+import uuid
from pathlib import Path
from unittest.mock import Mock, call
-from uuid import uuid4
import pytest
-from sqlalchemy import create_engine
from antarest.core.config import Config, LauncherConfig, LocalConfig
-from antarest.core.persistence import Base
-from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.launcher.adapters.abstractlauncher import LauncherInitException
from antarest.launcher.adapters.local_launcher.local_launcher import LocalLauncher
from antarest.launcher.model import JobStatus, LauncherParametersDTO
+SOLVER_NAME = "solver.bat" if os.name == "nt" else "solver.sh"
+
+
+@pytest.fixture
+def launcher_config(tmp_path: Path) -> Config:
+ """
+ Fixture to create a launcher config with a local launcher.
+ """
+ solver_path = tmp_path.joinpath(SOLVER_NAME)
+ data = {"binaries": {"700": solver_path}, "enable_nb_cores_detection": True}
+ return Config(launcher=LauncherConfig(local=LocalConfig.from_dict(data)))
+
@pytest.mark.unit_test
def test_local_launcher__launcher_init_exception():
@@ -30,21 +39,12 @@ def test_local_launcher__launcher_init_exception():
@pytest.mark.unit_test
-def test_compute(tmp_path: Path):
- engine = create_engine("sqlite:///:memory:", echo=False)
- Base.metadata.create_all(engine)
- # noinspection SpellCheckingInspection
- DBSessionMiddleware(
- None,
- custom_engine=engine,
- session_args={"autocommit": False, "autoflush": False},
- )
- local_launcher = LocalLauncher(Config(), callbacks=Mock(), event_bus=Mock(), cache=Mock())
+def test_compute(tmp_path: Path, launcher_config: Config):
+ local_launcher = LocalLauncher(launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock())
# prepare a dummy executable to simulate Antares Solver
if os.name == "nt":
- solver_name = "solver.bat"
- solver_path = tmp_path.joinpath(solver_name)
+ solver_path = tmp_path.joinpath(SOLVER_NAME)
solver_path.write_text(
textwrap.dedent(
"""\
@@ -55,8 +55,7 @@ def test_compute(tmp_path: Path):
)
)
else:
- solver_name = "solver.sh"
- solver_path = tmp_path.joinpath(solver_name)
+ solver_path = tmp_path.joinpath(SOLVER_NAME)
solver_path.write_text(
textwrap.dedent(
"""\
@@ -68,8 +67,8 @@ def test_compute(tmp_path: Path):
)
solver_path.chmod(0o775)
- uuid = uuid4()
- local_launcher.job_id_to_study_id = {str(uuid): ("study-id", tmp_path / "run", Mock())}
+ study_id = uuid.uuid4()
+ local_launcher.job_id_to_study_id = {str(study_id): ("study-id", tmp_path / "run", Mock())}
local_launcher.callbacks.import_output.return_value = "some output"
launcher_parameters = LauncherParametersDTO(
adequacy_patch=None,
@@ -86,15 +85,15 @@ def test_compute(tmp_path: Path):
local_launcher._compute(
antares_solver_path=solver_path,
study_uuid="study-id",
- uuid=uuid,
+ uuid=study_id,
launcher_parameters=launcher_parameters,
)
# noinspection PyUnresolvedReferences
local_launcher.callbacks.update_status.assert_has_calls(
[
- call(str(uuid), JobStatus.RUNNING, None, None),
- call(str(uuid), JobStatus.SUCCESS, None, "some output"),
+ call(str(study_id), JobStatus.RUNNING, None, None),
+ call(str(study_id), JobStatus.SUCCESS, None, "some output"),
]
)
diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py
index 2c9f94d89c..a6177c5e61 100644
--- a/tests/launcher/test_service.py
+++ b/tests/launcher/test_service.py
@@ -3,15 +3,24 @@
import time
from datetime import datetime, timedelta
from pathlib import Path
-from typing import Dict, List, Literal, Union
+from typing import Dict, List, Union
from unittest.mock import Mock, call, patch
from uuid import uuid4
from zipfile import ZIP_DEFLATED, ZipFile
import pytest
from sqlalchemy import create_engine
-
-from antarest.core.config import Config, LauncherConfig, LocalConfig, SlurmConfig, StorageConfig
+from typing_extensions import Literal
+
+from antarest.core.config import (
+ Config,
+ InvalidConfigurationError,
+ LauncherConfig,
+ LocalConfig,
+ NbCoresConfig,
+ SlurmConfig,
+ StorageConfig,
+)
from antarest.core.exceptions import StudyNotFoundError
from antarest.core.filetransfer.model import FileDownload, FileDownloadDTO, FileDownloadTaskDTO
from antarest.core.interfaces.eventbus import Event, EventType
@@ -20,7 +29,7 @@
from antarest.core.requests import RequestParameters, UserHasNotPermissionError
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.dbmodel import Base
-from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LogType
+from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LauncherParametersDTO, LogType
from antarest.launcher.service import (
EXECUTION_INFO_FILE,
LAUNCHER_PARAM_NAME_SUFFIX,
@@ -33,780 +42,908 @@
from antarest.study.model import OwnerInfo, PublicMode, Study, StudyMetadataDTO
-@pytest.mark.unit_test
-@patch.object(Auth, "get_current_user")
-def test_service_run_study(get_current_user_mock):
- get_current_user_mock.return_value = None
- storage_service_mock = Mock()
- storage_service_mock.get_study_information.return_value = StudyMetadataDTO(
- id="id",
- name="name",
- created=1,
- updated=1,
- type="rawstudy",
- owner=OwnerInfo(id=0, name="author"),
- groups=[],
- public_mode=PublicMode.NONE,
- version=42,
- workspace="default",
- managed=True,
- archived=False,
- )
- storage_service_mock.get_study_path.return_value = Path("path/to/study")
-
- uuid = uuid4()
- launcher_mock = Mock()
- factory_launcher_mock = Mock()
- factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
-
- event_bus = Mock()
-
- pending = JobResult(
- id=str(uuid),
- study_id="study_uuid",
- job_status=JobStatus.PENDING,
- launcher="local",
- )
- repository = Mock()
- repository.save.return_value = pending
-
- launcher_service = LauncherService(
- config=Config(),
- study_service=storage_service_mock,
- job_result_repository=repository,
- factory_launcher=factory_launcher_mock,
- event_bus=event_bus,
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
- launcher_service._generate_new_id = lambda: str(uuid)
-
- job_id = launcher_service.run_study(
- "study_uuid",
- "local",
- None,
- RequestParameters(
- user=JWTUser(
- id=0,
- impersonator=0,
- type="users",
- )
- ),
- )
-
- assert job_id == str(uuid)
- repository.save.assert_called_once_with(pending)
- event_bus.push.assert_called_once_with(
- Event(
- type=EventType.STUDY_JOB_STARTED,
- payload=pending.to_dto().dict(),
- permissions=PermissionInfo(owner=0),
+class TestLauncherService:
+ @pytest.mark.unit_test
+ @patch.object(Auth, "get_current_user")
+ def test_service_run_study(self, get_current_user_mock) -> None:
+ get_current_user_mock.return_value = None
+ storage_service_mock = Mock()
+ # noinspection SpellCheckingInspection
+ storage_service_mock.get_study_information.return_value = StudyMetadataDTO(
+ id="id",
+ name="name",
+ created="1",
+ updated="1",
+ type="rawstudy",
+ owner=OwnerInfo(id=0, name="author"),
+ groups=[],
+ public_mode=PublicMode.NONE,
+ version=42,
+ workspace="default",
+ managed=True,
+ archived=False,
)
- )
+ storage_service_mock.get_study_path.return_value = Path("path/to/study")
+ uuid = uuid4()
+ launcher_mock = Mock()
+ factory_launcher_mock = Mock()
+ factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
-@pytest.mark.unit_test
-def test_service_get_result_from_launcher():
- launcher_mock = Mock()
- fake_execution_result = JobResult(
- id=str(uuid4()),
- study_id="sid",
- job_status=JobStatus.SUCCESS,
- msg="Hello, World!",
- exit_code=0,
- launcher="local",
- )
- factory_launcher_mock = Mock()
- factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
-
- repository = Mock()
- repository.get.return_value = fake_execution_result
-
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
-
- launcher_service = LauncherService(
- config=Config(),
- study_service=study_service,
- job_result_repository=repository,
- factory_launcher=factory_launcher_mock,
- event_bus=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
-
- job_id = uuid4()
- assert (
- launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER))
- == fake_execution_result
- )
+ event_bus = Mock()
+ pending = JobResult(
+ id=str(uuid),
+ study_id="study_uuid",
+ job_status=JobStatus.PENDING,
+ launcher="local",
+ launcher_params=LauncherParametersDTO().json(),
+ )
+ repository = Mock()
+ repository.save.return_value = pending
+
+ launcher_service = LauncherService(
+ config=Config(),
+ study_service=storage_service_mock,
+ job_result_repository=repository,
+ factory_launcher=factory_launcher_mock,
+ event_bus=event_bus,
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+ launcher_service._generate_new_id = lambda: str(uuid)
-@pytest.mark.unit_test
-def test_service_get_result_from_database():
- launcher_mock = Mock()
- fake_execution_result = JobResult(
- id=str(uuid4()),
- study_id="sid",
- job_status=JobStatus.SUCCESS,
- msg="Hello, World!",
- exit_code=0,
- )
- launcher_mock.get_result.return_value = None
- factory_launcher_mock = Mock()
- factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
-
- repository = Mock()
- repository.get.return_value = fake_execution_result
-
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
-
- launcher_service = LauncherService(
- config=Config(),
- study_service=study_service,
- job_result_repository=repository,
- factory_launcher=factory_launcher_mock,
- event_bus=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
-
- assert (
- launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER))
- == fake_execution_result
- )
+ job_id = launcher_service.run_study(
+ "study_uuid",
+ "local",
+ LauncherParametersDTO(),
+ RequestParameters(
+ user=JWTUser(
+ id=0,
+ impersonator=0,
+ type="users",
+ )
+ ),
+ )
+ assert job_id == str(uuid)
+ repository.save.assert_called_once_with(pending)
+ event_bus.push.assert_called_once_with(
+ Event(
+ type=EventType.STUDY_JOB_STARTED,
+ payload=pending.to_dto().dict(),
+ permissions=PermissionInfo(owner=0),
+ )
+ )
-@pytest.mark.unit_test
-def test_service_get_jobs_from_database():
- launcher_mock = Mock()
- now = datetime.utcnow()
- fake_execution_result = [
- JobResult(
+ @pytest.mark.unit_test
+ def test_service_get_result_from_launcher(self) -> None:
+ launcher_mock = Mock()
+ fake_execution_result = JobResult(
id=str(uuid4()),
- study_id="a",
+ study_id="sid",
job_status=JobStatus.SUCCESS,
msg="Hello, World!",
exit_code=0,
+ launcher="local",
)
- ]
- returned_faked_execution_results = [
- JobResult(
- id="1",
- study_id="a",
- job_status=JobStatus.SUCCESS,
- msg="Hello, World!",
- exit_code=0,
- creation_date=now,
- ),
- JobResult(
- id="2",
- study_id="b",
- job_status=JobStatus.SUCCESS,
- msg="Hello, World!",
- exit_code=0,
- creation_date=now,
- ),
- ]
- all_faked_execution_results = returned_faked_execution_results + [
- JobResult(
- id="3",
- study_id="c",
+ factory_launcher_mock = Mock()
+ factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
+
+ repository = Mock()
+ repository.get.return_value = fake_execution_result
+
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+
+ launcher_service = LauncherService(
+ config=Config(),
+ study_service=study_service,
+ job_result_repository=repository,
+ factory_launcher=factory_launcher_mock,
+ event_bus=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+
+ job_id = uuid4()
+ assert (
+ launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER))
+ == fake_execution_result
+ )
+
+ @pytest.mark.unit_test
+ def test_service_get_result_from_database(self) -> None:
+ launcher_mock = Mock()
+ fake_execution_result = JobResult(
+ id=str(uuid4()),
+ study_id="sid",
job_status=JobStatus.SUCCESS,
msg="Hello, World!",
exit_code=0,
- creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1),
- )
- ]
- launcher_mock.get_result.return_value = None
- factory_launcher_mock = Mock()
- factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
-
- repository = Mock()
- repository.find_by_study.return_value = fake_execution_result
- repository.get_all.return_value = all_faked_execution_results
-
- study_service = Mock()
- study_service.repository = Mock()
- study_service.repository.get_list.return_value = [
- Mock(
- spec=Study,
- id="b",
- groups=[],
- owner=User(id=2),
- public_mode=PublicMode.NONE,
)
- ]
-
- launcher_service = LauncherService(
- config=Config(),
- study_service=study_service,
- job_result_repository=repository,
- factory_launcher=factory_launcher_mock,
- event_bus=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
+ launcher_mock.get_result.return_value = None
+ factory_launcher_mock = Mock()
+ factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
+
+ repository = Mock()
+ repository.get.return_value = fake_execution_result
+
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+
+ launcher_service = LauncherService(
+ config=Config(),
+ study_service=study_service,
+ job_result_repository=repository,
+ factory_launcher=factory_launcher_mock,
+ event_bus=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
- study_id = uuid4()
- assert (
- launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER))
- == fake_execution_result
- )
- repository.find_by_study.assert_called_once_with(str(study_id))
- assert (
- launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER))
- == all_faked_execution_results
- )
- assert (
- launcher_service.get_jobs(
- None,
- params=RequestParameters(
- user=JWTUser(
- id=2,
- impersonator=2,
- type="users",
- groups=[],
- )
- ),
+ assert (
+ launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER))
+ == fake_execution_result
)
- == returned_faked_execution_results
- )
- with pytest.raises(UserHasNotPermissionError):
- launcher_service.remove_job(
- "some job",
- RequestParameters(
- user=JWTUser(
- id=2,
- impersonator=2,
- type="users",
- groups=[],
- )
+ @pytest.mark.unit_test
+ def test_service_get_jobs_from_database(self) -> None:
+ launcher_mock = Mock()
+ now = datetime.utcnow()
+ fake_execution_result = [
+ JobResult(
+ id=str(uuid4()),
+ study_id="a",
+ job_status=JobStatus.SUCCESS,
+ msg="Hello, World!",
+ exit_code=0,
+ )
+ ]
+ returned_faked_execution_results = [
+ JobResult(
+ id="1",
+ study_id="a",
+ job_status=JobStatus.SUCCESS,
+ msg="Hello, World!",
+ exit_code=0,
+ creation_date=now,
+ ),
+ JobResult(
+ id="2",
+ study_id="b",
+ job_status=JobStatus.SUCCESS,
+ msg="Hello, World!",
+ exit_code=0,
+ creation_date=now,
),
+ ]
+ all_faked_execution_results = returned_faked_execution_results + [
+ JobResult(
+ id="3",
+ study_id="c",
+ job_status=JobStatus.SUCCESS,
+ msg="Hello, World!",
+ exit_code=0,
+ creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1),
+ )
+ ]
+ launcher_mock.get_result.return_value = None
+ factory_launcher_mock = Mock()
+ factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock}
+
+ repository = Mock()
+ repository.find_by_study.return_value = fake_execution_result
+ repository.get_all.return_value = all_faked_execution_results
+
+ study_service = Mock()
+ study_service.repository = Mock()
+ study_service.repository.get_list.return_value = [
+ Mock(
+ spec=Study,
+ id="b",
+ groups=[],
+ owner=User(id=2),
+ public_mode=PublicMode.NONE,
+ )
+ ]
+
+ launcher_service = LauncherService(
+ config=Config(),
+ study_service=study_service,
+ job_result_repository=repository,
+ factory_launcher=factory_launcher_mock,
+ event_bus=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
)
- launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER))
- repository.delete.assert_called_with("some job")
+ study_id = uuid4()
+ assert (
+ launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER))
+ == fake_execution_result
+ )
+ repository.find_by_study.assert_called_once_with(str(study_id))
+ assert (
+ launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER))
+ == all_faked_execution_results
+ )
+ assert (
+ launcher_service.get_jobs(
+ None,
+ params=RequestParameters(
+ user=JWTUser(
+ id=2,
+ impersonator=2,
+ type="users",
+ groups=[],
+ )
+ ),
+ )
+ == returned_faked_execution_results
+ )
+ with pytest.raises(UserHasNotPermissionError):
+ launcher_service.remove_job(
+ "some job",
+ RequestParameters(
+ user=JWTUser(
+ id=2,
+ impersonator=2,
+ type="users",
+ groups=[],
+ )
+ ),
+ )
-@pytest.mark.unit_test
-@pytest.mark.parametrize(
- "config, solver, expected",
- [
- pytest.param(
- {
- "default": "local",
- "local": [],
- "slurm": [],
- },
- "default",
- [],
- id="empty-config",
- ),
- pytest.param(
- {
- "default": "local",
- "local": ["456", "123", "798"],
- },
- "default",
- ["123", "456", "798"],
- id="local-config-default",
- ),
- pytest.param(
- {
- "default": "local",
- "local": ["456", "123", "798"],
- },
- "slurm",
- [],
- id="local-config-slurm",
- ),
- pytest.param(
- {
- "default": "local",
- "local": ["456", "123", "798"],
- },
- "unknown",
- [],
- id="local-config-unknown",
- marks=pytest.mark.xfail(
- reason="Unknown solver configuration: 'unknown'",
- raises=KeyError,
- strict=True,
+ launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER))
+ repository.delete.assert_called_with("some job")
+
+ @pytest.mark.unit_test
+ @pytest.mark.parametrize(
+ "config, solver, expected",
+ [
+ pytest.param(
+ {
+ "default": "local",
+ "local": [],
+ "slurm": [],
+ },
+ "default",
+ [],
+ id="empty-config",
),
- ),
- pytest.param(
- {
- "default": "slurm",
- "slurm": ["258", "147", "369"],
- },
- "default",
- ["147", "258", "369"],
- id="slurm-config-default",
- ),
- pytest.param(
- {
- "default": "slurm",
- "slurm": ["258", "147", "369"],
- },
- "local",
- [],
- id="slurm-config-local",
- ),
- pytest.param(
- {
- "default": "slurm",
- "slurm": ["258", "147", "369"],
- },
- "unknown",
- [],
- id="slurm-config-unknown",
- marks=pytest.mark.xfail(
- reason="Unknown solver configuration: 'unknown'",
- raises=KeyError,
- strict=True,
+ pytest.param(
+ {
+ "default": "local",
+ "local": ["456", "123", "798"],
+ },
+ "default",
+ ["123", "456", "798"],
+ id="local-config-default",
),
- ),
- pytest.param(
- {
- "default": "slurm",
- "local": ["456", "123", "798"],
- "slurm": ["258", "147", "369"],
- },
- "local",
- ["123", "456", "798"],
- id="local+slurm-config-local",
- ),
- ],
-)
-def test_service_get_solver_versions(
- config: Dict[str, Union[str, List[str]]],
- solver: Literal["default", "local", "slurm", "unknown"],
- expected: List[str],
-) -> None:
- # Prepare the configuration
- default = config.get("default", "local")
- local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])})
- slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", []))
- launcher_config = LauncherConfig(
- default=default,
- local=local if local else None,
- slurm=slurm if slurm else None,
- )
- config = Config(launcher=launcher_config)
- launcher_service = LauncherService(
- config=config,
- study_service=Mock(),
- job_result_repository=Mock(),
- factory_launcher=Mock(),
- event_bus=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
+ pytest.param(
+ {
+ "default": "local",
+ "local": ["456", "123", "798"],
+ },
+ "slurm",
+ [],
+ id="local-config-slurm",
+ ),
+ pytest.param(
+ {
+ "default": "local",
+ "local": ["456", "123", "798"],
+ },
+ "unknown",
+ [],
+ id="local-config-unknown",
+ marks=pytest.mark.xfail(
+ reason="Unknown solver configuration: 'unknown'",
+ raises=KeyError,
+ strict=True,
+ ),
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": ["258", "147", "369"],
+ },
+ "default",
+ ["147", "258", "369"],
+ id="slurm-config-default",
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": ["258", "147", "369"],
+ },
+ "local",
+ [],
+ id="slurm-config-local",
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": ["258", "147", "369"],
+ },
+ "unknown",
+ [],
+ id="slurm-config-unknown",
+ marks=pytest.mark.xfail(
+ reason="Unknown solver configuration: 'unknown'",
+ raises=KeyError,
+ strict=True,
+ ),
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "local": ["456", "123", "798"],
+ "slurm": ["258", "147", "369"],
+ },
+ "local",
+ ["123", "456", "798"],
+ id="local+slurm-config-local",
+ ),
+ ],
)
+ def test_service_get_solver_versions(
+ self,
+ config: Dict[str, Union[str, List[str]]],
+ solver: Literal["default", "local", "slurm", "unknown"],
+ expected: List[str],
+ ) -> None:
+ # Prepare the configuration
+ # the default server version from the configuration file.
+ # the default server is initialised to local
+ default = config.get("default", "local")
+ local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])})
+ slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", []))
+ launcher_config = LauncherConfig(
+ default=default,
+ local=local if local else None,
+ slurm=slurm if slurm else None,
+ )
+ config = Config(launcher=launcher_config)
+ launcher_service = LauncherService(
+ config=config,
+ study_service=Mock(),
+ job_result_repository=Mock(),
+ factory_launcher=Mock(),
+ event_bus=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
- # Fetch the solver versions
- actual = launcher_service.get_solver_versions(solver)
+ # Fetch the solver versions
+ actual = launcher_service.get_solver_versions(solver)
+ assert actual == expected
- # Check the result
- assert actual == expected
+ @pytest.mark.unit_test
+ @pytest.mark.parametrize(
+ "config_map, solver, expected",
+ [
+ pytest.param(
+ {"default": "local", "local": {}, "slurm": {}},
+ "default",
+ {},
+ id="empty-config",
+ ),
+ pytest.param(
+ {
+ "default": "local",
+ "local": {"min": 1, "default": 11, "max": 12},
+ },
+ "default",
+ {"min": 1, "default": 11, "max": 12},
+ id="local-config-default",
+ ),
+ pytest.param(
+ {
+ "default": "local",
+ "local": {"min": 1, "default": 11, "max": 12},
+ },
+ "slurm",
+ {},
+ id="local-config-slurm",
+ ),
+ pytest.param(
+ {
+ "default": "local",
+ "local": {"min": 1, "default": 11, "max": 12},
+ },
+ "unknown",
+ {},
+ id="local-config-unknown",
+ marks=pytest.mark.xfail(
+ reason="Configuration is not available for the 'unknown' launcher",
+ raises=InvalidConfigurationError,
+ strict=True,
+ ),
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": {"min": 4, "default": 8, "max": 16},
+ },
+ "default",
+ {"min": 4, "default": 8, "max": 16},
+ id="slurm-config-default",
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": {"min": 4, "default": 8, "max": 16},
+ },
+ "local",
+ {},
+ id="slurm-config-local",
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "slurm": {"min": 4, "default": 8, "max": 16},
+ },
+ "unknown",
+ {},
+ id="slurm-config-unknown",
+ marks=pytest.mark.xfail(
+ reason="Configuration is not available for the 'unknown' launcher",
+ raises=InvalidConfigurationError,
+ strict=True,
+ ),
+ ),
+ pytest.param(
+ {
+ "default": "slurm",
+ "local": {"min": 1, "default": 11, "max": 12},
+ "slurm": {"min": 4, "default": 8, "max": 16},
+ },
+ "local",
+ {"min": 1, "default": 11, "max": 12},
+ id="local+slurm-config-local",
+ ),
+ ],
+ )
+ def test_get_nb_cores(
+ self,
+ config_map: Dict[str, Union[str, Dict[str, int]]],
+ solver: Literal["default", "local", "slurm", "unknown"],
+ expected: Dict[str, int],
+ ) -> None:
+ # Prepare the configuration
+ default = config_map.get("default", "local")
+ local_nb_cores = config_map.get("local", {})
+ slurm_nb_cores = config_map.get("slurm", {})
+ launcher_config = LauncherConfig(
+ default=default,
+ local=LocalConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": local_nb_cores}),
+ slurm=SlurmConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": slurm_nb_cores}),
+ )
+ launcher_service = LauncherService(
+ config=Config(launcher=launcher_config),
+ study_service=Mock(),
+ job_result_repository=Mock(),
+ factory_launcher=Mock(),
+ event_bus=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+ # Fetch the number of cores
+ actual = launcher_service.get_nb_cores(solver)
+
+ # Check the result
+ assert actual == NbCoresConfig(**expected)
+
+ @pytest.mark.unit_test
+ def test_service_kill_job(self, tmp_path: Path) -> None:
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+
+ launcher_service = LauncherService(
+ config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
+ study_service=study_service,
+ job_result_repository=Mock(),
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+ launcher = "slurm"
+ job_id = "job_id"
+ job_result_mock = Mock()
+ job_result_mock.id = job_id
+ job_result_mock.study_id = "study_id"
+ job_result_mock.launcher = launcher
+ launcher_service.job_result_repository.get.return_value = job_result_mock
+ launcher_service.launchers = {"slurm": Mock()}
+
+ job_status = launcher_service.kill_job(
+ job_id=job_id,
+ params=RequestParameters(user=DEFAULT_ADMIN_USER),
+ )
-@pytest.mark.unit_test
-def test_service_kill_job(tmp_path: Path):
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+ launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id)
- launcher_service = LauncherService(
- config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
- study_service=study_service,
- job_result_repository=Mock(),
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
- launcher = "slurm"
- job_id = "job_id"
- job_result_mock = Mock()
- job_result_mock.id = job_id
- job_result_mock.study_id = "study_id"
- job_result_mock.launcher = launcher
- launcher_service.job_result_repository.get.return_value = job_result_mock
- launcher_service.launchers = {"slurm": Mock()}
-
- job_status = launcher_service.kill_job(
- job_id=job_id,
- params=RequestParameters(user=DEFAULT_ADMIN_USER),
- )
+ assert job_status.job_status == JobStatus.FAILED
+ launcher_service.job_result_repository.save.assert_called_once_with(job_status)
- launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id)
+ def test_append_logs(self, tmp_path: Path) -> None:
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
- assert job_status.job_status == JobStatus.FAILED
- launcher_service.job_result_repository.save.assert_called_once_with(job_status)
+ launcher_service = LauncherService(
+ config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
+ study_service=study_service,
+ job_result_repository=Mock(),
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+ launcher = "slurm"
+ job_id = "job_id"
+ job_result_mock = Mock()
+ job_result_mock.id = job_id
+ job_result_mock.study_id = "study_id"
+ job_result_mock.output_id = None
+ job_result_mock.launcher = launcher
+ job_result_mock.logs = []
+ launcher_service.job_result_repository.get.return_value = job_result_mock
+
+ engine = create_engine("sqlite:///:memory:", echo=False)
+ Base.metadata.create_all(engine)
+ # noinspection SpellCheckingInspection
+ DBSessionMiddleware(
+ None,
+ custom_engine=engine,
+ session_args={"autocommit": False, "autoflush": False},
+ )
+ launcher_service.append_log(job_id, "test", JobLogType.BEFORE)
+ launcher_service.job_result_repository.save.assert_called_with(job_result_mock)
+ assert job_result_mock.logs[0].message == "test"
+ assert job_result_mock.logs[0].job_id == "job_id"
+ assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE)
+
+ def test_get_logs(self, tmp_path: Path) -> None:
+ study_service = Mock()
+ launcher_service = LauncherService(
+ config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
+ study_service=study_service,
+ job_result_repository=Mock(),
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+ launcher = "slurm"
+ job_id = "job_id"
+ job_result_mock = Mock()
+ job_result_mock.id = job_id
+ job_result_mock.study_id = "study_id"
+ job_result_mock.output_id = None
+ job_result_mock.launcher = launcher
+ job_result_mock.logs = [
+ JobLog(message="first message", log_type=str(JobLogType.BEFORE)),
+ JobLog(message="second message", log_type=str(JobLogType.BEFORE)),
+ JobLog(message="last message", log_type=str(JobLogType.AFTER)),
+ ]
+ job_result_mock.launcher_params = '{"archive_output": false}'
+
+ launcher_service.job_result_repository.get.return_value = job_result_mock
+ slurm_launcher = Mock()
+ launcher_service.launchers = {"slurm": slurm_launcher}
+ slurm_launcher.get_log.return_value = "launcher logs"
+
+ logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER))
+ assert logs == "first message\nsecond message\nlauncher logs\nlast message"
+ logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER))
+ assert logs == "launcher logs"
+
+ study_service.get_logs.side_effect = ["some sim log", "error log"]
+
+ job_result_mock.output_id = "some id"
+ logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER))
+ assert logs == "first message\nsecond message\nsome sim log\nlast message"
+
+ logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER))
+ assert logs == "error log"
+
+ study_service.get_logs.assert_has_calls(
+ [
+ call(
+ "study_id",
+ "some id",
+ job_id,
+ False,
+ params=RequestParameters(DEFAULT_ADMIN_USER),
+ ),
+ call(
+ "study_id",
+ "some id",
+ job_id,
+ True,
+ params=RequestParameters(DEFAULT_ADMIN_USER),
+ ),
+ ]
+ )
+ def test_manage_output(self, tmp_path: Path) -> None:
+ engine = create_engine("sqlite:///:memory:", echo=False)
+ Base.metadata.create_all(engine)
+ # noinspection SpellCheckingInspection
+ DBSessionMiddleware(
+ None,
+ custom_engine=engine,
+ session_args={"autocommit": False, "autoflush": False},
+ )
-def test_append_logs(tmp_path: Path):
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+
+ launcher_service = LauncherService(
+ config=Mock(storage=StorageConfig(tmp_dir=tmp_path)),
+ study_service=study_service,
+ job_result_repository=Mock(),
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
- launcher_service = LauncherService(
- config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
- study_service=study_service,
- job_result_repository=Mock(),
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
- launcher = "slurm"
- job_id = "job_id"
- job_result_mock = Mock()
- job_result_mock.id = job_id
- job_result_mock.study_id = "study_id"
- job_result_mock.output_id = None
- job_result_mock.launcher = launcher
- job_result_mock.logs = []
- launcher_service.job_result_repository.get.return_value = job_result_mock
-
- engine = create_engine("sqlite:///:memory:", echo=False)
- Base.metadata.create_all(engine)
- # noinspection SpellCheckingInspection
- DBSessionMiddleware(
- None,
- custom_engine=engine,
- session_args={"autocommit": False, "autoflush": False},
- )
- launcher_service.append_log(job_id, "test", JobLogType.BEFORE)
- launcher_service.job_result_repository.save.assert_called_with(job_result_mock)
- assert job_result_mock.logs[0].message == "test"
- assert job_result_mock.logs[0].job_id == "job_id"
- assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE)
-
-
-def test_get_logs(tmp_path: Path):
- study_service = Mock()
- launcher_service = LauncherService(
- config=Config(storage=StorageConfig(tmp_dir=tmp_path)),
- study_service=study_service,
- job_result_repository=Mock(),
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
- launcher = "slurm"
- job_id = "job_id"
- job_result_mock = Mock()
- job_result_mock.id = job_id
- job_result_mock.study_id = "study_id"
- job_result_mock.output_id = None
- job_result_mock.launcher = launcher
- job_result_mock.logs = [
- JobLog(message="first message", log_type=str(JobLogType.BEFORE)),
- JobLog(message="second message", log_type=str(JobLogType.BEFORE)),
- JobLog(message="last message", log_type=str(JobLogType.AFTER)),
- ]
- job_result_mock.launcher_params = '{"archive_output": false}'
-
- launcher_service.job_result_repository.get.return_value = job_result_mock
- slurm_launcher = Mock()
- launcher_service.launchers = {"slurm": slurm_launcher}
- slurm_launcher.get_log.return_value = "launcher logs"
-
- logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER))
- assert logs == "first message\nsecond message\nlauncher logs\nlast message"
- logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER))
- assert logs == "launcher logs"
-
- study_service.get_logs.side_effect = ["some sim log", "error log"]
-
- job_result_mock.output_id = "some id"
- logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER))
- assert logs == "first message\nsecond message\nsome sim log\nlast message"
-
- logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER))
- assert logs == "error log"
-
- study_service.get_logs.assert_has_calls(
- [
- call(
- "study_id",
- "some id",
- job_id,
- False,
- params=RequestParameters(DEFAULT_ADMIN_USER),
+ output_path = tmp_path / "output"
+ zipped_output_path = tmp_path / "zipped_output"
+ os.mkdir(output_path)
+ os.mkdir(zipped_output_path)
+ new_output_path = output_path / "new_output"
+ os.mkdir(new_output_path)
+ (new_output_path / "log").touch()
+ (new_output_path / "data").touch()
+ additional_log = tmp_path / "output.log"
+ additional_log.write_text("some log")
+ new_output_zipped_path = zipped_output_path / "test.zip"
+ with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data:
+ output_data.writestr("some output", "0\n1")
+ job_id = "job_id"
+ zipped_job_id = "zipped_job_id"
+ study_id = "study_id"
+ launcher_service.job_result_repository.get.side_effect = [
+ None,
+ JobResult(id=job_id, study_id=study_id),
+ JobResult(id=job_id, study_id=study_id, output_id="some id"),
+ JobResult(id=zipped_job_id, study_id=study_id),
+ JobResult(
+ id=job_id,
+ study_id=study_id,
),
- call(
- "study_id",
- "some id",
- job_id,
- True,
- params=RequestParameters(DEFAULT_ADMIN_USER),
+ JobResult(
+ id=job_id,
+ study_id=study_id,
+ launcher_params=json.dumps(
+ {
+ "archive_output": False,
+ f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello",
+ }
+ ),
),
]
- )
+ with pytest.raises(JobNotFound):
+ launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]})
-
-def test_manage_output(tmp_path: Path):
- engine = create_engine("sqlite:///:memory:", echo=False)
- Base.metadata.create_all(engine)
- # noinspection SpellCheckingInspection
- DBSessionMiddleware(
- None,
- custom_engine=engine,
- session_args={"autocommit": False, "autoflush": False},
- )
-
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
-
- launcher_service = LauncherService(
- config=Mock(storage=StorageConfig(tmp_dir=tmp_path)),
- study_service=study_service,
- job_result_repository=Mock(),
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
-
- output_path = tmp_path / "output"
- zipped_output_path = tmp_path / "zipped_output"
- os.mkdir(output_path)
- os.mkdir(zipped_output_path)
- new_output_path = output_path / "new_output"
- os.mkdir(new_output_path)
- (new_output_path / "log").touch()
- (new_output_path / "data").touch()
- additional_log = tmp_path / "output.log"
- additional_log.write_text("some log")
- new_output_zipped_path = zipped_output_path / "test.zip"
- with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data:
- output_data.writestr("some output", "0\n1")
- job_id = "job_id"
- zipped_job_id = "zipped_job_id"
- study_id = "study_id"
- launcher_service.job_result_repository.get.side_effect = [
- None,
- JobResult(id=job_id, study_id=study_id),
- JobResult(id=job_id, study_id=study_id, output_id="some id"),
- JobResult(id=zipped_job_id, study_id=study_id),
- JobResult(
- id=job_id,
- study_id=study_id,
- ),
- JobResult(
- id=job_id,
- study_id=study_id,
- launcher_params=json.dumps(
- {
- "archive_output": False,
- f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello",
- }
- ),
- ),
- ]
- with pytest.raises(JobNotFound):
launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]})
+ assert not launcher_service._get_job_output_fallback_path(job_id).exists()
+ launcher_service.study_service.import_output.assert_called()
- launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]})
- assert not launcher_service._get_job_output_fallback_path(job_id).exists()
- launcher_service.study_service.import_output.assert_called()
-
- launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER))
- launcher_service.study_service.export_output.assert_called()
-
- launcher_service._import_output(
- zipped_job_id,
- zipped_output_path,
- {
- "out.log": [additional_log],
- "antares-out": [additional_log],
- "antares-err": [additional_log],
- },
- )
- launcher_service.study_service.save_logs.has_calls(
- [
- call(study_id, zipped_job_id, "out.log", "some log"),
- call(study_id, zipped_job_id, "out", "some log"),
- call(study_id, zipped_job_id, "err", "some log"),
- ]
- )
-
- launcher_service.study_service.import_output.side_effect = [
- StudyNotFoundError(""),
- StudyNotFoundError(""),
- ]
-
- assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None
-
- (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}")
- output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]})
- assert output_name is not None
- assert output_name.endswith("-hello")
- assert launcher_service._get_job_output_fallback_path(job_id).exists()
- assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists()
-
- launcher_service.job_result_repository.get.reset_mock()
- launcher_service.job_result_repository.get.side_effect = [
- None,
- JobResult(id=job_id, study_id=study_id, output_id=output_name),
- ]
- with pytest.raises(JobNotFound):
launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER))
+ launcher_service.study_service.export_output.assert_called()
- study_service.get_study.reset_mock()
- study_service.get_study.side_effect = StudyNotFoundError("")
+ launcher_service._import_output(
+ zipped_job_id,
+ zipped_output_path,
+ {
+ "out.log": [additional_log],
+ "antares-out": [additional_log],
+ "antares-err": [additional_log],
+ },
+ )
+ launcher_service.study_service.save_logs.has_calls(
+ [
+ call(study_id, zipped_job_id, "out.log", "some log"),
+ call(study_id, zipped_job_id, "out", "some log"),
+ call(study_id, zipped_job_id, "err", "some log"),
+ ]
+ )
- export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True)
- launcher_service.file_transfer_manager.request_download.return_value = FileDownload(
- id="a", name="a", filename="a", ready=True, path="a"
- )
- launcher_service.task_service.add_task.return_value = "some id"
+ launcher_service.study_service.import_output.side_effect = [
+ StudyNotFoundError(""),
+ StudyNotFoundError(""),
+ ]
- assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO(
- task="some id", file=export_file
- )
+ assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None
- launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER))
- assert not launcher_service._get_job_output_fallback_path(job_id).exists()
+ (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}")
+ output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]})
+ assert output_name is not None
+ assert output_name.endswith("-hello")
+ assert launcher_service._get_job_output_fallback_path(job_id).exists()
+ assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists()
+ launcher_service.job_result_repository.get.reset_mock()
+ launcher_service.job_result_repository.get.side_effect = [
+ None,
+ JobResult(id=job_id, study_id=study_id, output_id=output_name),
+ ]
+ with pytest.raises(JobNotFound):
+ launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER))
-def test_save_stats(tmp_path: Path) -> None:
- study_service = Mock()
- study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+ study_service.get_study.reset_mock()
+ study_service.get_study.side_effect = StudyNotFoundError("")
- launcher_service = LauncherService(
- config=Mock(storage=StorageConfig(tmp_dir=tmp_path)),
- study_service=study_service,
- job_result_repository=Mock(),
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
+ export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True)
+ launcher_service.file_transfer_manager.request_download.return_value = FileDownload(
+ id="a", name="a", filename="a", ready=True, path="a"
+ )
+ launcher_service.task_service.add_task.return_value = "some id"
- job_id = "job_id"
- study_id = "study_id"
- job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS)
-
- output_path = tmp_path / "some-output"
- output_path.mkdir()
-
- launcher_service._save_solver_stats(job_result, output_path)
- launcher_service.job_result_repository.save.assert_not_called()
-
- expected_saved_stats = """#item duration_ms NbOccurences
-mc_years 216328 1
-study_loading 4304 1
-survey_report 158 1
-total 244581 1
-tsgen_hydro 1683 1
-tsgen_load 2702 1
-tsgen_solar 21606 1
-tsgen_thermal 407 2
-tsgen_wind 2500 1
- """
- (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats)
-
- launcher_service._save_solver_stats(job_result, output_path)
- launcher_service.job_result_repository.save.assert_called_with(
- JobResult(
- id=job_id,
- study_id=study_id,
- job_status=JobStatus.SUCCESS,
- solver_stats=expected_saved_stats,
+ assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO(
+ task="some id", file=export_file
)
- )
- zip_file = tmp_path / "test.zip"
- with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data:
- output_data.writestr(EXECUTION_INFO_FILE, "0\n1")
+ launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER))
+ assert not launcher_service._get_job_output_fallback_path(job_id).exists()
+
+ def test_save_solver_stats(self, tmp_path: Path) -> None:
+ study_service = Mock()
+ study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE)
+
+ launcher_service = LauncherService(
+ config=Mock(storage=StorageConfig(tmp_dir=tmp_path)),
+ study_service=study_service,
+ job_result_repository=Mock(),
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
- launcher_service._save_solver_stats(job_result, zip_file)
- launcher_service.job_result_repository.save.assert_called_with(
- JobResult(
- id=job_id,
- study_id=study_id,
- job_status=JobStatus.SUCCESS,
- solver_stats="0\n1",
+ job_id = "job_id"
+ study_id = "study_id"
+ job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS)
+
+ output_path = tmp_path / "some-output"
+ output_path.mkdir()
+
+ launcher_service._save_solver_stats(job_result, output_path)
+ launcher_service.job_result_repository.save.assert_not_called()
+
+ expected_saved_stats = """#item duration_ms NbOccurences
+ mc_years 216328 1
+ study_loading 4304 1
+ survey_report 158 1
+ total 244581 1
+ tsgen_hydro 1683 1
+ tsgen_load 2702 1
+ tsgen_solar 21606 1
+ tsgen_thermal 407 2
+ tsgen_wind 2500 1
+ """
+ (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats)
+
+ launcher_service._save_solver_stats(job_result, output_path)
+ launcher_service.job_result_repository.save.assert_called_with(
+ JobResult(
+ id=job_id,
+ study_id=study_id,
+ job_status=JobStatus.SUCCESS,
+ solver_stats=expected_saved_stats,
+ )
)
- )
+ zip_file = tmp_path / "test.zip"
+ with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data:
+ output_data.writestr(EXECUTION_INFO_FILE, "0\n1")
+
+ launcher_service._save_solver_stats(job_result, zip_file)
+ launcher_service.job_result_repository.save.assert_called_with(
+ JobResult(
+ id=job_id,
+ study_id=study_id,
+ job_status=JobStatus.SUCCESS,
+ solver_stats="0\n1",
+ )
+ )
-def test_get_load(tmp_path: Path):
- study_service = Mock()
- job_repository = Mock()
+ def test_get_load(self, tmp_path: Path) -> None:
+ study_service = Mock()
+ job_repository = Mock()
- launcher_service = LauncherService(
- config=Mock(
+ config = Config(
storage=StorageConfig(tmp_dir=tmp_path),
- launcher=LauncherConfig(local=LocalConfig(), slurm=SlurmConfig(default_n_cpu=12)),
- ),
- study_service=study_service,
- job_result_repository=job_repository,
- event_bus=Mock(),
- factory_launcher=Mock(),
- file_transfer_manager=Mock(),
- task_service=Mock(),
- cache=Mock(),
- )
-
- job_repository.get_running.side_effect = [
- [],
- [],
- [
- Mock(
- spec=JobResult,
- launcher="slurm",
- launcher_params=None,
- ),
- ],
- [
- Mock(
- spec=JobResult,
- launcher="slurm",
- launcher_params='{"nb_cpu": 18}',
- ),
- Mock(
- spec=JobResult,
- launcher="local",
- launcher_params=None,
- ),
- Mock(
- spec=JobResult,
- launcher="slurm",
- launcher_params=None,
+ launcher=LauncherConfig(
+ local=LocalConfig(),
+ slurm=SlurmConfig(nb_cores=NbCoresConfig(min=1, default=12, max=24)),
),
- Mock(
- spec=JobResult,
- launcher="local",
- launcher_params='{"nb_cpu": 7}',
- ),
- ],
- ]
-
- with pytest.raises(NotImplementedError):
- launcher_service.get_load(from_cluster=True)
-
- load = launcher_service.get_load()
- assert load["slurm"] == 0
- assert load["local"] == 0
- load = launcher_service.get_load()
- assert load["slurm"] == 12.0 / 64
- assert load["local"] == 0
- load = launcher_service.get_load()
- assert load["slurm"] == 30.0 / 64
- assert load["local"] == 8.0 / os.cpu_count()
+ )
+ launcher_service = LauncherService(
+ config=config,
+ study_service=study_service,
+ job_result_repository=job_repository,
+ event_bus=Mock(),
+ factory_launcher=Mock(),
+ file_transfer_manager=Mock(),
+ task_service=Mock(),
+ cache=Mock(),
+ )
+
+ job_repository.get_running.side_effect = [
+ # call #1
+ [],
+ # call #2
+ [],
+ # call #3
+ [
+ Mock(
+ spec=JobResult,
+ launcher="slurm",
+ launcher_params=None,
+ ),
+ ],
+ # call #4
+ [
+ Mock(
+ spec=JobResult,
+ launcher="slurm",
+ launcher_params='{"nb_cpu": 18}',
+ ),
+ Mock(
+ spec=JobResult,
+ launcher="local",
+ launcher_params=None,
+ ),
+ Mock(
+ spec=JobResult,
+ launcher="slurm",
+ launcher_params=None,
+ ),
+ Mock(
+ spec=JobResult,
+ launcher="local",
+ launcher_params='{"nb_cpu": 7}',
+ ),
+ ],
+ ]
+
+ # call #1
+ with pytest.raises(NotImplementedError):
+ launcher_service.get_load(from_cluster=True)
+
+ # call #2
+ load = launcher_service.get_load()
+ assert load["slurm"] == 0
+ assert load["local"] == 0
+
+ # call #3
+ load = launcher_service.get_load()
+ slurm_config = config.launcher.slurm
+ assert load["slurm"] == slurm_config.nb_cores.default / slurm_config.max_cores
+ assert load["local"] == 0
+
+ # call #4
+ load = launcher_service.get_load()
+ local_config = config.launcher.local
+ assert load["slurm"] == (18 + slurm_config.nb_cores.default) / slurm_config.max_cores
+ assert load["local"] == (7 + local_config.nb_cores.default) / local_config.nb_cores.max
diff --git a/tests/launcher/test_slurm_launcher.py b/tests/launcher/test_slurm_launcher.py
index 7820abcdea..dfb6846e89 100644
--- a/tests/launcher/test_slurm_launcher.py
+++ b/tests/launcher/test_slurm_launcher.py
@@ -10,11 +10,9 @@
from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb
from antareslauncher.main import MainParameters
from antareslauncher.study_dto import StudyDTO
-from sqlalchemy import create_engine
+from sqlalchemy.orm import Session # type: ignore
-from antarest.core.config import Config, LauncherConfig, SlurmConfig
-from antarest.core.persistence import Base
-from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
+from antarest.core.config import Config, LauncherConfig, NbCoresConfig, SlurmConfig
from antarest.launcher.adapters.abstractlauncher import LauncherInitException
from antarest.launcher.adapters.slurm_launcher.slurm_launcher import (
LOG_DIR_NAME,
@@ -24,32 +22,34 @@
SlurmLauncher,
VersionNotSupportedError,
)
-from antarest.launcher.model import JobStatus, LauncherParametersDTO
+from antarest.launcher.model import JobStatus, LauncherParametersDTO, XpansionParametersDTO
from antarest.tools.admin_lib import clean_locks_from_config
@pytest.fixture
def launcher_config(tmp_path: Path) -> Config:
- return Config(
- launcher=LauncherConfig(
- slurm=SlurmConfig(
- local_workspace=tmp_path,
- default_json_db_name="default_json_db_name",
- slurm_script_path="slurm_script_path",
- antares_versions_on_remote_server=["42", "45"],
- username="username",
- hostname="hostname",
- port=42,
- private_key_file=Path("private_key_file"),
- key_password="key_password",
- password="password",
- )
- )
- )
+ data = {
+ "local_workspace": tmp_path,
+ "username": "john",
+ "hostname": "slurm-001",
+ "port": 22,
+ "private_key_file": Path("/home/john/.ssh/id_rsa"),
+ "key_password": "password",
+ "password": "password",
+ "default_wait_time": 10,
+ "default_time_limit": 20,
+ "default_json_db_name": "antares.db",
+ "slurm_script_path": "/path/to/slurm/launcher.sh",
+ "max_cores": 32,
+ "antares_versions_on_remote_server": ["840", "850", "860"],
+ "enable_nb_cores_detection": False,
+ "nb_cores": {"min": 1, "default": 34, "max": 36},
+ }
+ return Config(launcher=LauncherConfig(slurm=SlurmConfig.from_dict(data)))
@pytest.mark.unit_test
-def test_slurm_launcher__launcher_init_exception():
+def test_slurm_launcher__launcher_init_exception() -> None:
with pytest.raises(
LauncherInitException,
match="Missing parameter 'launcher.slurm'",
@@ -63,13 +63,13 @@ def test_slurm_launcher__launcher_init_exception():
@pytest.mark.unit_test
-def test_init_slurm_launcher_arguments(tmp_path: Path):
+def test_init_slurm_launcher_arguments(tmp_path: Path) -> None:
config = Config(
launcher=LauncherConfig(
slurm=SlurmConfig(
default_wait_time=42,
default_time_limit=43,
- default_n_cpu=44,
+ nb_cores=NbCoresConfig(min=1, default=30, max=36),
local_workspace=tmp_path,
)
)
@@ -88,13 +88,15 @@ def test_init_slurm_launcher_arguments(tmp_path: Path):
assert not arguments.xpansion_mode
assert not arguments.version
assert not arguments.post_processing
- assert Path(arguments.studies_in) == config.launcher.slurm.local_workspace / "STUDIES_IN"
- assert Path(arguments.output_dir) == config.launcher.slurm.local_workspace / "OUTPUT"
- assert Path(arguments.log_dir) == config.launcher.slurm.local_workspace / "LOGS"
+ slurm_config = config.launcher.slurm
+ assert slurm_config is not None
+ assert Path(arguments.studies_in) == slurm_config.local_workspace / "STUDIES_IN"
+ assert Path(arguments.output_dir) == slurm_config.local_workspace / "OUTPUT"
+ assert Path(arguments.log_dir) == slurm_config.local_workspace / "LOGS"
@pytest.mark.unit_test
-def test_init_slurm_launcher_parameters(tmp_path: Path):
+def test_init_slurm_launcher_parameters(tmp_path: Path) -> None:
config = Config(
launcher=LauncherConfig(
slurm=SlurmConfig(
@@ -115,23 +117,25 @@ def test_init_slurm_launcher_parameters(tmp_path: Path):
slurm_launcher = SlurmLauncher(config=config, callbacks=Mock(), event_bus=Mock(), cache=Mock())
main_parameters = slurm_launcher._init_launcher_parameters()
- assert main_parameters.json_dir == config.launcher.slurm.local_workspace
- assert main_parameters.default_json_db_name == config.launcher.slurm.default_json_db_name
- assert main_parameters.slurm_script_path == config.launcher.slurm.slurm_script_path
- assert main_parameters.antares_versions_on_remote_server == config.launcher.slurm.antares_versions_on_remote_server
+ slurm_config = config.launcher.slurm
+ assert slurm_config is not None
+ assert main_parameters.json_dir == slurm_config.local_workspace
+ assert main_parameters.default_json_db_name == slurm_config.default_json_db_name
+ assert main_parameters.slurm_script_path == slurm_config.slurm_script_path
+ assert main_parameters.antares_versions_on_remote_server == slurm_config.antares_versions_on_remote_server
assert main_parameters.default_ssh_dict == {
- "username": config.launcher.slurm.username,
- "hostname": config.launcher.slurm.hostname,
- "port": config.launcher.slurm.port,
- "private_key_file": config.launcher.slurm.private_key_file,
- "key_password": config.launcher.slurm.key_password,
- "password": config.launcher.slurm.password,
+ "username": slurm_config.username,
+ "hostname": slurm_config.hostname,
+ "port": slurm_config.port,
+ "private_key_file": slurm_config.private_key_file,
+ "key_password": slurm_config.key_password,
+ "password": slurm_config.password,
}
assert main_parameters.db_primary_key == "name"
@pytest.mark.unit_test
-def test_slurm_launcher_delete_function(tmp_path: str):
+def test_slurm_launcher_delete_function(tmp_path: str) -> None:
config = Config(launcher=LauncherConfig(slurm=SlurmConfig(local_workspace=Path(tmp_path))))
slurm_launcher = SlurmLauncher(
config=config,
@@ -155,64 +159,104 @@ def test_slurm_launcher_delete_function(tmp_path: str):
assert not file_path.exists()
-def test_extra_parameters(launcher_config: Config):
+def test_extra_parameters(launcher_config: Config) -> None:
+ """
+ The goal of this unit test is to control the protected method `_check_and_apply_launcher_params`,
+ which is called by the `SlurmLauncher.run_study` function, in a separate thread.
+
+ The `_check_and_apply_launcher_params` method extract the parameters from the configuration
+ and populate a `argparse.Namespace` which is used to launch a simulation using Antares Launcher.
+
+ We want to make sure all the parameters are populated correctly.
+ """
slurm_launcher = SlurmLauncher(
config=launcher_config,
callbacks=Mock(),
event_bus=Mock(),
cache=Mock(),
)
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO())
- assert launcher_params.n_cpu == 1
- assert launcher_params.time_limit == 0
+
+ apply_params = slurm_launcher._apply_params
+ launcher_params = apply_params(LauncherParametersDTO())
+ slurm_config = slurm_launcher.config.launcher.slurm
+ assert slurm_config is not None
+ assert launcher_params.n_cpu == slurm_config.nb_cores.default
+ assert launcher_params.time_limit == slurm_config.default_time_limit
assert not launcher_params.xpansion_mode
assert not launcher_params.post_processing
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=12))
+ launcher_params = apply_params(LauncherParametersDTO(other_options=""))
+ assert launcher_params.other_options == ""
+
+ launcher_params = apply_params(LauncherParametersDTO(other_options="foo\tbar baz "))
+ assert launcher_params.other_options == "foo bar baz"
+
+ launcher_params = apply_params(LauncherParametersDTO(other_options="/foo?bar"))
+ assert launcher_params.other_options == "foobar"
+
+ launcher_params = apply_params(LauncherParametersDTO(nb_cpu=12))
assert launcher_params.n_cpu == 12
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=48))
- assert launcher_params.n_cpu == 1
+ launcher_params = apply_params(LauncherParametersDTO(nb_cpu=999))
+ assert launcher_params.n_cpu == slurm_config.nb_cores.default # out of range
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=10))
+ launcher_params = apply_params(LauncherParametersDTO(time_limit=10))
assert launcher_params.time_limit == MIN_TIME_LIMIT
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=999999999))
+ launcher_params = apply_params(LauncherParametersDTO(time_limit=999999999))
assert launcher_params.time_limit == MAX_TIME_LIMIT - 3600
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=99999))
+ launcher_params = apply_params(LauncherParametersDTO(time_limit=99999))
assert launcher_params.time_limit == 99999
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(xpansion=True))
- assert launcher_params.xpansion_mode
+ launcher_params = apply_params(LauncherParametersDTO(xpansion=False))
+ assert launcher_params.xpansion_mode is None
+ assert launcher_params.other_options == ""
+
+ launcher_params = apply_params(LauncherParametersDTO(xpansion=True))
+ assert launcher_params.xpansion_mode == "cpp"
+ assert launcher_params.other_options == ""
+
+ launcher_params = apply_params(LauncherParametersDTO(xpansion=True, xpansion_r_version=True))
+ assert launcher_params.xpansion_mode == "r"
+ assert launcher_params.other_options == ""
+
+ launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=False)))
+ assert launcher_params.xpansion_mode == "cpp"
+ assert launcher_params.other_options == ""
+
+ launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=True)))
+ assert launcher_params.xpansion_mode == "cpp"
+ assert launcher_params.other_options == "xpansion_sensitivity"
+
+ launcher_params = apply_params(LauncherParametersDTO(post_processing=False))
+ assert launcher_params.post_processing is False
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(post_processing=True))
- assert launcher_params.post_processing
+ launcher_params = apply_params(LauncherParametersDTO(post_processing=True))
+ assert launcher_params.post_processing is True
- launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(adequacy_patch={}))
- assert launcher_params.post_processing
+ launcher_params = apply_params(LauncherParametersDTO(adequacy_patch={}))
+ assert launcher_params.post_processing is True
# noinspection PyUnresolvedReferences
@pytest.mark.parametrize(
- "version, job_status",
- [(42, JobStatus.RUNNING), (99, JobStatus.FAILED), (45, JobStatus.FAILED)],
+ "version, launcher_called, job_status",
+ [
+ (840, True, JobStatus.RUNNING),
+ (860, False, JobStatus.FAILED),
+ pytest.param(
+ 999, False, JobStatus.FAILED, marks=pytest.mark.xfail(raises=VersionNotSupportedError, strict=True)
+ ),
+ ],
)
@pytest.mark.unit_test
def test_run_study(
- tmp_path: Path,
launcher_config: Config,
version: int,
+ launcher_called: bool,
job_status: JobStatus,
-):
- engine = create_engine("sqlite:///:memory:", echo=False)
- Base.metadata.create_all(engine)
- # noinspection SpellCheckingInspection
- DBSessionMiddleware(
- None,
- custom_engine=engine,
- session_args={"autocommit": False, "autoflush": False},
- )
+) -> None:
slurm_launcher = SlurmLauncher(
config=launcher_config,
callbacks=Mock(),
@@ -231,7 +275,8 @@ def test_run_study(
job_id = str(uuid.uuid4())
study_dir = argument.studies_in / job_id
study_dir.mkdir(parents=True)
- (study_dir / "study.antares").write_text(
+ study_antares_path = study_dir.joinpath("study.antares")
+ study_antares_path.write_text(
textwrap.dedent(
"""\
[antares]
@@ -242,22 +287,20 @@ def test_run_study(
# noinspection PyUnusedLocal
def call_launcher_mock(arguments: Namespace, parameters: MainParameters):
- if version != 45:
+ if launcher_called:
slurm_launcher.data_repo_tinydb.save_study(StudyDTO(job_id))
slurm_launcher._call_launcher = call_launcher_mock
- if version == 99:
- with pytest.raises(VersionNotSupportedError):
- slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version))
- else:
- slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version))
+ # When the launcher is called
+ slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version))
+ # Check the results
assert (
version not in launcher_config.launcher.slurm.antares_versions_on_remote_server
- or f"solver_version = {version}" in (study_dir / "study.antares").read_text(encoding="utf-8")
+ or f"solver_version = {version}" in study_antares_path.read_text(encoding="utf-8")
)
- # slurm_launcher._clean_local_workspace.assert_called_once()
+
slurm_launcher.callbacks.export_study.assert_called_once()
slurm_launcher.callbacks.update_status.assert_called_once_with(ANY, job_status, ANY, None)
if job_status == JobStatus.RUNNING:
@@ -266,7 +309,7 @@ def call_launcher_mock(arguments: Namespace, parameters: MainParameters):
@pytest.mark.unit_test
-def test_check_state(tmp_path: Path, launcher_config: Config):
+def test_check_state(tmp_path: Path, launcher_config: Config) -> None:
slurm_launcher = SlurmLauncher(
config=launcher_config,
callbacks=Mock(),
@@ -308,16 +351,7 @@ def test_check_state(tmp_path: Path, launcher_config: Config):
@pytest.mark.unit_test
-def test_clean_local_workspace(tmp_path: Path, launcher_config: Config):
- engine = create_engine("sqlite:///:memory:", echo=False)
- Base.metadata.create_all(engine)
- # noinspection SpellCheckingInspection
- DBSessionMiddleware(
- None,
- custom_engine=engine,
- session_args={"autocommit": False, "autoflush": False},
- )
-
+def test_clean_local_workspace(tmp_path: Path, launcher_config: Config) -> None:
slurm_launcher = SlurmLauncher(
config=launcher_config,
callbacks=Mock(),
@@ -325,7 +359,6 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config):
use_private_workspace=False,
cache=Mock(),
)
-
(launcher_config.launcher.slurm.local_workspace / "machin.txt").touch()
assert os.listdir(launcher_config.launcher.slurm.local_workspace)
@@ -335,7 +368,7 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config):
# noinspection PyUnresolvedReferences
@pytest.mark.unit_test
-def test_import_study_output(launcher_config, tmp_path):
+def test_import_study_output(launcher_config, tmp_path) -> None:
slurm_launcher = SlurmLauncher(
config=launcher_config,
callbacks=Mock(),
@@ -399,7 +432,7 @@ def test_kill_job(
run_with_mock,
tmp_path: Path,
launcher_config: Config,
-):
+) -> None:
launch_id = "launch_id"
mock_study = Mock()
mock_study.name = launch_id
@@ -419,35 +452,36 @@ def test_kill_job(
slurm_launcher.kill_job(job_id=launch_id)
+ slurm_config = launcher_config.launcher.slurm
launcher_arguments = Namespace(
antares_version=0,
check_queue=False,
- job_id_to_kill=42,
+ job_id_to_kill=mock_study.job_id,
json_ssh_config=None,
log_dir=str(tmp_path / "LOGS"),
- n_cpu=1,
+ n_cpu=slurm_config.nb_cores.default,
output_dir=str(tmp_path / "OUTPUT"),
post_processing=False,
studies_in=str(tmp_path / "STUDIES_IN"),
- time_limit=0,
+ time_limit=slurm_config.default_time_limit,
version=False,
wait_mode=False,
- wait_time=0,
+ wait_time=slurm_config.default_wait_time,
xpansion_mode=None,
other_options=None,
)
launcher_parameters = MainParameters(
json_dir=Path(tmp_path),
- default_json_db_name="default_json_db_name",
- slurm_script_path="slurm_script_path",
- antares_versions_on_remote_server=["42", "45"],
+ default_json_db_name=slurm_config.default_json_db_name,
+ slurm_script_path=slurm_config.slurm_script_path,
+ antares_versions_on_remote_server=slurm_config.antares_versions_on_remote_server,
default_ssh_dict={
- "username": "username",
- "hostname": "hostname",
- "port": 42,
- "private_key_file": Path("private_key_file"),
- "key_password": "key_password",
- "password": "password",
+ "username": slurm_config.username,
+ "hostname": slurm_config.hostname,
+ "port": slurm_config.port,
+ "private_key_file": slurm_config.private_key_file,
+ "key_password": slurm_config.key_password,
+ "password": slurm_config.password,
},
db_primary_key="name",
)
@@ -456,7 +490,7 @@ def test_kill_job(
@patch("antarest.launcher.adapters.slurm_launcher.slurm_launcher.run_with")
-def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config):
+def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config) -> None:
callbacks = Mock()
(tmp_path / LOG_DIR_NAME).mkdir()
@@ -474,11 +508,7 @@ def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config:
clean_locks_from_config(launcher_config)
assert not (workspaces[0] / WORKSPACE_LOCK_FILE_NAME).exists()
- slurm_launcher.data_repo_tinydb.save_study(
- StudyDTO(
- path="somepath",
- )
- )
+ slurm_launcher.data_repo_tinydb.save_study(StudyDTO(path="some_path"))
run_with_mock.assert_not_called()
# will use existing private workspace
diff --git a/tests/storage/repository/filesystem/config/test_config_files.py b/tests/storage/repository/filesystem/config/test_config_files.py
index 9b029fcea7..e0e0ccb45f 100644
--- a/tests/storage/repository/filesystem/config/test_config_files.py
+++ b/tests/storage/repository/filesystem/config/test_config_files.py
@@ -4,6 +4,7 @@
import pytest
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.files import (
_parse_links,
_parse_outputs,
@@ -73,6 +74,7 @@ def test_parse_bindings(tmp_path: Path) -> None:
[bindB]
id = bindB
+ type = weekly
"""
(study_path / "input/bindingconstraints/bindingconstraints.ini").write_text(content)
@@ -81,8 +83,18 @@ def test_parse_bindings(tmp_path: Path) -> None:
path=study_path,
version=-1,
bindings=[
- BindingConstraintDTO(id="bindA", areas=[], clusters=[]),
- BindingConstraintDTO(id="bindB", areas=[], clusters=[]),
+ BindingConstraintDTO(
+ id="bindA",
+ areas=set(),
+ clusters=set(),
+ time_step=BindingConstraintFrequency.HOURLY,
+ ),
+ BindingConstraintDTO(
+ id="bindB",
+ areas=set(),
+ clusters=set(),
+ time_step=BindingConstraintFrequency.WEEKLY,
+ ),
],
study_id="id",
output_path=study_path / "output",
diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py
index 0b0ed0db87..4986073713 100644
--- a/tests/storage/test_model.py
+++ b/tests/storage/test_model.py
@@ -1,5 +1,6 @@
from pathlib import Path
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import (
Area,
BindingConstraintDTO,
@@ -41,7 +42,14 @@ def test_file_study_tree_config_dto():
xpansion="",
)
},
- bindings=[BindingConstraintDTO(id="b1", areas=[], clusters=[])],
+ bindings=[
+ BindingConstraintDTO(
+ id="b1",
+ areas=set(),
+ clusters=set(),
+ time_step=BindingConstraintFrequency.DAILY,
+ )
+ ],
store_new_set=False,
archive_input_series=["?"],
enr_modelling="aggregated",
diff --git a/tests/study/business/conftest.py b/tests/study/business/conftest.py
deleted file mode 100644
index 2638d47b3d..0000000000
--- a/tests/study/business/conftest.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import contextlib
-
-import pytest
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
-
-from antarest.dbmodel import Base
-
-
-@pytest.fixture(scope="function", name="db_engine")
-def db_engine_fixture():
- engine = create_engine("sqlite:///:memory:")
- Base.metadata.create_all(engine)
- yield engine
- engine.dispose()
-
-
-@pytest.fixture(scope="function", name="db_session")
-def db_session_fixture(db_engine):
- make_session = sessionmaker(bind=db_engine)
- with contextlib.closing(make_session()) as session:
- yield session
diff --git a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py
index e9689834c5..93a3262259 100644
--- a/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py
+++ b/tests/study/storage/variantstudy/business/test_matrix_constants_generator.py
@@ -36,3 +36,22 @@ def test_get_st_storage(self, tmp_path):
matrix_id5 = ref5.split(MATRIX_PROTOCOL_PREFIX)[1]
matrix_dto5 = generator.matrix_service.get(matrix_id5)
assert np.array(matrix_dto5.data).all() == matrix_constants.st_storage.series.inflows.all()
+
+ def test_get_binding_constraint(self, tmp_path):
+ generator = GeneratorMatrixConstants(matrix_service=SimpleMatrixService(bucket_dir=tmp_path))
+ series = matrix_constants.binding_constraint.series
+
+ hourly = generator.get_binding_constraint_hourly()
+ hourly_matrix_id = hourly.split(MATRIX_PROTOCOL_PREFIX)[1]
+ hourly_matrix_dto = generator.matrix_service.get(hourly_matrix_id)
+ assert np.array(hourly_matrix_dto.data).all() == series.default_binding_constraint_hourly.all()
+
+ daily = generator.get_binding_constraint_daily()
+ daily_matrix_id = daily.split(MATRIX_PROTOCOL_PREFIX)[1]
+ daily_matrix_dto = generator.matrix_service.get(daily_matrix_id)
+ assert np.array(daily_matrix_dto.data).all() == series.default_binding_constraint_daily.all()
+
+ weekly = generator.get_binding_constraint_weekly()
+ weekly_matrix_id = weekly.split(MATRIX_PROTOCOL_PREFIX)[1]
+ weekly_matrix_dto = generator.matrix_service.get(weekly_matrix_id)
+ assert np.array(weekly_matrix_dto.data).all() == series.default_binding_constraint_weekly.all()
diff --git a/tests/test_resources.py b/tests/test_resources.py
index 2a0bf94677..330116e507 100644
--- a/tests/test_resources.py
+++ b/tests/test_resources.py
@@ -4,6 +4,8 @@
import pytest
+from antarest.core.config import Config
+
HERE = pathlib.Path(__file__).parent.resolve()
PROJECT_DIR = next(iter(p for p in HERE.parents if p.joinpath("antarest").exists()))
RESOURCES_DIR = PROJECT_DIR.joinpath("resources")
@@ -84,3 +86,17 @@ def test_empty_study_zip(filename: str, expected_list: Sequence[str]):
with zipfile.ZipFile(resource_path) as myzip:
actual = sorted(myzip.namelist())
assert actual == expected_list
+
+
+def test_resources_config():
+ """
+ Check that the "resources/config.yaml" file is valid.
+
+ The launcher section must be configured to use a local launcher
+ with NB Cores detection enabled.
+ """
+ config_path = RESOURCES_DIR.joinpath("deploy/config.yaml")
+ config = Config.from_yaml_file(config_path, res=RESOURCES_DIR)
+ assert config.launcher.default == "local"
+ assert config.launcher.local is not None
+ assert config.launcher.local.enable_nb_cores_detection is True
diff --git a/tests/variantstudy/model/command/test_manage_binding_constraints.py b/tests/variantstudy/model/command/test_manage_binding_constraints.py
index 1596ce6476..d22e05ce1e 100644
--- a/tests/variantstudy/model/command/test_manage_binding_constraints.py
+++ b/tests/variantstudy/model/command/test_manage_binding_constraints.py
@@ -1,10 +1,16 @@
from unittest.mock import Mock
from antarest.study.storage.rawstudy.io.reader import IniReader
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
from antarest.study.storage.variantstudy.business.command_extractor import CommandExtractor
from antarest.study.storage.variantstudy.business.command_reverter import CommandReverter
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep
+from antarest.study.storage.variantstudy.business.matrix_constants.binding_constraint.series import (
+ default_binding_constraint_daily,
+ default_binding_constraint_hourly,
+ default_binding_constraint_weekly,
+)
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
@@ -56,7 +62,7 @@ def test_manage_binding_constraint(
bind1_cmd = CreateBindingConstraint(
name="BD 1",
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.LESS,
coeffs={"area1%area2": [800, 30]},
comments="Hello",
@@ -68,7 +74,7 @@ def test_manage_binding_constraint(
bind2_cmd = CreateBindingConstraint(
name="BD 2",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"area1.cluster": [50]},
command_context=command_context,
@@ -101,13 +107,14 @@ def test_manage_binding_constraint(
"type": "daily",
}
+ weekly_values = default_binding_constraint_weekly.tolist()
bind_update = UpdateBindingConstraint(
id="bd 1",
enabled=False,
- time_step=TimeStep.WEEKLY,
+ time_step=BindingConstraintFrequency.WEEKLY,
operator=BindingConstraintOperator.BOTH,
coeffs={"area1%area2": [800, 30]},
- values=[[0]],
+ values=weekly_values,
command_context=command_context,
)
res = bind_update.apply(empty_study)
@@ -139,28 +146,29 @@ def test_manage_binding_constraint(
def test_match(command_context: CommandContext):
+ values = default_binding_constraint_daily.tolist()
base = CreateBindingConstraint(
name="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
other_match = CreateBindingConstraint(
name="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
other_not_match = CreateBindingConstraint(
name="bar",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
command_context=command_context,
@@ -171,31 +179,31 @@ def test_match(command_context: CommandContext):
assert not base.match(other_other)
assert base.match_signature() == "create_binding_constraint%foo"
# check the matrices links
- matrix_id = command_context.matrix_service.create([[0]])
+ matrix_id = command_context.matrix_service.create(values)
assert base.get_inner_matrices() == [matrix_id]
base = UpdateBindingConstraint(
id="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
other_match = UpdateBindingConstraint(
id="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
other_not_match = UpdateBindingConstraint(
id="bar",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
command_context=command_context,
@@ -206,7 +214,7 @@ def test_match(command_context: CommandContext):
assert not base.match(other_other)
assert base.match_signature() == "update_binding_constraint%foo"
# check the matrices links
- matrix_id = command_context.matrix_service.create([[0]])
+ matrix_id = command_context.matrix_service.create(values)
assert base.get_inner_matrices() == [matrix_id]
base = RemoveBindingConstraint(id="foo", command_context=command_context)
@@ -221,13 +229,16 @@ def test_match(command_context: CommandContext):
def test_revert(command_context: CommandContext):
+ hourly_values = default_binding_constraint_hourly.tolist()
+ daily_values = default_binding_constraint_daily.tolist()
+ weekly_values = default_binding_constraint_weekly.tolist()
base = CreateBindingConstraint(
name="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=daily_values,
command_context=command_context,
)
assert CommandReverter().revert(base, [], Mock(spec=FileStudy)) == [
@@ -237,10 +248,10 @@ def test_revert(command_context: CommandContext):
base = UpdateBindingConstraint(
id="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=daily_values,
command_context=command_context,
)
mock_command_extractor = Mock(spec=CommandExtractor)
@@ -255,19 +266,19 @@ def test_revert(command_context: CommandContext):
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.WEEKLY,
+ time_step=BindingConstraintFrequency.WEEKLY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=weekly_values,
command_context=command_context,
),
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=hourly_values,
command_context=command_context,
),
],
@@ -276,34 +287,34 @@ def test_revert(command_context: CommandContext):
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=hourly_values,
command_context=command_context,
)
]
# check the matrices links
- matrix_id = command_context.matrix_service.create([[0]])
+ hourly_matrix_id = command_context.matrix_service.create(hourly_values)
assert CommandReverter().revert(
base,
[
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.WEEKLY,
+ time_step=BindingConstraintFrequency.WEEKLY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=weekly_values,
command_context=command_context,
),
CreateBindingConstraint(
name="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.EQUAL,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=hourly_values,
command_context=command_context,
),
],
@@ -312,10 +323,10 @@ def test_revert(command_context: CommandContext):
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.EQUAL,
coeffs={"a": [0.3]},
- values=matrix_id,
+ values=hourly_matrix_id,
comments=None,
command_context=command_context,
)
@@ -329,7 +340,7 @@ def test_create_diff(command_context: CommandContext):
base = CreateBindingConstraint(
name="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
values="a",
@@ -338,7 +349,7 @@ def test_create_diff(command_context: CommandContext):
other_match = CreateBindingConstraint(
name="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.EQUAL,
coeffs={"b": [0.3]},
values="b",
@@ -348,7 +359,7 @@ def test_create_diff(command_context: CommandContext):
UpdateBindingConstraint(
id="foo",
enabled=True,
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.EQUAL,
coeffs={"b": [0.3]},
values="b",
@@ -356,22 +367,23 @@ def test_create_diff(command_context: CommandContext):
)
]
+ values = default_binding_constraint_daily.tolist()
base = UpdateBindingConstraint(
id="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
other_match = UpdateBindingConstraint(
id="foo",
enabled=False,
- time_step=TimeStep.DAILY,
+ time_step=BindingConstraintFrequency.DAILY,
operator=BindingConstraintOperator.BOTH,
coeffs={"a": [0.3]},
- values=[[0]],
+ values=values,
command_context=command_context,
)
assert base.create_diff(other_match) == [other_match]
diff --git a/tests/variantstudy/model/command/test_remove_area.py b/tests/variantstudy/model/command/test_remove_area.py
index cb03c4f2e0..3fb77082f2 100644
--- a/tests/variantstudy/model/command/test_remove_area.py
+++ b/tests/variantstudy/model/command/test_remove_area.py
@@ -1,9 +1,10 @@
import pytest
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
from antarest.study.storage.study_upgrader import upgrade_study
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
@@ -125,7 +126,7 @@ def test_apply(
bind1_cmd = CreateBindingConstraint(
name="BD 2",
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.LESS,
coeffs={
f"{area_id}%{area_id2}": [400, 30],
diff --git a/tests/variantstudy/model/command/test_remove_cluster.py b/tests/variantstudy/model/command/test_remove_cluster.py
index 0948e962f9..e4525fbc36 100644
--- a/tests/variantstudy/model/command/test_remove_cluster.py
+++ b/tests/variantstudy/model/command/test_remove_cluster.py
@@ -1,8 +1,9 @@
from checksumdir import dirhash
+from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
-from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator, TimeStep
+from antarest.study.storage.variantstudy.model.command.common import BindingConstraintOperator
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import CreateBindingConstraint
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
@@ -48,7 +49,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext):
bind1_cmd = CreateBindingConstraint(
name="BD 1",
- time_step=TimeStep.HOURLY,
+ time_step=BindingConstraintFrequency.HOURLY,
operator=BindingConstraintOperator.LESS,
coeffs={
f"{area_id}.{cluster_id}": [800, 30],
diff --git a/tests/variantstudy/test_command_factory.py b/tests/variantstudy/test_command_factory.py
index 4dbcb6a06f..47e3c57811 100644
--- a/tests/variantstudy/test_command_factory.py
+++ b/tests/variantstudy/test_command_factory.py
@@ -31,7 +31,10 @@ def setup_class(self):
f".{name}",
package="antarest.study.storage.variantstudy.model.command",
)
- self.command_class_set = {command.__name__ for command in ICommand.__subclasses__()}
+ abstract_commands = {"AbstractBindingConstraintCommand"}
+ self.command_class_set = {
+ cmd.__name__ for cmd in ICommand.__subclasses__() if cmd.__name__ not in abstract_commands
+ }
# noinspection SpellCheckingInspection
@pytest.mark.parametrize(
diff --git a/webapp/package-lock.json b/webapp/package-lock.json
index 9a4f7cd401..3ab3cca0b7 100644
--- a/webapp/package-lock.json
+++ b/webapp/package-lock.json
@@ -1,6 +1,6 @@
{
"name": "antares-web",
- "version": "2.15.1",
+ "version": "2.15.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
diff --git a/webapp/package.json b/webapp/package.json
index 74a8fd65f0..506c324bf2 100644
--- a/webapp/package.json
+++ b/webapp/package.json
@@ -1,6 +1,6 @@
{
"name": "antares-web",
- "version": "2.15.1",
+ "version": "2.15.2",
"private": true,
"engines": {
"node": "18.16.1"