Skip to content

Commit 802b44e

Browse files
authored
Fix detection of task functions. (#437)
1 parent 2f46b5a commit 802b44e

File tree

10 files changed

+50
-25
lines changed

10 files changed

+50
-25
lines changed

docs/source/changes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
5252
- {pull}`431` enables colors for WSL.
5353
- {pull}`432` fixes type checking of `pytask.mark.xxx`.
5454
- {pull}`433` fixes the ids generated for {class}`~pytask.PythonNode`s.
55+
- {pull}`437` fixes the detection of task functions and publishes
56+
{func}`pytask.is_task_function`.
5557

5658
## 0.3.2 - 2023-06-07
5759

docs/source/reference_guides/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ To parse dependencies and products from nodes, use the following functions.
264264

265265
```{eval-rst}
266266
.. autofunction:: pytask.depends_on
267-
.. autofunction:: pytask.parse_nodes
267+
.. autofunction:: pytask.parse_dependencies_from_task_function
268+
.. autofunction:: pytask.parse_products_from_task_function
268269
.. autofunction:: pytask.produces
269270
```
270271

src/_pytask/collect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from _pytask.shared import reduce_node_name
4040
from _pytask.task_utils import task as task_decorator
4141
from _pytask.traceback import render_exc_info
42+
from _pytask.typing import is_task_function
4243
from rich.text import Text
4344

4445
if TYPE_CHECKING:
@@ -102,7 +103,7 @@ def _collect_from_tasks(session: Session) -> None:
102103
obj=raw_task,
103104
)
104105

105-
if callable(raw_task):
106+
if is_task_function(raw_task):
106107
if not hasattr(raw_task, "pytask_meta"):
107108
raw_task = task_decorator()(raw_task) # noqa: PLW2901
108109

src/_pytask/collect_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
__all__ = [
3939
"depends_on",
4040
"parse_dependencies_from_task_function",
41+
"parse_products_from_task_function",
4142
"parse_nodes",
4243
"produces",
4344
]
@@ -370,18 +371,6 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, PNode
370371
Read more about products in the documentation: https://tinyurl.com/yrezszr4.
371372
"""
372373

373-
_WARNING_PRODUCES_AS_KWARG = """Using 'produces' as an argument name to specify \
374-
products is deprecated and won't be available in pytask v0.5. Instead, use the product \
375-
annotation, described in this tutorial: https://tinyurl.com/yrezszr4.
376-
377-
from typing_extensions import Annotated
378-
from pytask import Product
379-
380-
def task_example(produces: Annotated[..., Product]):
381-
...
382-
383-
"""
384-
385374

386375
def parse_products_from_task_function(
387376
session: Session, task_path: Path, task_name: str, node_path: Path, obj: Any

src/_pytask/mark/structures.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
from __future__ import annotations
22

3-
import functools
43
import warnings
54
from typing import Any
65
from typing import Callable
76
from typing import Iterable
87
from typing import Mapping
98

109
from _pytask.models import CollectionMetadata
10+
from _pytask.typing import is_task_function
1111
from attrs import define
1212
from attrs import field
1313
from attrs import validators
1414

1515

16-
def is_task_function(func: Any) -> bool:
17-
return (callable(func) and getattr(func, "__name__", "<lambda>") != "<lambda>") or (
18-
isinstance(func, functools.partial)
19-
and getattr(func.func, "__name__", "<lambda>") != "<lambda>"
20-
)
21-
22-
2316
@define(frozen=True)
2417
class Mark:
2518
"""A class for a mark containing the name, positional and keyword arguments."""

src/_pytask/report.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def from_exception(
3737
exc_info: ExceptionInfo,
3838
node: MetaNode | None = None,
3939
) -> CollectionReport:
40+
exc_info = remove_internal_traceback_frames_from_exc_info(exc_info)
4041
return cls(outcome=outcome, node=node, exc_info=exc_info)
4142

4243

src/_pytask/task_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from _pytask.mark import Mark
1313
from _pytask.models import CollectionMetadata
1414
from _pytask.shared import find_duplicates
15+
from _pytask.typing import is_task_function
1516

1617
if TYPE_CHECKING:
1718
from _pytask.tree_util import PyTree
@@ -123,7 +124,7 @@ def wrapper(func: Callable[..., Any]) -> None:
123124

124125
# In case the decorator is used without parentheses, wrap the function which is
125126
# passed as the first argument with the default arguments.
126-
if callable(name) and kwargs is None:
127+
if is_task_function(name) and kwargs is None:
127128
return task()(name)
128129
return wrapper
129130

src/_pytask/typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import functools
4+
from typing import Any
5+
36
from attr import define
47

58

@@ -13,3 +16,9 @@ class ProductType:
1316

1417
Product = ProductType()
1518
"""ProductType: A singleton to mark products in annotations."""
19+
20+
21+
def is_task_function(func: Any) -> bool:
22+
return (callable(func) and hasattr(func, "__name__")) or (
23+
isinstance(func, functools.partial) and hasattr(func.func, "__name__")
24+
)

src/pytask/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from _pytask.click import ColoredGroup
88
from _pytask.click import EnumChoice
99
from _pytask.collect_utils import depends_on
10-
from _pytask.collect_utils import parse_nodes
10+
from _pytask.collect_utils import parse_dependencies_from_task_function
11+
from _pytask.collect_utils import parse_products_from_task_function
1112
from _pytask.collect_utils import produces
1213
from _pytask.compat import check_for_optional_program
1314
from _pytask.compat import import_optional_dependency
@@ -44,6 +45,7 @@
4445
from _pytask.nodes import PathNode
4546
from _pytask.nodes import PythonNode
4647
from _pytask.nodes import Task
48+
from _pytask.nodes import TaskWithoutPath
4749
from _pytask.outcomes import CollectionOutcome
4850
from _pytask.outcomes import count_outcomes
4951
from _pytask.outcomes import Exit
@@ -63,6 +65,7 @@
6365
from _pytask.traceback import remove_internal_traceback_frames_from_exc_info
6466
from _pytask.traceback import remove_traceback_from_exc_info
6567
from _pytask.traceback import render_exc_info
68+
from _pytask.typing import is_task_function
6669
from _pytask.typing import Product
6770
from _pytask.warnings_utils import parse_warning_filter
6871
from _pytask.warnings_utils import warning_record_to_str
@@ -115,6 +118,7 @@
115118
"State",
116119
"Task",
117120
"TaskOutcome",
121+
"TaskWithoutPath",
118122
"WarningReport",
119123
"__version__",
120124
"build",
@@ -131,8 +135,10 @@
131135
"has_mark",
132136
"hookimpl",
133137
"import_optional_dependency",
138+
"is_task_function",
134139
"mark",
135-
"parse_nodes",
140+
"parse_dependencies_from_task_function",
141+
"parse_products_from_task_function",
136142
"parse_warning_filter",
137143
"produces",
138144
"remove_internal_traceback_frames_from_exc_info",

tests/test_typing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
5+
from _pytask.typing import is_task_function
6+
7+
8+
def test_is_task_function():
9+
def func():
10+
pass
11+
12+
assert is_task_function(func)
13+
14+
partialed_func = functools.partial(func)
15+
16+
assert is_task_function(partialed_func)
17+
18+
assert is_task_function(lambda x: x)
19+
20+
partialed_lambda = functools.partial(lambda x: x)
21+
22+
assert is_task_function(partialed_lambda)

0 commit comments

Comments
 (0)