diff --git a/submitit/core/submission.py b/submitit/core/submission.py index 4d9f65f..f68afca 100644 --- a/submitit/core/submission.py +++ b/submitit/core/submission.py @@ -6,6 +6,7 @@ import argparse import os +import sys import time import traceback from pathlib import Path @@ -49,6 +50,8 @@ def process_job(folder: Union[Path, str]) -> None: ) try: delayed = utils.DelayedSubmission.load(paths.submitted_pickle) + logger.info(f"Set sys.path to {delayed.sys_path} like in the scheduler runtime.") + sys.path = list(delayed.sys_path) env = job_environment.JobEnvironment() env._handle_signals(paths, delayed) result = delayed.result() diff --git a/submitit/core/utils.py b/submitit/core/utils.py index c208d23..62d7ed1 100644 --- a/submitit/core/utils.py +++ b/submitit/core/utils.py @@ -125,6 +125,7 @@ def __init__(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> N self._done = False self._timeout_min: int = 0 self._timeout_countdown: int = 0 # controlled in submission and execution + self.sys_path = tuple(sys.path) def result(self) -> Any: if self._done: diff --git a/submitit/test_pickle.py b/submitit/test_pickle.py index 0abed7c..fd9ae6b 100644 --- a/submitit/test_pickle.py +++ b/submitit/test_pickle.py @@ -5,6 +5,9 @@ # import pickle +import subprocess +import sys +from pathlib import Path from weakref import ref import pytest @@ -78,3 +81,28 @@ def get_main() -> str: ex = LocalExecutor(tmp_path) j_main = ex.submit(get_main).result() assert main == j_main + + +def test_submitit_respects_sys_path(tmp_path: Path): + # https://github.com/facebookincubator/submitit/issues/12 + CUSTOM_CODE = f""" +import submitit +import sys +from pathlib import Path + +def dump_sys_path(file): + Path(file).write_text("\\n".join(sys.path)) + +dump_sys_path("{tmp_path}/scheduler_path.txt") +ex = submitit.LocalExecutor("{tmp_path}/log") +job = ex.submit(dump_sys_path, "{tmp_path}/job_path.txt") +job.wait() + +""" + scheduler_py = tmp_path / "scheduler.py" + scheduler_py.write_text(CUSTOM_CODE) + subprocess.check_call([sys.executable, scheduler_py]) + job_sys_path = (tmp_path / "job_path.txt").read_text() + scheduler_sys_path = (tmp_path / "scheduler_path.txt").read_text() + + assert job_sys_path == scheduler_sys_path