Skip to content

Commit bf5668e

Browse files
authored
Fix import_path. (#424)
1 parent 94deca2 commit bf5668e

File tree

5 files changed

+136
-16
lines changed

5 files changed

+136
-16
lines changed

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
4040
decorators.
4141
- {pull}`421` removes the deprecation warning when `produces` is used as an magic
4242
function keyword to define products.
43+
- {pull}`424` fixes problems with {func}`~_pytask.path.import_path`.
4344

4445
## 0.3.2 - 2023-06-07
4546

src/_pytask/path.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Contains code to handle paths."""
22
from __future__ import annotations
33

4+
import contextlib
45
import functools
56
import importlib.util
67
import os
@@ -122,6 +123,8 @@ def import_path(path: Path, root: Path) -> ModuleType:
122123
123124
"""
124125
module_name = _module_name_from_path(path, root)
126+
with contextlib.suppress(KeyError):
127+
return sys.modules[module_name]
125128

126129
spec = importlib.util.spec_from_file_location(module_name, str(path))
127130

@@ -154,6 +157,11 @@ def _module_name_from_path(path: Path, root: Path) -> str:
154157
# Use the parts for the relative path to the root path.
155158
path_parts = relative_path.parts
156159

160+
# Module name for packages do not contain the __init__ file, unless the
161+
# `__init__.py` file is at the root.
162+
if len(path_parts) >= 2 and path_parts[-1] == "__init__": # noqa: PLR2004
163+
path_parts = path_parts[:-1]
164+
157165
return ".".join(path_parts)
158166

159167

tests/conftest.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,77 @@
11
from __future__ import annotations
22

3+
import sys
4+
from contextlib import contextmanager
35
from pathlib import Path
6+
from typing import Callable
47

58
import pytest
69
from click.testing import CliRunner
710

811

9-
@pytest.fixture()
10-
def runner():
11-
return CliRunner()
12-
13-
1412
@pytest.fixture(autouse=True)
1513
def _add_objects_to_doctest_namespace(doctest_namespace):
1614
doctest_namespace["Path"] = Path
15+
16+
17+
class SysPathsSnapshot:
18+
"""A snapshot for sys.path."""
19+
20+
def __init__(self) -> None:
21+
self.__saved = list(sys.path), list(sys.meta_path)
22+
23+
def restore(self) -> None:
24+
sys.path[:], sys.meta_path[:] = self.__saved
25+
26+
27+
class SysModulesSnapshot:
28+
"""A snapshot for sys.modules."""
29+
30+
def __init__(self, preserve: Callable[[str], bool] | None = None) -> None:
31+
self.__preserve = preserve
32+
self.__saved = dict(sys.modules)
33+
34+
def restore(self) -> None:
35+
if self.__preserve:
36+
self.__saved.update(
37+
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
38+
)
39+
sys.modules.clear()
40+
sys.modules.update(self.__saved)
41+
42+
43+
@contextmanager
44+
def restore_sys_path_and_module_after_test_execution():
45+
sys_path_snapshot = SysPathsSnapshot()
46+
sys_modules_snapshot = SysModulesSnapshot()
47+
yield
48+
sys_modules_snapshot.restore()
49+
sys_path_snapshot.restore()
50+
51+
52+
@pytest.fixture(autouse=True)
53+
def _restore_sys_path_and_module_after_test_execution():
54+
"""Restore sys.path and sys.modules after every test execution.
55+
56+
This fixture became necessary because most task modules in the tests are named
57+
`task_example`. Since the change in #424, the same module is not reimported which
58+
solves errors with parallelization. At the same time, modules with the same name in
59+
the tests are overshadowing another and letting tests fail.
60+
61+
The changes to `sys.path` might not be necessary to restore, but we do it anyways.
62+
63+
"""
64+
with restore_sys_path_and_module_after_test_execution():
65+
yield
66+
67+
68+
class CustomCliRunner(CliRunner):
69+
def invoke(self, *args, **kwargs):
70+
"""Restore sys.path and sys.modules after an invocation."""
71+
with restore_sys_path_and_module_after_test_execution():
72+
return super().invoke(*args, **kwargs)
73+
74+
75+
@pytest.fixture()
76+
def runner():
77+
return CustomCliRunner()

tests/test_path.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,34 @@ def test_find_case_sensitive_path(tmp_path, path, existing_paths, expected):
124124

125125

126126
@pytest.fixture()
127-
def simple_module(tmp_path: Path) -> Path:
128-
fn = tmp_path / "_src/project/mymod.py"
127+
def simple_module(request, tmp_path: Path) -> Path:
128+
name = f"mymod_{request.node.name}"
129+
fn = tmp_path / f"_src/project/{name}.py"
129130
fn.parent.mkdir(parents=True)
130131
fn.write_text("def foo(x): return 40 + x")
131-
return fn
132+
module_name = _module_name_from_path(fn, root=tmp_path)
133+
yield fn
134+
sys.modules.pop(module_name, None)
132135

133136

134137
@pytest.mark.unit()
135-
def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
138+
def test_importmode_importlib(request, simple_module: Path, tmp_path: Path) -> None:
136139
"""`importlib` mode does not change sys.path."""
137140
module = import_path(simple_module, root=tmp_path)
138141
assert module.foo(2) == 42 # type: ignore[attr-defined]
139142
assert str(simple_module.parent) not in sys.path
140143
assert module.__name__ in sys.modules
141-
assert module.__name__ == "_src.project.mymod"
144+
assert module.__name__ == f"_src.project.mymod_{request.node.name}"
142145
assert "_src" in sys.modules
143146
assert "_src.project" in sys.modules
144147

145148

146149
@pytest.mark.unit()
147-
def test_importmode_twice_is_different_module(
148-
simple_module: Path, tmp_path: Path
149-
) -> None:
150-
"""`importlib` mode always returns a new module."""
150+
def test_remembers_previous_imports(simple_module: Path, tmp_path: Path) -> None:
151+
"""importlib mode called remembers previous module (pytest#10341, pytest#10811)."""
151152
module1 = import_path(simple_module, root=tmp_path)
152153
module2 = import_path(simple_module, root=tmp_path)
153-
assert module1 is not module2
154+
assert module1 is module2
154155

155156

156157
@pytest.mark.unit()
@@ -165,6 +166,9 @@ def test_no_meta_path_found(
165166
# mode='importlib' fails if no spec is found to load the module
166167
import importlib.util
167168

169+
# Force module to be re-imported.
170+
del sys.modules[module.__name__]
171+
168172
monkeypatch.setattr(
169173
importlib.util, "spec_from_file_location", lambda *args: None # noqa: ARG005
170174
)
@@ -288,6 +292,10 @@ def test_module_name_from_path(tmp_path: Path) -> None:
288292
result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar"))
289293
assert result == "home.foo.task_foo"
290294

295+
# Importing __init__.py files should return the package as module name.
296+
result = _module_name_from_path(tmp_path / "src/app/__init__.py", tmp_path)
297+
assert result == "src.app"
298+
291299

292300
@pytest.mark.unit()
293301
def test_insert_missing_modules(
@@ -308,3 +316,42 @@ def test_insert_missing_modules(
308316
modules = {}
309317
_insert_missing_modules(modules, "")
310318
assert not modules
319+
320+
321+
def test_importlib_package(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
322+
"""
323+
Importing a package using --importmode=importlib should not import the
324+
package's __init__.py file more than once (#11306).
325+
"""
326+
monkeypatch.chdir(tmp_path)
327+
monkeypatch.syspath_prepend(tmp_path)
328+
329+
package_name = "importlib_import_package"
330+
tmp_path.joinpath(package_name).mkdir()
331+
init = tmp_path.joinpath(f"{package_name}/__init__.py")
332+
init.write_text(
333+
textwrap.dedent(
334+
"""
335+
from .singleton import Singleton
336+
instance = Singleton()
337+
"""
338+
),
339+
encoding="ascii",
340+
)
341+
singleton = tmp_path.joinpath(f"{package_name}/singleton.py")
342+
singleton.write_text(
343+
textwrap.dedent(
344+
"""
345+
class Singleton:
346+
INSTANCES = []
347+
def __init__(self) -> None:
348+
self.INSTANCES.append(self)
349+
if len(self.INSTANCES) > 1:
350+
raise RuntimeError("Already initialized")
351+
"""
352+
),
353+
encoding="ascii",
354+
)
355+
356+
mod = import_path(init, root=tmp_path)
357+
assert len(mod.instance.INSTANCES) == 1

tests/test_persist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pytask import SkippedUnchanged
1414
from pytask import TaskOutcome
1515

16+
from tests.conftest import restore_sys_path_and_module_after_test_execution
17+
1618

1719
class DummyClass:
1820
pass
@@ -46,7 +48,8 @@ def task_dummy(depends_on, produces):
4648
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
4749
tmp_path.joinpath("in.txt").write_text("I'm not the reason you care.")
4850

49-
session = build(paths=tmp_path)
51+
with restore_sys_path_and_module_after_test_execution():
52+
session = build(paths=tmp_path)
5053

5154
assert session.exit_code == ExitCode.OK
5255
assert len(session.execution_reports) == 1

0 commit comments

Comments
 (0)