Skip to content

Commit 3c912ef

Browse files
authored
Allow tasks to depend on other tasks. (#493)
1 parent 46fa675 commit 3c912ef

File tree

9 files changed

+191
-8
lines changed

9 files changed

+191
-8
lines changed

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1818
when a product annotation is used with the argument name `produces`. And, allow
1919
`produces` to intake any node.
2020
- {pull}`490` refactors and better tests parsing of dependencies.
21+
- {pull}`493` allows tasks to depend on other tasks.
2122
- {pull}`496` makes pytask even lazier. Now, when a task produces a node whose hash
2223
remains the same, the consecutive tasks are not executed. It remained from when pytask
2324
relied on timestamps.

docs/source/tutorials/defining_dependencies_products.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,38 @@ def task_fit_model(depends_on, produces):
410410
:::
411411
::::
412412

413+
## Depending on a task
414+
415+
In some situations you want to define a task depending on another task without
416+
specifying the relationship explicitly.
417+
418+
pytask allows you to do that, but you loose features like access to paths which is why
419+
defining dependencies explicitly is always preferred.
420+
421+
There are two modes for it and both use {func}`@task(after=...) <pytask.task>`.
422+
423+
First, you can pass the task function or multiple task functions to the decorator.
424+
Applied to the tasks from before, we could have written `task_plot_data` as
425+
426+
```python
427+
@task(after=task_create_random_data)
428+
def task_plot_data(...):
429+
...
430+
```
431+
432+
You can also pass a list of task functions.
433+
434+
The second mode is to pass an expression, a substring of the name of the dependent
435+
tasks. Here, we can pass the function name or a significant part of the function
436+
name.
437+
438+
```python
439+
@task(after="random_data")
440+
def task_plot_data(...):
441+
...
442+
```
443+
444+
You will learn more about expressions in {doc}`selecting_tasks`.
413445

414446
## References
415447

src/_pytask/collect.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def pytask_collect_task(
254254
)
255255

256256
markers = get_all_marks(obj)
257+
collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None
258+
after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else []
257259

258260
# Get the underlying function to avoid having different states of the function,
259261
# e.g. due to pytask_meta, in different layers of the wrapping.
@@ -266,6 +268,7 @@ def pytask_collect_task(
266268
depends_on=dependencies,
267269
produces=products,
268270
markers=markers,
271+
attributes={"collection_id": collection_id, "after": after},
269272
)
270273
return Task(
271274
base_name=name,
@@ -274,6 +277,7 @@ def pytask_collect_task(
274277
depends_on=dependencies,
275278
produces=products,
276279
markers=markers,
280+
attributes={"collection_id": collection_id, "after": after},
277281
)
278282
if isinstance(obj, PTask) and not inspect.isclass(obj):
279283
return obj
@@ -294,7 +298,7 @@ def pytask_collect_task(
294298
295299
Please, align the names to ensure reproducibility on case-sensitive file systems \
296300
(often Linux or macOS) or disable this error with 'check_casing_of_paths = false' in \
297-
your pytask configuration file.
301+
the pyproject.toml file.
298302
299303
Hint: If parts of the path preceding your project directory are not properly \
300304
formatted, check whether you need to call `.resolve()` on `SRC`, `BLD` or other paths \

src/_pytask/dag.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from _pytask.console import render_to_string
1616
from _pytask.console import TASK_ICON
1717
from _pytask.exceptions import ResolvingDependenciesError
18+
from _pytask.mark import select_by_after_keyword
1819
from _pytask.node_protocols import PNode
1920
from _pytask.node_protocols import PTask
2021
from _pytask.nodes import PythonNode
@@ -93,6 +94,30 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
9394
return dag
9495

9596

97+
@hookimpl
98+
def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None:
99+
"""Create dependencies between tasks when using ``@task(after=...)``."""
100+
temporary_id_to_task = {
101+
task.attributes["collection_id"]: task
102+
for task in session.tasks
103+
if "collection_id" in task.attributes
104+
}
105+
for task in session.tasks:
106+
after = task.attributes.get("after")
107+
if isinstance(after, list):
108+
for temporary_id in after:
109+
other_task = temporary_id_to_task[temporary_id]
110+
for successor in dag.successors(other_task.signature):
111+
dag.add_edge(successor, task.signature)
112+
elif isinstance(after, str):
113+
task_signature = task.signature
114+
signatures = select_by_after_keyword(session, after)
115+
signatures.discard(task_signature)
116+
for signature in signatures:
117+
for successor in dag.successors(signature):
118+
dag.add_edge(successor, task.signature)
119+
120+
96121
def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
97122
"""Check if DAG has cycles."""
98123
try:

src/_pytask/mark/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"MarkDecorator",
4040
"MarkGenerator",
4141
"ParseError",
42+
"select_by_after_keyword",
4243
"select_by_keyword",
4344
"select_by_mark",
4445
]
@@ -168,6 +169,22 @@ def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str]:
168169
return remaining
169170

170171

172+
def select_by_after_keyword(session: Session, after: str) -> set[str]:
173+
"""Select tasks defined by the after keyword."""
174+
try:
175+
expression = Expression.compile_(after)
176+
except ParseError as e:
177+
msg = f"Wrong expression passed to 'after': {after}: {e}"
178+
raise ValueError(msg) from None
179+
180+
ancestors: set[str] = set()
181+
for task in session.tasks:
182+
if after and expression.evaluate(KeywordMatcher.from_task(task)):
183+
ancestors.add(task.signature)
184+
185+
return ancestors
186+
187+
171188
@define(slots=True)
172189
class MarkMatcher:
173190
"""A matcher for markers which are present.

src/_pytask/mark/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from _pytask.tree_util import PyTree
1010
from _pytask.session import Session
1111
import networkx as nx
1212

13+
def select_by_after_keyword(session: Session, after: str) -> set[str]: ...
1314
def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str]: ...
1415
def select_by_mark(session: Session, dag: nx.DiGraph) -> set[str]: ...
1516

@@ -54,4 +55,5 @@ __all__ = [
5455
"ParseError",
5556
"select_by_keyword",
5657
"select_by_mark",
58+
"select_by_after_keyword",
5759
]

src/_pytask/models.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from __future__ import annotations
33

44
from typing import Any
5+
from typing import Callable
56
from typing import NamedTuple
67
from typing import TYPE_CHECKING
8+
from uuid import UUID
9+
from uuid import uuid4
710

811
from attrs import define
912
from attrs import field
@@ -16,18 +19,39 @@
1619

1720
@define
1821
class CollectionMetadata:
19-
"""A class for carrying metadata from functions to tasks."""
20-
22+
"""A class for carrying metadata from functions to tasks.
23+
24+
Attributes
25+
----------
26+
after
27+
An expression or a task function or a list of task functions that need to be
28+
executed before this task can.
29+
id_
30+
An id for the task if it is part of a parametrization. Otherwise, an automatic
31+
id will be generated. See
32+
:doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for
33+
more information.
34+
kwargs
35+
A dictionary containing keyword arguments which are passed to the task when it
36+
is executed.
37+
markers
38+
A list of markers that are attached to the task.
39+
name
40+
Use it to override the name of the task that is, by default, the name of the
41+
callable.
42+
produces
43+
Definition of products to parse the function returns and store them. See
44+
:doc:`this how-to guide <../how_to_guides/using_task_returns>` for more
45+
information.
46+
"""
47+
48+
after: str | list[Callable[..., Any]] = field(factory=list)
2149
id_: str | None = None
22-
"""The id for a single parametrization."""
2350
kwargs: dict[str, Any] = field(factory=dict)
24-
"""Contains kwargs which are necessary for the task function on execution."""
2551
markers: list[Mark] = field(factory=list)
26-
"""Contains the markers of the function."""
2752
name: str | None = None
28-
"""The name of the task function."""
2953
produces: PyTree[Any] | None = None
30-
"""Definition of products to handle returns."""
54+
_id: UUID = field(factory=uuid4)
3155

3256

3357
class NodeInfo(NamedTuple):

src/_pytask/task_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
def task(
4040
name: str | None = None,
4141
*,
42+
after: str | Callable[..., Any] | list[Callable[..., Any]] | None = None,
4243
id: str | None = None, # noqa: A002
4344
kwargs: dict[Any, Any] | None = None,
4445
produces: PyTree[Any] | None = None,
@@ -55,6 +56,9 @@ def task(
5556
name
5657
Use it to override the name of the task that is, by default, the name of the
5758
callable.
59+
after
60+
An expression or a task function or a list of task functions that need to be
61+
executed before this task can.
5862
id
5963
An id for the task if it is part of a parametrization. Otherwise, an automatic
6064
id will be generated. See
@@ -102,20 +106,23 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
102106

103107
parsed_kwargs = {} if kwargs is None else kwargs
104108
parsed_name = name if isinstance(name, str) else func.__name__
109+
parsed_after = _parse_after(after)
105110

106111
if hasattr(unwrapped, "pytask_meta"):
107112
unwrapped.pytask_meta.name = parsed_name
108113
unwrapped.pytask_meta.kwargs = parsed_kwargs
109114
unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
110115
unwrapped.pytask_meta.id_ = id
111116
unwrapped.pytask_meta.produces = produces
117+
unwrapped.pytask_meta.after = parsed_after
112118
else:
113119
unwrapped.pytask_meta = CollectionMetadata(
114120
name=parsed_name,
115121
kwargs=parsed_kwargs,
116122
markers=[Mark("task", (), {})],
117123
id_=id,
118124
produces=produces,
125+
after=parsed_after,
119126
)
120127

121128
# Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage
@@ -131,6 +138,30 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
131138
return wrapper
132139

133140

141+
def _parse_after(
142+
after: str | Callable[..., Any] | list[Callable[..., Any]] | None
143+
) -> str | list[Callable[..., Any]]:
144+
if not after:
145+
return []
146+
if isinstance(after, str):
147+
return after
148+
if callable(after):
149+
if not hasattr(after, "pytask_meta"):
150+
after.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
151+
return [after.pytask_meta._id] # type: ignore[attr-defined]
152+
if isinstance(after, list):
153+
new_after = []
154+
for func in after:
155+
if not hasattr(func, "pytask_meta"):
156+
func.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
157+
new_after.append(func.pytask_meta._id) # type: ignore[attr-defined]
158+
msg = (
159+
"'after' should be an expression string, a task, or a list of class. Got "
160+
f"{after}, instead."
161+
)
162+
raise TypeError(msg)
163+
164+
134165
def parse_collected_tasks_with_task_marker(
135166
tasks: list[Callable[..., Any]],
136167
) -> dict[str, Callable[..., Any]]:

tests/test_task.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,50 @@ def func(path: Annotated[Path, Product]):
615615
assert result.exit_code == ExitCode.COLLECTION_FAILED
616616
assert "Duplicated tasks" in result.output
617617
assert "id=b.txt" in result.output
618+
619+
620+
def test_task_will_be_executed_after_another_one_with_string(runner, tmp_path):
621+
source = """
622+
from pytask import task
623+
from pathlib import Path
624+
from typing_extensions import Annotated
625+
626+
@task(after="task_first")
627+
def task_second():
628+
assert Path(__file__).parent.joinpath("out.txt").exists()
629+
630+
def task_first() -> Annotated[str, Path("out.txt")]:
631+
return "Hello, World!"
632+
"""
633+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
634+
635+
result = runner.invoke(cli, [tmp_path.as_posix()])
636+
assert result.exit_code == ExitCode.OK
637+
assert "2 Succeeded" in result.output
638+
639+
# Make sure that the dependence does not only apply to the task (and task module),
640+
# but also it products.
641+
tmp_path.joinpath("out.txt").write_text("Hello, Moon!")
642+
result = runner.invoke(cli, [tmp_path.as_posix()])
643+
assert result.exit_code == ExitCode.OK
644+
assert "1 Succeeded" in result.output
645+
assert "1 Skipped because unchanged" in result.output
646+
647+
648+
def test_task_will_be_executed_after_another_one_with_function(tmp_path):
649+
source = """
650+
from pytask import task
651+
from pathlib import Path
652+
from typing_extensions import Annotated
653+
654+
def task_first() -> Annotated[str, Path("out.txt")]:
655+
return "Hello, World!"
656+
657+
@task(after=task_first)
658+
def task_second():
659+
assert Path(__file__).parent.joinpath("out.txt").exists()
660+
"""
661+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
662+
663+
session = build(paths=tmp_path)
664+
assert session.exit_code == ExitCode.OK

0 commit comments

Comments
 (0)