Skip to content

Commit

Permalink
Merge pull request #704 from AntaresSimulatorTeam/dev
Browse files Browse the repository at this point in the history
v2.2.2
  • Loading branch information
pl-buiquang authored Jan 11, 2022
2 parents 97c2cad + bab4c4d commit 9c430d0
Show file tree
Hide file tree
Showing 31 changed files with 472 additions and 167 deletions.
2 changes: 1 addition & 1 deletion antarest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.1"
__version__ = "2.2.2"

from pathlib import Path

Expand Down
7 changes: 7 additions & 0 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class ExternalAuthConfig:

url: Optional[str] = None
default_group_role: RoleType = RoleType.READER
add_ext_groups: bool = False
group_mapping: Dict[str, str] = field(default_factory=dict)

@staticmethod
def from_dict(data: JSON) -> "ExternalAuthConfig":
Expand All @@ -28,6 +30,8 @@ def from_dict(data: JSON) -> "ExternalAuthConfig":
default_group_role=RoleType(
data.get("default_group_role", RoleType.READER.value)
),
add_ext_groups=data.get("add_ext_groups", False),
group_mapping=data.get("group_mapping", {}),
)


Expand Down Expand Up @@ -339,6 +343,9 @@ def from_dict(data: JSON, res: Optional[Path] = None) -> "Config":
tasks=TaskConfig.from_dict(data["tasks"])
if "tasks" in data
else TaskConfig(),
server=ServerConfig.from_dict(data["server"])
if "server" in data
else ServerConfig(),
)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion antarest/core/tasks/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _cancel_task(self, task_id: str, dispatch: bool = False) -> None:
task = self.repo.get_or_raise(task_id)
if task_id in self.tasks:
self.tasks[task_id].cancel()
task.status = TaskStatus.CANCELLED
task.status = TaskStatus.CANCELLED.value
self.repo.save(task)
elif dispatch:
self.event_bus.push(
Expand Down
115 changes: 69 additions & 46 deletions antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import shutil
import tempfile
import threading
import time
from copy import deepcopy
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
study_service: StudyService,
callbacks: LauncherCallbacks,
event_bus: IEventBus,
use_private_workspace: bool = True,
) -> None:
super().__init__(config, study_service, callbacks)
if config.launcher.slurm is None:
Expand All @@ -71,11 +73,20 @@ def __init__(
self.thread: Optional[threading.Thread] = None
self.job_id_to_study_id: Dict[str, str] = {}
self._check_config()
self.log_tail_manager = LogTailManager(
self.slurm_config.local_workspace
self.antares_launcher_lock = threading.Lock()
self.local_workspace = (
Path(tempfile.mkdtemp(dir=str(self.slurm_config.local_workspace)))
if use_private_workspace
else Path(self.slurm_config.local_workspace)
)

self.log_tail_manager = LogTailManager(self.local_workspace)
self.launcher_args = self._init_launcher_arguments(
self.local_workspace
)
self.launcher_params = self._init_launcher_parameters(
self.local_workspace
)
self.launcher_args = self._init_launcher_arguments()
self.launcher_params = self._init_launcher_parameters()
self.data_repo_tinydb = DataRepoTinydb(
database_file_path=(
self.launcher_params.json_dir
Expand Down Expand Up @@ -104,17 +115,30 @@ def stop(self) -> None:
self.check_state = False
self.thread = None

def _init_launcher_arguments(self) -> argparse.Namespace:
def _init_launcher_arguments(
self, local_workspace: Optional[Path] = None
) -> argparse.Namespace:
main_options_parameters = ParserParameters(
default_wait_time=self.slurm_config.default_wait_time,
default_time_limit=self.slurm_config.default_time_limit,
default_n_cpu=self.slurm_config.default_n_cpu,
studies_in_dir=str(
(Path(self.slurm_config.local_workspace) / "STUDIES_IN")
(
Path(local_workspace or self.slurm_config.local_workspace)
/ "STUDIES_IN"
)
),
log_dir=str(
(
Path(local_workspace or self.slurm_config.local_workspace)
/ "LOGS"
)
),
log_dir=str((Path(self.slurm_config.local_workspace) / "LOGS")),
finished_dir=str(
(Path(self.slurm_config.local_workspace) / "OUTPUT")
(
Path(local_workspace or self.slurm_config.local_workspace)
/ "OUTPUT"
)
),
ssh_config_file_is_required=False,
ssh_configfile_path_alternate1=None,
Expand All @@ -135,9 +159,11 @@ def _init_launcher_arguments(self) -> argparse.Namespace:
arguments.post_processing = False
return arguments

def _init_launcher_parameters(self) -> MainParameters:
def _init_launcher_parameters(
self, local_workspace: Optional[Path] = None
) -> MainParameters:
main_parameters = MainParameters(
json_dir=self.slurm_config.local_workspace,
json_dir=local_workspace or self.slurm_config.local_workspace,
default_json_db_name=self.slurm_config.default_json_db_name,
slurm_script_path=self.slurm_config.slurm_script_path,
antares_versions_on_remote_server=self.slurm_config.antares_versions_on_remote_server,
Expand All @@ -154,10 +180,8 @@ def _init_launcher_parameters(self) -> MainParameters:
return main_parameters

def _delete_study(self, study_path: Path) -> None:
if (
self.slurm_config.local_workspace.absolute()
in study_path.absolute().parents
):
logger.info(f"Deleting study export at {study_path}")
if self.local_workspace.absolute() in study_path.absolute().parents:
if study_path.exists():
shutil.rmtree(study_path)

Expand All @@ -169,28 +193,22 @@ def _import_study_output(
self._import_xpansion_result(job_id, study_id)
return self.storage_service.import_output(
study_id,
self.slurm_config.local_workspace / "OUTPUT" / job_id / "output",
self.local_workspace / "OUTPUT" / job_id / "output",
params=RequestParameters(DEFAULT_ADMIN_USER),
)

def _import_xpansion_result(self, job_id: str, study_id: str) -> None:
output_path = (
self.slurm_config.local_workspace / "OUTPUT" / job_id / "output"
)
output_path = self.local_workspace / "OUTPUT" / job_id / "output"
if output_path.exists() and len(os.listdir(output_path)) == 1:
output_path = output_path / os.listdir(output_path)[0]
shutil.copytree(
self.slurm_config.local_workspace
/ "OUTPUT"
/ job_id
/ "input"
/ "links",
self.local_workspace / "OUTPUT" / job_id / "input" / "links",
output_path / "updated_links",
)
study = self.storage_service.get_study(study_id)
if int(study.version) < 800:
shutil.copytree(
self.slurm_config.local_workspace
self.local_workspace
/ "OUTPUT"
/ job_id
/ "user"
Expand All @@ -202,11 +220,12 @@ def _import_xpansion_result(self, job_id: str, study_id: str) -> None:

def _check_studies_state(self) -> None:
try:
run_with(
arguments=self.launcher_args,
parameters=self.launcher_params,
show_banner=False,
)
with self.antares_launcher_lock:
run_with(
arguments=self.launcher_args,
parameters=self.launcher_params,
show_banner=False,
)
except Exception as e:
logger.info("Could not get data on remote server", exc_info=e)

Expand Down Expand Up @@ -291,7 +310,7 @@ def _get_log_path(

def _clean_local_workspace(self) -> None:
logger.info("Cleaning up slurm workspace")
local_workspace = self.slurm_config.local_workspace
local_workspace = self.local_workspace
for filename in os.listdir(local_workspace):
file_path = os.path.join(local_workspace, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
Expand All @@ -315,10 +334,9 @@ def _assert_study_version_is_supported(
)

def _clean_up_study(self, launch_id: str) -> None:
logger.info(f"Cleaning up study with launch_id {launch_id}")
self.data_repo_tinydb.remove_study(launch_id)
self._delete_study(
self.slurm_config.local_workspace / "OUTPUT" / launch_id
)
self._delete_study(self.local_workspace / "OUTPUT" / launch_id)
del self.job_id_to_study_id[launch_id]

def _run_study(
Expand All @@ -335,18 +353,21 @@ def _run_study(

try:
# export study
self.storage_service.export_study_flat(
study_uuid, params, study_path, outputs=False
)
with self.antares_launcher_lock:
self.storage_service.export_study_flat(
study_uuid, params, study_path, outputs=False
)

self._assert_study_version_is_supported(study_uuid, params)
self._assert_study_version_is_supported(study_uuid, params)

launcher_args = self._check_and_apply_launcher_params(
launcher_params
)
run_with(
launcher_args, self.launcher_params, show_banner=False
)
logger.info("Study exported and run with launcher")

launcher_args = self._check_and_apply_launcher_params(
launcher_params
)
run_with(
launcher_args, self.launcher_params, show_banner=False
)
self.callbacks.update_status(
str(launch_uuid), JobStatus.RUNNING, None, None
)
Expand Down Expand Up @@ -423,10 +444,12 @@ def kill_job(self, job_id: str) -> None:
for study in self.data_repo_tinydb.get_list_of_studies():
if study.name == job_id:
launcher_args.job_id_to_kill = study.job_id
run_with(
launcher_args, self.launcher_params, show_banner=False
)
with self.antares_launcher_lock:
run_with(
launcher_args, self.launcher_params, show_banner=False
)
return
# todo kill job should be sent to other slurm launcher so that is correctly killed
logger.warning(
"Failed to retrieve job id in antares launcher database"
)
Expand Down
5 changes: 5 additions & 0 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from http import HTTPStatus
from typing import List, Optional, cast, Dict
Expand Down Expand Up @@ -32,6 +33,8 @@
create_permission_from_study,
)

logger = logging.getLogger(__name__)


class JobNotFound(HTTPException):
def __init__(self) -> None:
Expand Down Expand Up @@ -79,6 +82,7 @@ def update(
msg: Optional[str],
output_id: Optional[str],
) -> None:
logger.info(f"Setting study with job id {job_uuid} status to {status}")
job_result = self.job_result_repository.get(job_uuid)
if job_result is not None:
job_result.job_status = status
Expand All @@ -97,6 +101,7 @@ def update(
channel=EventChannelDirectory.JOB_STATUS + job_result.id,
)
)
logger.info(f"Study status set")

def _assert_launcher_is_initialized(self, launcher: str) -> None:
if launcher not in self.launchers:
Expand Down
21 changes: 17 additions & 4 deletions antarest/login/ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from antarest.core.config import Config
from antarest.core.model import JSON
from antarest.login.model import UserLdap, Group, Role
from antarest.login.model import UserLdap, Group, Role, GroupDTO
from antarest.login.repository import (
UserLdapRepository,
RoleRepository,
Expand Down Expand Up @@ -75,6 +75,8 @@ def __init__(
roles: RoleRepository,
):
self.url = config.security.external_auth.url
self.group_mapping = config.security.external_auth.group_mapping
self.add_ext_groups = config.security.external_auth.add_ext_groups
self.users = users
self.groups = groups
self.roles = roles
Expand Down Expand Up @@ -131,11 +133,22 @@ def _save_or_update(self, user: ExternalUser) -> UserLdap:

existing_roles = self.roles.get_all_by_user(existing_user.id)

grouprole_to_add = [
(group_id, user.groups[group_id])
mapped_groups = [
Group(
id=self.group_mapping.get(group_id, group_id),
name=user.groups[group_id],
)
for group_id in user.groups
if group_id not in [role.group_id for role in existing_roles]
if self.add_ext_groups or group_id in self.group_mapping.keys()
]
grouprole_to_add = [
(group.id, (self.groups.get(group.id) or group).name)
for group in mapped_groups
if group.id not in [role.group_id for role in existing_roles]
]
logger.info(
f"Saving new groups from external user {grouprole_to_add} from received {user.groups}"
)
for group_id, group_name in grouprole_to_add:
logger.info(
"Adding user %s role %s to group %s (%s) following ldap sync",
Expand Down
3 changes: 3 additions & 0 deletions antarest/login/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def __eq__(self, other: Any) -> bool:

return bool(self.id == other.id and self.name == other.name)

def __repr__(self) -> str:
return f"Group(id={self.id}, name={self.name})"


@dataclass
class Role(Base): # type: ignore
Expand Down
Loading

0 comments on commit 9c430d0

Please sign in to comment.