Skip to content

Commit 2dfb0f7

Browse files
authored
Improve handling of task_files. (#568)
1 parent cf6cf40 commit 2dfb0f7

File tree

5 files changed

+31
-9
lines changed

5 files changed

+31
-9
lines changed

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1818
- {pull}`555` uses new-style hook wrappers and requires pluggy 1.3 for typing.
1919
- {pull}`557` fixes an issue with `@task(after=...)` in notebooks and terminals.
2020
- {pull}`566` makes universal-pathlib an official dependency.
21+
- {pull}`568` restricts `task_files` to a list of patterns and raises a better error.
2122

2223
## 0.4.5 - 2024-01-09
2324

docs/source/reference_guides/configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ strict_markers = true
242242
Change the pattern which identify task files.
243243
244244
```toml
245-
task_files = "task_*.py" # default
245+
task_files = ["task_*.py"] # default
246246
247247
task_files = ["task_*.py", "tasks_*.py"]
248248
```

src/_pytask/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def build( # noqa: C901, PLR0912, PLR0913, PLR0915
9595
stop_after_first_failure: bool = False,
9696
strict_markers: bool = False,
9797
tasks: Callable[..., Any] | PTask | Iterable[Callable[..., Any] | PTask] = (),
98-
task_files: str | Iterable[str] = "task_*.py",
98+
task_files: Iterable[str] = ("task_*.py",),
9999
trace: bool = False,
100100
verbose: int = 1,
101101
**kwargs: Any,

src/_pytask/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ def pytask_parse_config(config: dict[str, Any]) -> None:
9292
+ IGNORED_TEMPORARY_FILES_AND_FOLDERS
9393
)
9494

95-
config["task_files"] = to_list(config.get("task_files", "task_*.py"))
95+
value = config.get("task_files", ["task_*.py"])
96+
if not isinstance(value, (list, tuple)) or not all(
97+
isinstance(p, str) for p in value
98+
):
99+
msg = "'task_files' must be a list of patterns."
100+
raise ValueError(msg)
101+
config["task_files"] = value
96102

97103
if config["stop_after_first_failure"]:
98104
config["max_failures"] = 1

tests/test_collect.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def test_collect_same_task_different_ways(tmp_path, path_extension):
103103
@pytest.mark.parametrize(
104104
("task_files", "pattern", "expected_collected_tasks"),
105105
[
106-
(["example_task.py"], "'*_task.py'", 1),
107-
(["tasks_example.py"], "'tasks_*'", 1),
108-
(["example_tasks.py"], "'*_tasks.py'", 1),
109-
(["task_module.py", "tasks_example.py"], "'tasks_*.py'", 1),
106+
(["example_task.py"], "['*_task.py']", 1),
107+
(["tasks_example.py"], "['tasks_*']", 1),
108+
(["example_tasks.py"], "['*_tasks.py']", 1),
109+
(["task_module.py", "tasks_example.py"], "['tasks_*.py']", 1),
110110
(["task_module.py", "tasks_example.py"], "['task_*.py', 'tasks_*.py']", 2),
111111
],
112112
)
@@ -117,15 +117,30 @@ def test_collect_files_w_custom_file_name_pattern(
117117
f"[tool.pytask.ini_options]\ntask_files = {pattern}"
118118
)
119119

120-
for file in task_files:
121-
tmp_path.joinpath(file).write_text("def task_example(): pass")
120+
for file_ in task_files:
121+
tmp_path.joinpath(file_).write_text("def task_example(): pass")
122122

123123
session = build(paths=tmp_path)
124124

125125
assert session.exit_code == ExitCode.OK
126126
assert len(session.tasks) == expected_collected_tasks
127127

128128

129+
def test_error_with_invalid_file_name_pattern(runner, tmp_path):
130+
tmp_path.joinpath("pyproject.toml").write_text(
131+
"[tool.pytask.ini_options]\ntask_files = 'asds'"
132+
)
133+
134+
result = runner.invoke(cli, [tmp_path.as_posix()])
135+
assert result.exit_code == ExitCode.CONFIGURATION_FAILED
136+
assert "'task_files' must be a list of patterns." in result.output
137+
138+
139+
def test_error_with_invalid_file_name_pattern_(tmp_path):
140+
session = build(paths=tmp_path, task_files=[1])
141+
assert session.exit_code == ExitCode.CONFIGURATION_FAILED
142+
143+
129144
@pytest.mark.unit()
130145
@pytest.mark.parametrize(
131146
("session", "path", "node_info", "expected"),

0 commit comments

Comments
 (0)