Skip to content

Commit

Permalink
feat(parameters): handle the --partition and --qos parameters for…
Browse files Browse the repository at this point in the history
… the `sbatch` command (#58)

* feat(parameters): add partition in MainParameters

* refactor: simplify the implementation of `SlurmScriptFeatures.compose_launch_command`

* refactor(parameters-reader): simplify the implementation of the parameters reader

* refactor(parameters-reader): change the signature of the `SlurmScriptFeatures` constructor

* style(parameters-reader): sort imports and reformat source code

* feat(parameters-reader): make the `partition` parameter optional

* feat(parameters-reader): handle the `--qos` (quality of service) parameter for the `sbatch` command

---------

Co-authored-by: Laurent LAPORTE <[email protected]>
  • Loading branch information
MartinBelthle and laurent-laporte-pro authored Sep 18, 2023
1 parent 329dbee commit 2cacfa7
Show file tree
Hide file tree
Showing 20 changed files with 252 additions and 218 deletions.
12 changes: 3 additions & 9 deletions antareslauncher/antares_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
from antareslauncher.use_cases.check_remote_queue.check_queue_controller import (
CheckQueueController,
)
from antareslauncher.use_cases.create_list.study_list_composer import (
StudyListComposer,
)
from antareslauncher.use_cases.kill_job.job_kill_controller import (
JobKillController,
)
from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer
from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController
from antareslauncher.use_cases.launch.launch_controller import LaunchController
from antareslauncher.use_cases.retrieve.retrieve_controller import (
RetrieveController,
)
from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController
from antareslauncher.use_cases.wait_loop_controller.wait_controller import (
WaitController,
)
Expand Down
2 changes: 1 addition & 1 deletion antareslauncher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from antareslauncher import __author__, __project_name__, __version__
from antareslauncher.exceptions import (
ConfigFileNotFoundError,
InvalidConfigValueError,
UnknownFileSuffixError,
ConfigFileNotFoundError,
)

APP_NAME = __project_name__
Expand Down
3 changes: 2 additions & 1 deletion antareslauncher/data_repo/data_repo_tinydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import List

import tinydb
from tinydb import TinyDB, where

from antareslauncher.data_repo.idata_repo import IDataRepo
from antareslauncher.study_dto import StudyDTO
from tinydb import TinyDB, where


class DataRepoTinydb(IDataRepo):
Expand Down
43 changes: 30 additions & 13 deletions antareslauncher/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from dataclasses import dataclass
import dataclasses
from pathlib import Path
from typing import List, Dict
from typing import Dict, List

from antareslauncher import __version__
from antareslauncher.antares_launcher import AntaresLauncher
Expand All @@ -19,23 +19,17 @@
from antareslauncher.use_cases.check_remote_queue.check_queue_controller import (
CheckQueueController,
)
from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import (
SlurmQueueShow,
)
from antareslauncher.use_cases.check_remote_queue.slurm_queue_show import SlurmQueueShow
from antareslauncher.use_cases.create_list.study_list_composer import (
StudyListComposer,
StudyListComposerParameters,
)
from antareslauncher.use_cases.generate_tree_structure.tree_structure_initializer import (
TreeStructureInitializer,
)
from antareslauncher.use_cases.kill_job.job_kill_controller import (
JobKillController,
)
from antareslauncher.use_cases.kill_job.job_kill_controller import JobKillController
from antareslauncher.use_cases.launch.launch_controller import LaunchController
from antareslauncher.use_cases.retrieve.retrieve_controller import (
RetrieveController,
)
from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController
from antareslauncher.use_cases.retrieve.state_updater import StateUpdater
from antareslauncher.use_cases.wait_loop_controller.wait_controller import (
WaitController,
Expand Down Expand Up @@ -65,14 +59,33 @@ class SshConnectionNotEstablishedException(Exception):
# fmt: on


@dataclass
@dataclasses.dataclass
class MainParameters:
"""
Represents the main parameters of the application.
Attributes:
json_dir: Path to the directory where the JSON database will be stored.
default_json_db_name: The default JSON database name.
slurm_script_path: Path to the SLURM script used to launch studies (a Shell script).
antares_versions_on_remote_server: A list of available Antares Solver versions on the remote server.
default_ssh_dict: A dictionary containing the SSH settings read from `ssh_config.json`.
db_primary_key: The primary key for the database, default to "name".
partition: Extra `sbatch` option to request a specific partition for resource allocation.
If not specified, the default behavior is to allow the SLURM controller
to select the default partition as designated by the system administrator.
quality_of_service: Extra `sbatch` option to request a quality of service for the job.
QOS values can be defined for each user/cluster/account association in the Slurm database.
"""

json_dir: Path
default_json_db_name: str
slurm_script_path: str
antares_versions_on_remote_server: List[str]
default_ssh_dict: Dict
db_primary_key: str
partition: str = ""
quality_of_service: str = ""


def run_with(
Expand Down Expand Up @@ -114,7 +127,11 @@ def run_with(
connection = ssh_connection.SshConnection(config=ssh_dict)
verify_connection(connection, display)

slurm_script_features = SlurmScriptFeatures(parameters.slurm_script_path)
slurm_script_features = SlurmScriptFeatures(
parameters.slurm_script_path,
partition=parameters.partition,
quality_of_service=parameters.quality_of_service,
)
environment = RemoteEnvironmentWithSlurm(connection, slurm_script_features)
data_repo = DataRepoTinydb(
database_file_path=db_json_file_path, db_primary_key=parameters.db_primary_key
Expand Down
109 changes: 45 additions & 64 deletions antareslauncher/parameters_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import getpass
import json
import os.path
import typing as t
from pathlib import Path
from typing import Dict, Any

import yaml
import getpass

from antareslauncher.main import MainParameters
from antareslauncher.main_option_parser import ParserParameters

Expand All @@ -13,45 +14,53 @@
DEFAULT_JSON_DB_NAME = f"{getpass.getuser()}_antares_launcher_db.json"


class ParametersReader:
class EmptyFileException(TypeError):
pass
class MissingValueException(Exception):
def __init__(self, yaml_filepath: Path, key: str) -> None:
super().__init__(f"Missing key '{key}' in '{yaml_filepath}'")

class MissingValueException(KeyError):
pass

class ParametersReader:
def __init__(self, json_ssh_conf: Path, yaml_filepath: Path):
self.json_ssh_conf = json_ssh_conf

with open(Path(yaml_filepath)) as yaml_file:
self.yaml_content = yaml.load(yaml_file, Loader=yaml.FullLoader) or {}
with open(yaml_filepath) as yaml_file:
obj = yaml.load(yaml_file, Loader=yaml.FullLoader) or {}

# fmt: off
self._wait_time = self._get_compulsory_value("DEFAULT_WAIT_TIME")
self.time_limit = self._get_compulsory_value("DEFAULT_TIME_LIMIT")
self.n_cpu = self._get_compulsory_value("DEFAULT_N_CPU")
self.studies_in_dir = os.path.expanduser(self._get_compulsory_value("STUDIES_IN_DIR"))
self.log_dir = os.path.expanduser(self._get_compulsory_value("LOG_DIR"))
self.finished_dir = os.path.expanduser(self._get_compulsory_value("FINISHED_DIR"))
self.ssh_conf_file_is_required = self._get_compulsory_value("SSH_CONFIG_FILE_IS_REQUIRED")
# fmt: on
try:
self.default_wait_time = obj["DEFAULT_WAIT_TIME"]
self.time_limit = obj["DEFAULT_TIME_LIMIT"]
self.n_cpu = obj["DEFAULT_N_CPU"]
self.studies_in_dir = os.path.expanduser(obj["STUDIES_IN_DIR"])
self.log_dir = os.path.expanduser(obj["LOG_DIR"])
self.finished_dir = os.path.expanduser(obj["FINISHED_DIR"])
self.ssh_conf_file_is_required = obj["SSH_CONFIG_FILE_IS_REQUIRED"]
default_ssh_configfile_name = obj["DEFAULT_SSH_CONFIGFILE_NAME"]
except KeyError as e:
raise MissingValueException(yaml_filepath, str(e)) from None

alt1, alt2 = self._get_ssh_conf_file_alts()
self.ssh_conf_alt1, self.ssh_conf_alt2 = alt1, alt2
self.default_ssh_dict = self._get_ssh_dict_from_json()
self.remote_slurm_script_path = self._get_compulsory_value("SLURM_SCRIPT_PATH")
self.antares_versions = self._get_compulsory_value(
"ANTARES_VERSIONS_ON_REMOTE_SERVER"
)
self.db_primary_key = self._get_compulsory_value("DB_PRIMARY_KEY")
self.json_dir = Path(self._get_compulsory_value("JSON_DIR")).expanduser()
self.json_db_name = self.yaml_content.get(
"DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME
)
default_alternate1 = ALT1_PARENT / default_ssh_configfile_name
default_alternate2 = ALT2_PARENT / default_ssh_configfile_name

alt1 = obj.get("SSH_CONFIGFILE_PATH_ALTERNATE1", default_alternate1)
alt2 = obj.get("SSH_CONFIGFILE_PATH_ALTERNATE2", default_alternate2)

try:
self.ssh_conf_alt1 = alt1
self.ssh_conf_alt2 = alt2
self.default_ssh_dict = self._get_ssh_dict_from_json()
self.remote_slurm_script_path = obj["SLURM_SCRIPT_PATH"]
self.partition = obj.get("PARTITION", "")
self.quality_of_service = obj.get("QUALITY_OF_SERVICE", "")
self.antares_versions = obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]
self.db_primary_key = obj["DB_PRIMARY_KEY"]
self.json_dir = Path(obj["JSON_DIR"]).expanduser()
self.json_db_name = obj.get("DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME)
except KeyError as e:
raise MissingValueException(yaml_filepath, str(e)) from None

def get_parser_parameters(self):
options = ParserParameters(
default_wait_time=self._wait_time,
return ParserParameters(
default_wait_time=self.default_wait_time,
default_time_limit=self.time_limit,
default_n_cpu=self.n_cpu,
studies_in_dir=self.studies_in_dir,
Expand All @@ -61,48 +70,20 @@ def get_parser_parameters(self):
ssh_configfile_path_alternate1=self.ssh_conf_alt1,
ssh_configfile_path_alternate2=self.ssh_conf_alt2,
)
return options

def get_main_parameters(self) -> MainParameters:
main_parameters = MainParameters(
return MainParameters(
json_dir=self.json_dir,
default_json_db_name=self.json_db_name,
slurm_script_path=self.remote_slurm_script_path,
partition=self.partition,
quality_of_service=self.quality_of_service,
antares_versions_on_remote_server=self.antares_versions,
default_ssh_dict=self.default_ssh_dict,
db_primary_key=self.db_primary_key,
)
return main_parameters

def _get_ssh_conf_file_alts(self):
default_alternate1, default_alternate2 = self._get_default_alternate_values()
ssh_conf_alternate1 = self.yaml_content.get(
"SSH_CONFIGFILE_PATH_ALTERNATE1",
default_alternate1,
)
ssh_conf_alternate2 = self.yaml_content.get(
"SSH_CONFIGFILE_PATH_ALTERNATE2",
default_alternate2,
)
return ssh_conf_alternate1, ssh_conf_alternate2

def _get_default_alternate_values(self):
default_ssh_configfile_name = self._get_compulsory_value(
"DEFAULT_SSH_CONFIGFILE_NAME"
)
default_alternate1 = ALT1_PARENT / default_ssh_configfile_name
default_alternate2 = ALT2_PARENT / default_ssh_configfile_name
return default_alternate1, default_alternate2

def _get_compulsory_value(self, key: str):
try:
value = self.yaml_content[key]
except KeyError as e:
print(f"missing value: {str(e)}")
raise ParametersReader.MissingValueException(e) from None
return value

def _get_ssh_dict_from_json(self) -> Dict[str, Any]:
def _get_ssh_dict_from_json(self) -> t.Dict[str, t.Any]:
with open(self.json_ssh_conf) as ssh_connection_json:
ssh_dict = json.load(ssh_connection_json)
if "private_key_file" in ssh_dict:
Expand Down
Loading

0 comments on commit 2cacfa7

Please sign in to comment.