Skip to content

Commit 76975c3

Browse files
committed
ordeq: extend type hint for node factories
1 parent 21df97c commit 76975c3

File tree

6 files changed

+156
-28
lines changed

6 files changed

+156
-28
lines changed

packages/ordeq/src/ordeq/_nodes.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def create_node(
298298
func: Callable[FuncParams, FuncReturns],
299299
*,
300300
inputs: Sequence[Input | Node] | Input | Node | None = None,
301-
outputs: Sequence[Output] | Output | None = None,
301+
outputs: Sequence[Output] | Output,
302302
checks: Sequence[Input | Output | Node]
303303
| Input
304304
| Output
@@ -435,15 +435,15 @@ def create_node(
435435
def _not_passed(*args, **kwargs): ...
436436

437437

438-
not_passed = cast("Node", _not_passed)
438+
not_passed = cast("View", _not_passed)
439439

440440

441441
@overload
442442
def node(
443443
func: Callable[FuncParams, FuncReturns],
444444
*,
445-
inputs: Sequence[Input | Node] | Input | Node | None = None,
446-
outputs: Sequence[Output] | Output | None = None,
445+
inputs: Sequence[Input | View] | Input | View | None = None,
446+
outputs: Sequence[Output] | Output,
447447
checks: Sequence[Input | Output | Node]
448448
| Input
449449
| Output
@@ -456,9 +456,25 @@ def node(
456456

457457
@overload
458458
def node(
459+
func: Callable[FuncParams, FuncReturns],
459460
*,
460-
inputs: Sequence[Input | Node] | Input | Node = not_passed,
461-
outputs: Sequence[Output] | Output | None = None,
461+
inputs: Sequence[Input | View] | Input | View | None = None,
462+
outputs: None = None,
463+
checks: Sequence[Input | Output | Node]
464+
| Input
465+
| Output
466+
| Node
467+
| ResourceType
468+
| None = None,
469+
**attributes: Any,
470+
) -> View[FuncParams, FuncReturns]: ...
471+
472+
473+
@overload
474+
def node(
475+
*,
476+
inputs: Sequence[Input | View] | Input | View = not_passed,
477+
outputs: Sequence[Output] | Output,
462478
checks: Sequence[Input | Output | Node]
463479
| Input
464480
| Output
@@ -471,10 +487,27 @@ def node(
471487
]: ...
472488

473489

490+
@overload
474491
def node(
475-
func: Callable[FuncParams, FuncReturns] = not_passed,
476492
*,
477-
inputs: Sequence[Input | Node] | Input | Node | None = None,
493+
inputs: Sequence[Input | View] | Input | View = not_passed,
494+
outputs: None = None,
495+
checks: Sequence[Input | Output | Node]
496+
| Input
497+
| Output
498+
| Node
499+
| ResourceType
500+
| None = None,
501+
**attributes: Any,
502+
) -> Callable[
503+
[Callable[FuncParams, FuncReturns]], View[FuncParams, FuncReturns]
504+
]: ...
505+
506+
507+
def node(
508+
func: Callable[FuncParams, FuncReturns] = _not_passed,
509+
*,
510+
inputs: Sequence[Input | View] | Input | View | None = None,
478511
outputs: Sequence[Output] | Output | None = None,
479512
checks: Sequence[Input | Output | Node]
480513
| Input
@@ -485,9 +518,11 @@ def node(
485518
**attributes: Any,
486519
) -> (
487520
Callable[
488-
[Callable[FuncParams, FuncReturns]], Node[FuncParams, FuncReturns]
521+
[Callable[FuncParams, FuncReturns]],
522+
Node[FuncParams, FuncReturns] | View[FuncParams, FuncReturns],
489523
]
490524
| Node[FuncParams, FuncReturns]
525+
| View[FuncParams, FuncReturns]
491526
):
492527
"""Decorator that creates a node from a function. When a node is run,
493528
the inputs are loaded and passed to the function. The returned values

packages/ordeq/tests/resources/nodes/node_type_hints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from typing_extensions import reveal_type
23

34
from ordeq import node
45
from ordeq._nodes import _is_node
@@ -13,6 +14,7 @@ def func(x: str, y: str) -> tuple[str, str]:
1314
return f"{x} + {y}", y
1415

1516

17+
reveal_type(func)
1618
print(type(func))
1719
print(func)
1820
print(inspect.get_annotations(func))
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import inspect
2+
from typing_extensions import reveal_type
3+
4+
from ordeq import node
5+
from ordeq._nodes import _is_node
6+
from ordeq_common import StringBuffer
7+
8+
9+
@node(inputs=[StringBuffer("x"), StringBuffer("y")])
10+
def func(x: str, y: str) -> tuple[str, str]:
11+
return f"{x} + {y}", y
12+
13+
14+
reveal_type(func)
15+
print(type(func))
16+
print(func)
17+
print(inspect.get_annotations(func))
18+
print(_is_node(func))

packages/ordeq/tests/snapshots/nodes/node_type_hints.snapshot.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
```python
44
import inspect
5+
from typing_extensions import reveal_type
56

67
from ordeq import node
78
from ordeq._nodes import _is_node
@@ -16,6 +17,7 @@ def func(x: str, y: str) -> tuple[str, str]:
1617
return f"{x} + {y}", y
1718

1819

20+
reveal_type(func)
1921
print(type(func))
2022
print(func)
2123
print(inspect.get_annotations(func))
@@ -31,4 +33,11 @@ node 'func' in module '__main__'
3133
{'x': <class 'str'>, 'y': <class 'str'>, 'return': tuple[str, str]}
3234
True
3335
36+
```
37+
38+
## Error
39+
40+
```text
41+
Runtime type is 'Node'
42+
3443
```

0 commit comments

Comments
 (0)