diff --git a/README.md b/README.md index 370aac2..0bf1ca9 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ The key difference with `submitit` is that `dask.distributed` distributes the jo ## Contributors -By chronological order: Jérémy Rapin, Louis Martin, Lowik Chanussot, Lucas Hosseini, Fabio Petroni, Francisco Massa, Guillaume Wenzek, Thibaut Lavril, Vinayak Tantia, Andrea Vedaldi, Max Nickel, Quentin Duval (feel free to [contribute](.github/CONTRIBUTING.md) and add your name ;) ) +By chronological order: Jérémy Rapin, Louis Martin, Lowik Chanussot, Lucas Hosseini, Fabio Petroni, Francisco Massa, Guillaume Wenzek, Thibaut Lavril, Vinayak Tantia, Andrea Vedaldi, Max Nickel, Quentin Duval, Rushil Patel (feel free to [contribute](.github/CONTRIBUTING.md) and add your name ;) ) ## License diff --git a/submitit/core/core.py b/submitit/core/core.py index f5bbeed..96cc4e8 100644 --- a/submitit/core/core.py +++ b/submitit/core/core.py @@ -862,7 +862,9 @@ class PicklingExecutor(Executor): Note: during a batch submission, this is the estimated sum of all pickles. """ - def __init__(self, folder: tp.Union[Path, str], max_num_timeout: int = 3, max_pickle_size_gb: float = 1.0) -> None: + def __init__( + self, folder: tp.Union[Path, str], max_num_timeout: int = 3, max_pickle_size_gb: float = 1.0 + ) -> None: super().__init__(folder) self.max_num_timeout = max_num_timeout self.max_pickle_size_gb = max_pickle_size_gb @@ -900,7 +902,7 @@ def _internal_process_submissions( if check_size: # warn if the dumped objects are too big check_size = False num = len(delayed_submissions) - size = pickle_path.stat().st_size / 1024**3 + size = pickle_path.stat().st_size / 1024**3 if num * size > self.max_pickle_size_gb: pickle_path.unlink() msg = f"Submitting an estimated {num} x {size:.2f} > {self.max_pickle_size_gb}GB of objects " diff --git a/submitit/core/plugins.py b/submitit/core/plugins.py index 70c7463..d25ff9a 100644 --- a/submitit/core/plugins.py +++ b/submitit/core/plugins.py @@ -6,6 +6,7 @@ import functools import os +from importlib import metadata from typing import TYPE_CHECKING, List, Mapping, Tuple, Type from ..core import logger @@ -16,42 +17,61 @@ from ..core.job_environment import JobEnvironment +def _iter_submitit_entrypoints(): + """Return an iterable of EntryPoint objects in the 'submitit' group + compatible with Python 3.8+ and the backport.""" + + # 3.10+ API: EntryPoints with .select + eps = metadata.entry_points() + if hasattr(eps, "select"): + return eps.select(group="submitit") + + # importlib_metadata backport newer signature: entry_points("submitit") + try: + return metadata.entry_points()["submitit"] + except TypeError: + pass # older API; fall through + + # 3.8/3.9 legacy: mapping {group: [EntryPoint, ...]} + if hasattr(eps, "get"): + return eps.get("submitit", []) + + # old style (should in theory never get here if 3.8+): flat iterable; filter by .group + return [ep for ep in eps if getattr(ep, "group", None) == "submitit"] + + @functools.lru_cache() def _get_plugins() -> Tuple[List[Type["Executor"]], List["JobEnvironment"]]: # pylint: disable=cyclic-import,import-outside-toplevel - # Load dynamically to avoid import cycle - # pkg_resources goes through all modules on import. - import pkg_resources - from ..local import debug, local from ..slurm import slurm - # TODO: use sys.modules.keys() and importlib.resources to find the files - # We load both kind of entry points at the same time because we have to go through all module files anyway. executors: List[Type["Executor"]] = [slurm.SlurmExecutor, local.LocalExecutor, debug.DebugExecutor] job_envs = [slurm.SlurmJobEnvironment(), local.LocalJobEnvironment(), debug.DebugJobEnvironment()] - for entry_point in pkg_resources.iter_entry_points("submitit"): + for entry_point in _iter_submitit_entrypoints(): if entry_point.name not in ("executor", "job_environment"): - logger.warning(f"Found unknown entry point in package {entry_point.module_name}: {entry_point}") + logger.warning(f"{entry_point.name} = {entry_point.value}") continue + module_name = entry_point.value.split(":", 1)[0] try: # call `load` rather than `resolve`. # `load` also checks the module and its dependencies are correctly installed. - cls = entry_point.load() + obj = entry_point.load() except Exception as e: # This may happen if the plugin haven't been correctly installed - logger.exception(f"Failed to load submitit plugin '{entry_point.module_name}': {e}") + logger.exception(f"Failed to load submitit plugin '{module_name}': {e}") continue if entry_point.name == "executor": - executors.append(cls) + executors.append(obj) else: try: - job_env = cls() + job_env = obj() except Exception as e: + name = getattr(obj, "name", getattr(obj, "__name__", str(obj))) logger.exception( - f"Failed to init JobEnvironment '{cls.name}' ({cls}) from submitit plugin '{entry_point.module_name}': {e}" + f"Failed to init JobEnvironment '{name}' ({obj}) from submitit plugin '{module_name}': {e}" ) continue job_envs.append(job_env) diff --git a/submitit/core/test_core.py b/submitit/core/test_core.py index d36c843..2132f4f 100644 --- a/submitit/core/test_core.py +++ b/submitit/core/test_core.py @@ -250,6 +250,7 @@ def test_max_pickle_size_gb(tmp_path: Path) -> None: with pytest.raises(RuntimeError): _ = executor.submit(_three_time, 4) + if __name__ == "__main__": args, kwargs = [], {} # oversimplisitic parser for argv in sys.argv[1:]: diff --git a/submitit/core/test_plugins.py b/submitit/core/test_plugins.py index 2ec220b..4b7194f 100644 --- a/submitit/core/test_plugins.py +++ b/submitit/core/test_plugins.py @@ -4,12 +4,12 @@ # LICENSE file in the root directory of this source tree. # +import importlib import logging import re import typing as tp from pathlib import Path -import pkg_resources import pytest from . import core, plugins @@ -69,17 +69,23 @@ def __init__(self, tmp_path: Path, monkeypatch): self.monkeypatch = monkeypatch def add_plugin(self, name: str, entry_points: str, init: str): - plugin = self.tmp_path / name - plugin.mkdir(mode=0o777) - plugin_egg = plugin.with_suffix(".egg-info") - plugin_egg.mkdir(mode=0o777) - - (plugin_egg / "entry_points.txt").write_text(entry_points) - (plugin / "__init__.py").write_text(init) - - # also fix pkg_resources since it already has loaded old packages in other tests. - working_set = pkg_resources.WorkingSet([str(self.tmp_path)]) - self.monkeypatch.setattr(pkg_resources, "iter_entry_points", working_set.iter_entry_points) + # Extract version from init string if available + version = "0.0.0" # default fallback - this bit doesn't matter for testing + version_match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', init) + if version_match: + version = version_match.group(1) + + pkg_dir = self.tmp_path / name + pkg_dir.mkdir(mode=0o777) + (pkg_dir / "__init__.py").write_text(init) + + dist = self.tmp_path / f"{name}-{version}.dist-info" + dist.mkdir(mode=0o777) + (dist / "METADATA").write_text(f"Name: {name}\nVersion: {version}\n") + (dist / "entry_points.txt").write_text(entry_points) + + # Make sure Python and metadata see the new files + importlib.invalidate_caches() def __enter__(self) -> None: _clear_plugin_cache() diff --git a/submitit/local/local.py b/submitit/local/local.py index 6165057..7c9db1b 100644 --- a/submitit/local/local.py +++ b/submitit/local/local.py @@ -147,10 +147,11 @@ class LocalExecutor(core.PicklingExecutor): job_class = LocalJob def __init__( - self, folder: tp.Union[str, Path], + self, + folder: tp.Union[str, Path], max_num_timeout: int = 3, max_pickle_size_gb: float = 1.0, - python: tp.Optional[str] = None + python: tp.Optional[str] = None, ) -> None: super().__init__( folder, diff --git a/submitit/slurm/slurm.py b/submitit/slurm/slurm.py index 137d9e5..024dd92 100644 --- a/submitit/slurm/slurm.py +++ b/submitit/slurm/slurm.py @@ -245,7 +245,7 @@ def __init__( folder: tp.Union[str, Path], max_num_timeout: int = 3, max_pickle_size_gb: float = 1.0, - python: tp.Optional[str] = None + python: tp.Optional[str] = None, ) -> None: super().__init__( folder,