From 2cacfa7fef2e814dde8f303123eb1618f1a4e07d Mon Sep 17 00:00:00 2001 From: MartinBelthle <102529366+MartinBelthle@users.noreply.github.com> Date: Mon, 18 Sep 2023 18:47:28 +0200 Subject: [PATCH] feat(parameters): handle the `--partition` and `--qos` parameters for the `sbatch` command (#58) * feat(parameters): add partition in MainParameters * refactor: simplify the implementation of `SlurmScriptFeatures.compose_launch_command` * refactor(parameters-reader): simplify the implementation of the parameters reader * refactor(parameters-reader): change the signature of the `SlurmScriptFeatures` constructor * style(parameters-reader): sort imports and reformat source code * feat(parameters-reader): make the `partition` parameter optional * feat(parameters-reader): handle the `--qos` (quality of service) parameter for the `sbatch` command --------- Co-authored-by: Laurent LAPORTE --- antareslauncher/antares_launcher.py | 12 +- antareslauncher/config.py | 2 +- antareslauncher/data_repo/data_repo_tinydb.py | 3 +- antareslauncher/main.py | 43 +++++-- antareslauncher/parameters_reader.py | 109 +++++++--------- .../slurm_script_features.py | 116 +++++++++--------- data/README.md | 7 ++ ...test_integration_check_queue_controller.py | 10 +- .../test_integration_job_kill_controller.py | 10 +- .../test_integration_launch_controller.py | 27 ++-- tests/unit/test_antares_launcher.py | 2 +- tests/unit/test_check_queue_controller.py | 4 +- tests/unit/test_config.py | 1 + tests/unit/test_job_kill_controller.py | 4 +- tests/unit/test_main_option_parser.py | 2 +- tests/unit/test_parameters_reader.py | 74 ++++++----- .../test_remote_environment_with_slurm.py | 36 ++++-- tests/unit/test_slurm_queue_show.py | 4 +- tests/unit/test_ssh_connection.py | 3 +- tests/unit/test_study_list_composer.py | 1 - 20 files changed, 252 insertions(+), 218 deletions(-) diff --git a/antareslauncher/antares_launcher.py b/antareslauncher/antares_launcher.py index 1e8abdb..a9a11ee 100644 --- a/antareslauncher/antares_launcher.py +++ b/antareslauncher/antares_launcher.py @@ -4,16 +4,10 @@ from antareslauncher.use_cases.check_remote_queue.check_queue_controller import ( CheckQueueController, ) -from antareslauncher.use_cases.create_list.study_list_composer import ( - StudyListComposer, -) -from antareslauncher.use_cases.kill_job.job_kill_controller import ( - JobKillController, -) +from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer +from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController from antareslauncher.use_cases.launch.launch_controller import LaunchController -from antareslauncher.use_cases.retrieve.retrieve_controller import ( - RetrieveController, -) +from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController from antareslauncher.use_cases.wait_loop_controller.wait_controller import ( WaitController, ) diff --git a/antareslauncher/config.py b/antareslauncher/config.py index 5fe9342..760d526 100644 --- a/antareslauncher/config.py +++ b/antareslauncher/config.py @@ -13,9 +13,9 @@ from antareslauncher import __author__, __project_name__, __version__ from antareslauncher.exceptions import ( + ConfigFileNotFoundError, InvalidConfigValueError, UnknownFileSuffixError, - ConfigFileNotFoundError, ) APP_NAME = __project_name__ diff --git a/antareslauncher/data_repo/data_repo_tinydb.py b/antareslauncher/data_repo/data_repo_tinydb.py index 3a34067..c40a281 100644 --- a/antareslauncher/data_repo/data_repo_tinydb.py +++ b/antareslauncher/data_repo/data_repo_tinydb.py @@ -2,9 +2,10 @@ from typing import List import tinydb +from tinydb import TinyDB, where + from antareslauncher.data_repo.idata_repo import IDataRepo from antareslauncher.study_dto import StudyDTO -from tinydb import TinyDB, where class DataRepoTinydb(IDataRepo): diff --git a/antareslauncher/main.py b/antareslauncher/main.py index 076c054..f05eced 100644 --- a/antareslauncher/main.py +++ b/antareslauncher/main.py @@ -1,7 +1,7 @@ import argparse -from dataclasses import dataclass +import dataclasses from pathlib import Path -from typing import List, Dict +from typing import Dict, List from antareslauncher import __version__ from antareslauncher.antares_launcher import AntaresLauncher @@ -19,9 +19,7 @@ from antareslauncher.use_cases.check_remote_queue.check_queue_controller import ( CheckQueueController, ) -from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import ( - SlurmQueueShow, -) +from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow from antareslauncher.use_cases.create_list.study_list_composer import ( StudyListComposer, StudyListComposerParameters, @@ -29,13 +27,9 @@ from antareslauncher.use_cases.generate_tree_structure.tree_structure_initializer import ( TreeStructureInitializer, ) -from antareslauncher.use_cases.kill_job.job_kill_controller import ( - JobKillController, -) +from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController from antareslauncher.use_cases.launch.launch_controller import LaunchController -from antareslauncher.use_cases.retrieve.retrieve_controller import ( - RetrieveController, -) +from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController from antareslauncher.use_cases.retrieve.state_updater import StateUpdater from antareslauncher.use_cases.wait_loop_controller.wait_controller import ( WaitController, @@ -65,14 +59,33 @@ class SshConnectionNotEstablishedException(Exception): # fmt: on -@dataclass +@dataclasses.dataclass class MainParameters: + """ + Represents the main parameters of the application. + + Attributes: + json_dir: Path to the directory where the JSON database will be stored. + default_json_db_name: The default JSON database name. + slurm_script_path: Path to the SLURM script used to launch studies (a Shell script). + antares_versions_on_remote_server: A list of available Antares Solver versions on the remote server. + default_ssh_dict: A dictionary containing the SSH settings read from `ssh_config.json`. + db_primary_key: The primary key for the database, default to "name". + partition: Extra `sbatch` option to request a specific partition for resource allocation. + If not specified, the default behavior is to allow the SLURM controller + to select the default partition as designated by the system administrator. + quality_of_service: Extra `sbatch` option to request a quality of service for the job. + QOS values can be defined for each user/cluster/account association in the Slurm database. + """ + json_dir: Path default_json_db_name: str slurm_script_path: str antares_versions_on_remote_server: List[str] default_ssh_dict: Dict db_primary_key: str + partition: str = "" + quality_of_service: str = "" def run_with( @@ -114,7 +127,11 @@ def run_with( connection = ssh_connection.SshConnection(config=ssh_dict) verify_connection(connection, display) - slurm_script_features = SlurmScriptFeatures(parameters.slurm_script_path) + slurm_script_features = SlurmScriptFeatures( + parameters.slurm_script_path, + partition=parameters.partition, + quality_of_service=parameters.quality_of_service, + ) environment = RemoteEnvironmentWithSlurm(connection, slurm_script_features) data_repo = DataRepoTinydb( database_file_path=db_json_file_path, db_primary_key=parameters.db_primary_key diff --git a/antareslauncher/parameters_reader.py b/antareslauncher/parameters_reader.py index c9189d6..0b94695 100644 --- a/antareslauncher/parameters_reader.py +++ b/antareslauncher/parameters_reader.py @@ -1,10 +1,11 @@ +import getpass import json import os.path +import typing as t from pathlib import Path -from typing import Dict, Any import yaml -import getpass + from antareslauncher.main import MainParameters from antareslauncher.main_option_parser import ParserParameters @@ -13,45 +14,53 @@ DEFAULT_JSON_DB_NAME = f"{getpass.getuser()}_antares_launcher_db.json" -class ParametersReader: - class EmptyFileException(TypeError): - pass +class MissingValueException(Exception): + def __init__(self, yaml_filepath: Path, key: str) -> None: + super().__init__(f"Missing key '{key}' in '{yaml_filepath}'") - class MissingValueException(KeyError): - pass +class ParametersReader: def __init__(self, json_ssh_conf: Path, yaml_filepath: Path): self.json_ssh_conf = json_ssh_conf - with open(Path(yaml_filepath)) as yaml_file: - self.yaml_content = yaml.load(yaml_file, Loader=yaml.FullLoader) or {} + with open(yaml_filepath) as yaml_file: + obj = yaml.load(yaml_file, Loader=yaml.FullLoader) or {} - # fmt: off - self._wait_time = self._get_compulsory_value("DEFAULT_WAIT_TIME") - self.time_limit = self._get_compulsory_value("DEFAULT_TIME_LIMIT") - self.n_cpu = self._get_compulsory_value("DEFAULT_N_CPU") - self.studies_in_dir = os.path.expanduser(self._get_compulsory_value("STUDIES_IN_DIR")) - self.log_dir = os.path.expanduser(self._get_compulsory_value("LOG_DIR")) - self.finished_dir = os.path.expanduser(self._get_compulsory_value("FINISHED_DIR")) - self.ssh_conf_file_is_required = self._get_compulsory_value("SSH_CONFIG_FILE_IS_REQUIRED") - # fmt: on + try: + self.default_wait_time = obj["DEFAULT_WAIT_TIME"] + self.time_limit = obj["DEFAULT_TIME_LIMIT"] + self.n_cpu = obj["DEFAULT_N_CPU"] + self.studies_in_dir = os.path.expanduser(obj["STUDIES_IN_DIR"]) + self.log_dir = os.path.expanduser(obj["LOG_DIR"]) + self.finished_dir = os.path.expanduser(obj["FINISHED_DIR"]) + self.ssh_conf_file_is_required = obj["SSH_CONFIG_FILE_IS_REQUIRED"] + default_ssh_configfile_name = obj["DEFAULT_SSH_CONFIGFILE_NAME"] + except KeyError as e: + raise MissingValueException(yaml_filepath, str(e)) from None - alt1, alt2 = self._get_ssh_conf_file_alts() - self.ssh_conf_alt1, self.ssh_conf_alt2 = alt1, alt2 - self.default_ssh_dict = self._get_ssh_dict_from_json() - self.remote_slurm_script_path = self._get_compulsory_value("SLURM_SCRIPT_PATH") - self.antares_versions = self._get_compulsory_value( - "ANTARES_VERSIONS_ON_REMOTE_SERVER" - ) - self.db_primary_key = self._get_compulsory_value("DB_PRIMARY_KEY") - self.json_dir = Path(self._get_compulsory_value("JSON_DIR")).expanduser() - self.json_db_name = self.yaml_content.get( - "DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME - ) + default_alternate1 = ALT1_PARENT / default_ssh_configfile_name + default_alternate2 = ALT2_PARENT / default_ssh_configfile_name + + alt1 = obj.get("SSH_CONFIGFILE_PATH_ALTERNATE1", default_alternate1) + alt2 = obj.get("SSH_CONFIGFILE_PATH_ALTERNATE2", default_alternate2) + + try: + self.ssh_conf_alt1 = alt1 + self.ssh_conf_alt2 = alt2 + self.default_ssh_dict = self._get_ssh_dict_from_json() + self.remote_slurm_script_path = obj["SLURM_SCRIPT_PATH"] + self.partition = obj.get("PARTITION", "") + self.quality_of_service = obj.get("QUALITY_OF_SERVICE", "") + self.antares_versions = obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"] + self.db_primary_key = obj["DB_PRIMARY_KEY"] + self.json_dir = Path(obj["JSON_DIR"]).expanduser() + self.json_db_name = obj.get("DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME) + except KeyError as e: + raise MissingValueException(yaml_filepath, str(e)) from None def get_parser_parameters(self): - options = ParserParameters( - default_wait_time=self._wait_time, + return ParserParameters( + default_wait_time=self.default_wait_time, default_time_limit=self.time_limit, default_n_cpu=self.n_cpu, studies_in_dir=self.studies_in_dir, @@ -61,48 +70,20 @@ def get_parser_parameters(self): ssh_configfile_path_alternate1=self.ssh_conf_alt1, ssh_configfile_path_alternate2=self.ssh_conf_alt2, ) - return options def get_main_parameters(self) -> MainParameters: - main_parameters = MainParameters( + return MainParameters( json_dir=self.json_dir, default_json_db_name=self.json_db_name, slurm_script_path=self.remote_slurm_script_path, + partition=self.partition, + quality_of_service=self.quality_of_service, antares_versions_on_remote_server=self.antares_versions, default_ssh_dict=self.default_ssh_dict, db_primary_key=self.db_primary_key, ) - return main_parameters - - def _get_ssh_conf_file_alts(self): - default_alternate1, default_alternate2 = self._get_default_alternate_values() - ssh_conf_alternate1 = self.yaml_content.get( - "SSH_CONFIGFILE_PATH_ALTERNATE1", - default_alternate1, - ) - ssh_conf_alternate2 = self.yaml_content.get( - "SSH_CONFIGFILE_PATH_ALTERNATE2", - default_alternate2, - ) - return ssh_conf_alternate1, ssh_conf_alternate2 - - def _get_default_alternate_values(self): - default_ssh_configfile_name = self._get_compulsory_value( - "DEFAULT_SSH_CONFIGFILE_NAME" - ) - default_alternate1 = ALT1_PARENT / default_ssh_configfile_name - default_alternate2 = ALT2_PARENT / default_ssh_configfile_name - return default_alternate1, default_alternate2 - - def _get_compulsory_value(self, key: str): - try: - value = self.yaml_content[key] - except KeyError as e: - print(f"missing value: {str(e)}") - raise ParametersReader.MissingValueException(e) from None - return value - def _get_ssh_dict_from_json(self) -> Dict[str, Any]: + def _get_ssh_dict_from_json(self) -> t.Dict[str, t.Any]: with open(self.json_ssh_conf) as ssh_connection_json: ssh_dict = json.load(ssh_connection_json) if "private_key_file" in ssh_dict: diff --git a/antareslauncher/remote_environnement/slurm_script_features.py b/antareslauncher/remote_environnement/slurm_script_features.py index 6d3de9c..b233bbf 100644 --- a/antareslauncher/remote_environnement/slurm_script_features.py +++ b/antareslauncher/remote_environnement/slurm_script_features.py @@ -1,9 +1,10 @@ -from dataclasses import dataclass +import dataclasses +import shlex from antareslauncher.study_dto import Modes -@dataclass +@dataclasses.dataclass class ScriptParametersDTO: study_dir_name: str input_zipfile_name: str @@ -19,77 +20,72 @@ class SlurmScriptFeatures: """Class that returns data related to the remote SLURM script Installed on the remote server""" - def __init__(self, slurm_script_path: str): - self.JOB_TYPE_PLACEHOLDER = "TO_BE_REPLACED_WITH_JOB_TYPE" - self.JOB_TYPE_ANTARES = "ANTARES" - self.JOB_TYPE_XPANSION_R = "ANTARES_XPANSION_R" - self.JOB_TYPE_XPANSION_CPP = "ANTARES_XPANSION_CPP" + def __init__( + self, + slurm_script_path: str, + *, + partition: str, + quality_of_service: str, + ): + """ + Initialize the slurm script feature. + + Args: + slurm_script_path: Path to the SLURM script used to launch studies (a Shell script). + partition: Request a specific partition for the resource allocation. + If not specified, the default behavior is to allow the slurm controller + to select the default partition as designated by the system administrator. + quality_of_service: Request a quality of service for the job. + QOS values can be defined for each user/cluster/account association in the Slurm database. + """ self.solver_script_path = slurm_script_path - self._script_params = None - self._remote_launch_dir = None + self.partition = partition + self.quality_of_service = quality_of_service def compose_launch_command( self, remote_launch_dir: str, script_params: ScriptParametersDTO, ) -> str: - """Compose and return the complete command to be executed to launch the Antares Solver script. - It includes the change of directory to remote_base_path + """ + Compose and return the complete command to be executed to launch the Antares Solver script. Args: - script_params: ScriptFeaturesDTO dataclass container for script parameters remote_launch_dir: remote directory where the script is launched + script_params: ScriptFeaturesDTO dataclass container for script parameters Returns: - str: the complete command to be executed to launch the including the change of directory to remote_base_path - + str: the complete command to be executed to launch a study on the SLURM server """ - self._script_params = script_params - self._remote_launch_dir = remote_launch_dir - complete_command = self._get_complete_command_with_placeholders() + # The following options can be added to the `sbatch` command + # if they are not empty (or null for integer options). + _opts = { + "--partition": self.partition, # non-empty string + "--qos": self.quality_of_service, # non-empty string + "--job-name": script_params.study_dir_name, # non-empty string + "--time": script_params.time_limit, # greater than 0 + "--cpus-per-task": script_params.n_cpu, # greater than 0 + } - if script_params.run_mode == Modes.antares: - complete_command = complete_command.replace( - self.JOB_TYPE_PLACEHOLDER, self.JOB_TYPE_ANTARES - ) - elif script_params.run_mode == Modes.xpansion_r: - complete_command = complete_command.replace( - self.JOB_TYPE_PLACEHOLDER, self.JOB_TYPE_XPANSION_R - ) - elif script_params.run_mode == Modes.xpansion_cpp: - complete_command = complete_command.replace( - self.JOB_TYPE_PLACEHOLDER, self.JOB_TYPE_XPANSION_CPP - ) + _job_type = { + Modes.antares: "ANTARES", # Mode for Antares Solver + Modes.xpansion_r: "ANTARES_XPANSION_R", # Mode for Old Xpansion implemented in R + Modes.xpansion_cpp: "ANTARES_XPANSION_CPP", # Mode for Xpansion implemented in C++ + }[script_params.run_mode] - return complete_command - - def _bash_options(self): - option1_zipfile_name = f' "{self._script_params.input_zipfile_name}"' - option2_antares_version = f" {self._script_params.antares_version}" - option3_job_type = f" {self.JOB_TYPE_PLACEHOLDER}" - option4_post_processing = f" {self._script_params.post_processing}" - option5_other_options = f" '{self._script_params.other_options}'" - bash_options = ( - option1_zipfile_name - + option2_antares_version - + option3_job_type - + option4_post_processing - + option5_other_options + # Construct the `sbatch` command + args = ["sbatch"] + args.extend(f"{k}={shlex.quote(str(v))}" for k, v in _opts.items() if v) + args.extend( + shlex.quote(arg) + for arg in [ + self.solver_script_path, + script_params.input_zipfile_name, + script_params.antares_version, + _job_type, + str(script_params.post_processing), + script_params.other_options, + ] ) - return bash_options - - def _sbatch_command_with_slurm_options(self): - call_sbatch = f"sbatch" - job_name = f' --job-name="{self._script_params.study_dir_name}"' - time_limit_opt = f" --time={self._script_params.time_limit}" - cpu_per_task = f" --cpus-per-task={self._script_params.n_cpu}" - slurm_options = call_sbatch + job_name + time_limit_opt + cpu_per_task - return slurm_options - - def _get_complete_command_with_placeholders(self): - change_dir = f"cd {self._remote_launch_dir}" - slurm_options = self._sbatch_command_with_slurm_options() - bash_options = self._bash_options() - submit_command = slurm_options + " " + self.solver_script_path + bash_options - complete_command = change_dir + " && " + submit_command - return complete_command + launch_cmd = f"cd {remote_launch_dir} && {' '.join(args)}" + return launch_cmd diff --git a/data/README.md b/data/README.md index 1895066..851382c 100644 --- a/data/README.md +++ b/data/README.md @@ -25,6 +25,8 @@ DB_PRIMARY_KEY : "name" DEFAULT_SSH_CONFIGFILE_NAME: "ssh_config.json" SSH_CONFIG_FILE_IS_REQUIRED : False SLURM_SCRIPT_PATH : "/opt/antares/launchAntares.sh" +PARTITION : "compute1" +QUALITY_OR_SERVICE : "user1_qos" ANTARES_VERSIONS_ON_REMOTE_SERVER : - "610" @@ -51,6 +53,11 @@ Below is a description of the parameters: - `DEFAULT_SSH_CONFIGFILE_NAME`: The default name of the SSH configuration file, it should be "ssh_config.json". - `SSH_CONFIG_FILE_IS_REQUIRED`: A flag indicating whether an SSH configuration file is required. - `SLURM_SCRIPT_PATH`: Path to the SLURM script used to launch studies (a Shell script). +- `PARTITION`: Extra `sbatch` option to request a specific partition for resource allocation. + If not specified, the default behavior is to allow the SLURM controller + to select the default partition as designated by the system administrator. +- `QUALITY_OF_SERVICE`: Extra `sbatch` option to request a quality of service for the job. + QOS values can be defined for each user/cluster/account association in the Slurm database. - `ANTARES_VERSIONS_ON_REMOTE_SERVER`: A list of strings representing the available Antares Solver versions on the remote server. ## SSH Configuration diff --git a/tests/integration/test_integration_check_queue_controller.py b/tests/integration/test_integration_check_queue_controller.py index 91e781f..144bac8 100644 --- a/tests/integration/test_integration_check_queue_controller.py +++ b/tests/integration/test_integration_check_queue_controller.py @@ -12,9 +12,7 @@ from antareslauncher.use_cases.check_remote_queue.check_queue_controller import ( CheckQueueController, ) -from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import ( - SlurmQueueShow, -) +from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow from antareslauncher.use_cases.retrieve.state_updater import StateUpdater @@ -23,7 +21,11 @@ def setup_method(self): self.connection_mock = mock.Mock(home_dir="path/to/home") self.connection_mock.username = "username" self.connection_mock.execute_command = mock.Mock(return_value=("", "")) - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) env_mock = RemoteEnvironmentWithSlurm( _connection=self.connection_mock, slurm_script_features=slurm_script_features, diff --git a/tests/integration/test_integration_job_kill_controller.py b/tests/integration/test_integration_job_kill_controller.py index e8c8731..f36039f 100644 --- a/tests/integration/test_integration_job_kill_controller.py +++ b/tests/integration/test_integration_job_kill_controller.py @@ -8,14 +8,16 @@ from antareslauncher.remote_environnement.slurm_script_features import ( SlurmScriptFeatures, ) -from antareslauncher.use_cases.kill_job.job_kill_controller import ( - JobKillController, -) +from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController class TestIntegrationJobKilController: def setup_method(self): - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) connection = mock.Mock(home_dir="path/to/home") env = RemoteEnvironmentWithSlurm(connection, slurm_script_features) self.job_kill_controller = JobKillController(env, mock.Mock(), repo=mock.Mock()) diff --git a/tests/integration/test_integration_launch_controller.py b/tests/integration/test_integration_launch_controller.py index 412e5fa..40a4fe4 100644 --- a/tests/integration/test_integration_launch_controller.py +++ b/tests/integration/test_integration_launch_controller.py @@ -20,7 +20,11 @@ class TestIntegrationLaunchController: @pytest.fixture(scope="function") def launch_controller(self): connection = mock.Mock(home_dir="path/to/home") - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) environment = RemoteEnvironmentWithSlurm(connection, slurm_script_features) study1 = mock.Mock() study1.zipfile_path = "filepath" @@ -35,15 +39,13 @@ def launch_controller(self): data_repo.get_list_of_studies = mock.Mock(return_value=[study1, study2]) file_manager = mock.Mock() display = DisplayTerminal() - launch_controller = LaunchController( + return LaunchController( repo=data_repo, env=environment, file_manager=file_manager, display=display, ) - return launch_controller - @pytest.mark.integration_test def test_upload_file__called_twice(self, launch_controller): """ @@ -71,7 +73,11 @@ def test_execute_command__called_with_the_correct_parameters( connection = mock.Mock() connection.execute_command = mock.Mock(return_value=["Submitted 42", ""]) connection.home_dir = "Submitted" - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) environment = RemoteEnvironmentWithSlurm(connection, slurm_script_features) study1 = StudyDTO( path="dummy_path", @@ -84,7 +90,7 @@ def test_execute_command__called_with_the_correct_parameters( home_dir = "Submitted" remote_base_path = ( - str(home_dir) + "/REMOTE_" + getpass.getuser() + "_" + socket.gethostname() + f"{home_dir}/REMOTE_{getpass.getuser()}_{socket.gethostname()}" ) zipfile_name = Path(study1.zipfile_path).name @@ -92,7 +98,7 @@ def test_execute_command__called_with_the_correct_parameters( post_processing = False other_options = "" bash_options = ( - f'"{zipfile_name}"' + f" {zipfile_name}" f" {study1.antares_version}" f" {job_type}" f" {post_processing}" @@ -100,11 +106,14 @@ def test_execute_command__called_with_the_correct_parameters( ) command = ( f"cd {remote_base_path} && " - f'sbatch --job-name="{Path(study1.path).name}"' + f"sbatch" + f" --partition={slurm_script_features.partition}" + f" --qos={slurm_script_features.quality_of_service}" + f" --job-name={Path(study1.path).name}" f" --time={study1.time_limit // 60}" f" --cpus-per-task={study1.n_cpu}" f" {environment.slurm_script_features.solver_script_path}" - f" {bash_options}" + f"{bash_options}" ) data_repo = mock.Mock() diff --git a/tests/unit/test_antares_launcher.py b/tests/unit/test_antares_launcher.py index f8188a8..4803f52 100644 --- a/tests/unit/test_antares_launcher.py +++ b/tests/unit/test_antares_launcher.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import PropertyMock, Mock +from unittest.mock import Mock, PropertyMock import pytest diff --git a/tests/unit/test_check_queue_controller.py b/tests/unit/test_check_queue_controller.py index d1f53d7..127294b 100644 --- a/tests/unit/test_check_queue_controller.py +++ b/tests/unit/test_check_queue_controller.py @@ -7,9 +7,7 @@ from antareslauncher.use_cases.check_remote_queue.check_queue_controller import ( CheckQueueController, ) -from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import ( - SlurmQueueShow, -) +from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow from antareslauncher.use_cases.retrieve.state_updater import StateUpdater diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 5850686..f33997a 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -7,6 +7,7 @@ import pytest import yaml + from antareslauncher.config import ( APP_AUTHOR, APP_NAME, diff --git a/tests/unit/test_job_kill_controller.py b/tests/unit/test_job_kill_controller.py index 7167577..f12e8c4 100644 --- a/tests/unit/test_job_kill_controller.py +++ b/tests/unit/test_job_kill_controller.py @@ -2,9 +2,7 @@ import pytest -from antareslauncher.use_cases.kill_job.job_kill_controller import ( - JobKillController, -) +from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController class TestJobKillController: diff --git a/tests/unit/test_main_option_parser.py b/tests/unit/test_main_option_parser.py index c31dfe8..e394688 100644 --- a/tests/unit/test_main_option_parser.py +++ b/tests/unit/test_main_option_parser.py @@ -5,8 +5,8 @@ from antareslauncher.main_option_parser import ( MainOptionParser, ParserParameters, + look_for_default_ssh_conf_file, ) -from antareslauncher.main_option_parser import look_for_default_ssh_conf_file class TestMainOptionParser: diff --git a/tests/unit/test_parameters_reader.py b/tests/unit/test_parameters_reader.py index 042f9f8..1ba30d6 100644 --- a/tests/unit/test_parameters_reader.py +++ b/tests/unit/test_parameters_reader.py @@ -3,13 +3,16 @@ from pathlib import Path import pytest +import yaml -from antareslauncher.parameters_reader import ParametersReader +from antareslauncher.parameters_reader import MissingValueException, ParametersReader class TestParametersReader: def setup_method(self): self.SLURM_SCRIPT_PATH = "/path/to/launchAntares_v1.1.3.sh" + self.PARTITION = "compute1" + self.QUALITY_OF_SERVICE = "user1_qos" self.SSH_CONFIG_FILE_IS_REQUIRED = False self.DEFAULT_SSH_CONFIGFILE_NAME = "ssh_config.json" self.DB_PRIMARY_KEY = "name" @@ -22,40 +25,43 @@ def setup_method(self): self.JSON_DIR = "JSON" self.ANTARES_SUPPORTED_VERSIONS = ["610", "700"] - self.yaml_compulsory_content = ( - f'LOG_DIR : "{self.LOG_DIR}"\n' - f'JSON_DIR : "{self.JSON_DIR}"\n' - f'STUDIES_IN_DIR : "{self.STUDIES_IN_DIR}"\n' - f'FINISHED_DIR : "{self.FINISHED_DIR}"\n' - f"DEFAULT_TIME_LIMIT : {self.DEFAULT_TIME_LIMIT}\n" - f"DEFAULT_N_CPU : {self.DEFAULT_N_CPU}\n" - f"DEFAULT_WAIT_TIME : {self.DEFAULT_WAIT_TIME}\n" - f'DB_PRIMARY_KEY : "{self.DB_PRIMARY_KEY}"\n' - f'DEFAULT_SSH_CONFIGFILE_NAME: "{self.DEFAULT_SSH_CONFIGFILE_NAME}"\n' - f"SSH_CONFIG_FILE_IS_REQUIRED : {self.SSH_CONFIG_FILE_IS_REQUIRED}\n" - f'SLURM_SCRIPT_PATH : "{self.SLURM_SCRIPT_PATH}"\n' - f"ANTARES_VERSIONS_ON_REMOTE_SERVER :\n" - f' - "{self.ANTARES_SUPPORTED_VERSIONS[0]}"\n' - f' - "{self.ANTARES_SUPPORTED_VERSIONS[1]}"\n' + self.yaml_compulsory_content = yaml.dump( + { + "LOG_DIR": self.LOG_DIR, + "JSON_DIR": self.JSON_DIR, + "STUDIES_IN_DIR": self.STUDIES_IN_DIR, + "FINISHED_DIR": self.FINISHED_DIR, + "DEFAULT_TIME_LIMIT": self.DEFAULT_TIME_LIMIT, + "DEFAULT_N_CPU": self.DEFAULT_N_CPU, + "DEFAULT_WAIT_TIME": self.DEFAULT_WAIT_TIME, + "DB_PRIMARY_KEY": self.DB_PRIMARY_KEY, + "DEFAULT_SSH_CONFIGFILE_NAME": self.DEFAULT_SSH_CONFIGFILE_NAME, + "SSH_CONFIG_FILE_IS_REQUIRED": self.SSH_CONFIG_FILE_IS_REQUIRED, + "SLURM_SCRIPT_PATH": self.SLURM_SCRIPT_PATH, + "PARTITION": self.PARTITION, + "QUALITY_OF_SERVICE": self.QUALITY_OF_SERVICE, + "ANTARES_VERSIONS_ON_REMOTE_SERVER": self.ANTARES_SUPPORTED_VERSIONS, + }, + default_flow_style=False, ) - self.DEFAULT_JSON_DB_NAME = "db_file.json" - self.yaml_opt_content = ( - f'DEFAULT_JSON_DB_NAME : "{self.DEFAULT_JSON_DB_NAME}\n' - f'DEFAULT_SSH_CONFIGFILE_NAME: "{self.DEFAULT_SSH_CONFIGFILE_NAME}"\n' + + self.yaml_opt_content = yaml.dump( + { + "DEFAULT_JSON_DB_NAME": "db_file.json", + "DEFAULT_SSH_CONFIGFILE_NAME": self.DEFAULT_SSH_CONFIGFILE_NAME, + }, + default_flow_style=False, ) - self.USER = "user" - self.HOST = "host" - self.KEY = "C:\\home\\hello" - self.KEY_PSWD = "hello" + self.json_dict = { - "username": self.USER, - "hostname": self.HOST, - "private_key_file": self.KEY, - "key_password": self.KEY_PSWD, + "username": "user", + "hostname": "host", + "private_key_file": "C:\\home\\hello", + "key_password": "hello", } @pytest.mark.unit_test - def test_ParametersReader_raises_exception_with_no_file(self, tmp_path): + def test_parameters_reader_raises_exception_with_no_file(self, tmp_path): with pytest.raises(FileNotFoundError): ParametersReader(Path(tmp_path), Path("empty.yaml")) @@ -64,7 +70,7 @@ def test_get_option_parameters_raises_exception_with_empty_file(self, tmp_path): empty_json = tmp_path / "dummy.json" empty_yaml = tmp_path / "empty.yaml" empty_yaml.write_text("") - with pytest.raises(ParametersReader.MissingValueException): + with pytest.raises(MissingValueException): ParametersReader(empty_json, empty_yaml).get_parser_parameters() @pytest.mark.unit_test @@ -72,7 +78,7 @@ def test_get_main_parameters_raises_exception_with_empty_file(self, tmp_path): empty_json = tmp_path / "dummy.json" empty_yaml = tmp_path / "empty.yaml" empty_yaml.write_text("") - with pytest.raises(ParametersReader.MissingValueException): + with pytest.raises(MissingValueException): ParametersReader(empty_json, empty_yaml).get_main_parameters() @pytest.mark.unit_test @@ -89,7 +95,7 @@ def test_get_option_parameters_raises_exception_if_params_are_missing( "DEFAULT_TIME_LIMIT : 172800\n" "DEFAULT_N_CPU : 2\n" ) - with pytest.raises(ParametersReader.MissingValueException): + with pytest.raises(MissingValueException): ParametersReader(empty_json, config_yaml).get_parser_parameters() @pytest.mark.unit_test @@ -104,7 +110,7 @@ def test_get_main_parameters_raises_exception_if_params_are_missing(self, tmp_pa "DEFAULT_TIME_LIMIT : 172800\n" "DEFAULT_N_CPU : 2\n" ) - with pytest.raises(ParametersReader.MissingValueException): + with pytest.raises(MissingValueException): ParametersReader(empty_json, config_yaml).get_main_parameters() @pytest.mark.unit_test @@ -149,6 +155,8 @@ def test_get_main_parameters_initializes_parameters_correctly(self, tmp_path): main_parameters.default_json_db_name == f"{getpass.getuser()}_antares_launcher_db.json" ) + assert main_parameters.partition == self.PARTITION + assert main_parameters.quality_of_service == self.QUALITY_OF_SERVICE assert main_parameters.db_primary_key == self.DB_PRIMARY_KEY assert not main_parameters.default_ssh_dict assert ( diff --git a/tests/unit/test_remote_environment_with_slurm.py b/tests/unit/test_remote_environment_with_slurm.py index 8116ec0..4c80536 100644 --- a/tests/unit/test_remote_environment_with_slurm.py +++ b/tests/unit/test_remote_environment_with_slurm.py @@ -64,7 +64,11 @@ def remote_env(self) -> RemoteEnvironmentWithSlurm: remote_home_dir = "remote_home_dir" connection = mock.Mock(home_dir="path/to/home") connection.home_dir = remote_home_dir - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) return RemoteEnvironmentWithSlurm(connection, slurm_script_features) @pytest.mark.unit_test @@ -80,7 +84,11 @@ def test_initialise_remote_path_calls_connection_make_dir_with_correct_arguments connection.home_dir = remote_home_dir connection.make_dir = mock.Mock(return_value=True) connection.check_file_not_empty = mock.Mock(return_value=True) - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) # when RemoteEnvironmentWithSlurm(connection, slurm_script_features) # then @@ -92,7 +100,11 @@ def test_when_constructor_is_called_and_remote_base_path_cannot_be_created_then_ ): # given connection = mock.Mock(home_dir="path/to/home") - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) # when connection.make_dir = mock.Mock(return_value=False) # then @@ -107,7 +119,11 @@ def test_when_constructor_is_called_then_connection_check_file_not_empty_is_call connection = mock.Mock(home_dir="path/to/home") connection.make_dir = mock.Mock(return_value=True) connection.check_file_not_empty = mock.Mock(return_value=True) - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) # when RemoteEnvironmentWithSlurm(connection, slurm_script_features) # then @@ -123,7 +139,11 @@ def test_when_constructor_is_called_and_connection_check_file_not_empty_is_false connection = mock.Mock(home_dir="path/to/home") connection.home_dir = remote_home_dir connection.make_dir = mock.Mock(return_value=True) - slurm_script_features = SlurmScriptFeatures("slurm_script_path") + slurm_script_features = SlurmScriptFeatures( + "slurm_script_path", + partition="fake_partition", + quality_of_service="user1_qos", + ) # when connection.check_file_not_empty = mock.Mock(return_value=False) # then @@ -689,11 +709,13 @@ def test_compose_launch_command( change_dir = f"cd {remote_env.remote_base_path}" reference_submit_command = ( f"sbatch" - f' --job-name="{Path(study.path).name}"' + " --partition=fake_partition" + " --qos=user1_qos" + f" --job-name={Path(study.path).name}" f" --time={study.time_limit // 60}" f" --cpus-per-task={study.n_cpu}" f" {filename_launch_script}" - f' "{Path(study.zipfile_path).name}"' + f" {Path(study.zipfile_path).name}" f" {study.antares_version}" f" {job_type}" f" {post_processing}" diff --git a/tests/unit/test_slurm_queue_show.py b/tests/unit/test_slurm_queue_show.py index f6b42a2..87a12d2 100644 --- a/tests/unit/test_slurm_queue_show.py +++ b/tests/unit/test_slurm_queue_show.py @@ -2,9 +2,7 @@ import pytest -from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import ( - SlurmQueueShow, -) +from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow @pytest.mark.unit_test diff --git a/tests/unit/test_ssh_connection.py b/tests/unit/test_ssh_connection.py index 4086db3..340d689 100644 --- a/tests/unit/test_ssh_connection.py +++ b/tests/unit/test_ssh_connection.py @@ -7,12 +7,13 @@ import paramiko import pytest +from paramiko.sftp_attr import SFTPAttributes + from antareslauncher.remote_environnement.ssh_connection import ( ConnectionFailedException, DownloadMonitor, SshConnection, ) -from paramiko.sftp_attr import SFTPAttributes LOGGER = DownloadMonitor.__module__ diff --git a/tests/unit/test_study_list_composer.py b/tests/unit/test_study_list_composer.py index b31ecf1..2ad06e0 100644 --- a/tests/unit/test_study_list_composer.py +++ b/tests/unit/test_study_list_composer.py @@ -37,7 +37,6 @@ def study_mock(self): def test_given_repo_when_get_list_of_studies_called_then_repo_get_list_of_studies_is_called( self, ): - # given repo_mock = mock.Mock() repo_mock.get_list_of_studies = mock.Mock()