diff --git a/submitit/__init__.py b/submitit/__init__.py index 3051ac6..6c3cde9 100644 --- a/submitit/__init__.py +++ b/submitit/__init__.py @@ -18,4 +18,4 @@ from .slurm.slurm import SlurmExecutor as SlurmExecutor from .slurm.slurm import SlurmJob as SlurmJob -__version__ = "1.4.5" +__version__ = "1.4.6" diff --git a/submitit/helpers.py b/submitit/helpers.py index c94f0d1..a3292d7 100644 --- a/submitit/helpers.py +++ b/submitit/helpers.py @@ -24,6 +24,7 @@ from .core.utils import CommandFunction as CommandFunction # noqa from .core.utils import DelayedSubmission as DelayedSubmission # noqa from .core.utils import environment_variables as environment_variables # noqa +from .slurm.slurm import SlurmJobEnvironment class Checkpointable: @@ -331,7 +332,14 @@ def __init__(self) -> None: >>> torch.distributed.init_process_group(backend="nccl") >>> print(f"master: {dist_env.master_addr}:{dist_env.master_port}") """ - self._job_env = JobEnvironment() + try: + self._job_env = JobEnvironment() + except RuntimeError as e: + if SlurmJobEnvironment._env["job_id"] in os.environ: + # identified a slurm env without submitit, so let's use it + self._job_env = SlurmJobEnvironment() + else: + raise e self.master_addr = self._job_env.hostnames[0] self.master_port = self._get_master_port() self.rank = self._job_env.global_rank @@ -339,6 +347,11 @@ def __init__(self) -> None: self.local_rank = self._job_env.local_rank self.local_world_size = self._job_env.num_tasks // self._job_env.num_nodes + def __repr__(self) -> str: + cls = self.__class__.__name__ + env = sorted(f"{name}={val}" for name, val in self.__dict__.items() if not name.startswith("_")) + return f"{cls}<{','.join(env)}>" + def _get_master_port(self) -> int: # MIN_MASTER_PORT, MAX_MASTER_PORT = (1023, 65535) MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) diff --git a/submitit/test_helpers.py b/submitit/test_helpers.py index c23722e..b8127d9 100644 --- a/submitit/test_helpers.py +++ b/submitit/test_helpers.py @@ -136,6 +136,15 @@ def _get_env() -> tp.Dict[str, str]: return {x: y for x, y in os.environ.items() if x.startswith(("SLURM_", "SUBMITIT_"))} +def test_torch_distrib_env() -> None: + with pytest.raises(RuntimeError): + env = helpers.TorchDistributedEnvironment() + with utils.environment_variables(SLURM_JOB_ID=12): + env = helpers.TorchDistributedEnvironment() + # port is deduced from job id + assert env.master_port == 58811 + + def test_clean_env() -> None: base = _get_env() with utils.environment_variables(SLURM_BLUBLU=12, SUBMITIT_BLUBLU=12):