Skip to content

Commit 21df97c

Browse files
ordeq: narrow type hint on node factory (#474)
1 parent 534579f commit 21df97c

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

packages/ordeq/src/ordeq/_nodes.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,12 @@ def _is_node(obj: object) -> TypeGuard[Node]:
297297
def create_node(
298298
func: Callable[FuncParams, FuncReturns],
299299
*,
300-
inputs: Sequence[Input | Callable] | Input | Callable | None = None,
300+
inputs: Sequence[Input | Node] | Input | Node | None = None,
301301
outputs: Sequence[Output] | Output | None = None,
302-
checks: Sequence[Input | Output | Callable]
302+
checks: Sequence[Input | Output | Node]
303303
| Input
304304
| Output
305-
| Callable
305+
| Node
306306
| ResourceType
307307
| None = None,
308308
attributes: dict[str, Any] | None = None,
@@ -315,12 +315,12 @@ def create_node(
315315
def create_node(
316316
func: Callable[FuncParams, FuncReturns],
317317
*,
318-
inputs: Sequence[Input | Callable] | Input | Callable | None = None,
318+
inputs: Sequence[Input | Node] | Input | Node | None = None,
319319
outputs: None = None,
320-
checks: Sequence[Input | Output | Callable]
320+
checks: Sequence[Input | Output | Node]
321321
| Input
322322
| Output
323-
| Callable
323+
| Node
324324
| ResourceType
325325
| None = None,
326326
attributes: dict[str, Any] | None = None,
@@ -332,12 +332,12 @@ def create_node(
332332
def create_node(
333333
func: Callable[FuncParams, FuncReturns],
334334
*,
335-
inputs: Sequence[Input | Callable] | Input | Callable | None = None,
335+
inputs: Sequence[Input | Node] | Input | Node | None = None,
336336
outputs: Sequence[Output] | Output | None = None,
337-
checks: Sequence[Input | Output | Callable]
337+
checks: Sequence[Input | Output | Node]
338338
| Input
339339
| Output
340-
| Callable
340+
| Node
341341
| ResourceType
342342
| None = None,
343343
attributes: dict[str, Any] | None = None,
@@ -432,19 +432,22 @@ def create_node(
432432

433433
# Default value for 'func' in case it is not passed.
434434
# Used to distinguish between 'func=None' and func missing as positional arg.
435-
def not_passed(*args, **kwargs): ...
435+
def _not_passed(*args, **kwargs): ...
436+
437+
438+
not_passed = cast("Node", _not_passed)
436439

437440

438441
@overload
439442
def node(
440443
func: Callable[FuncParams, FuncReturns],
441444
*,
442-
inputs: Sequence[Input | Callable] | Input | Callable | None = None,
445+
inputs: Sequence[Input | Node] | Input | Node | None = None,
443446
outputs: Sequence[Output] | Output | None = None,
444-
checks: Sequence[Input | Output | Callable]
447+
checks: Sequence[Input | Output | Node]
445448
| Input
446449
| Output
447-
| Callable
450+
| Node
448451
| ResourceType
449452
| None = None,
450453
**attributes: Any,
@@ -454,12 +457,12 @@ def node(
454457
@overload
455458
def node(
456459
*,
457-
inputs: Sequence[Input | Callable] | Input | Callable = not_passed,
460+
inputs: Sequence[Input | Node] | Input | Node = not_passed,
458461
outputs: Sequence[Output] | Output | None = None,
459-
checks: Sequence[Input | Output | Callable]
462+
checks: Sequence[Input | Output | Node]
460463
| Input
461464
| Output
462-
| Callable
465+
| Node
463466
| ResourceType
464467
| None = None,
465468
**attributes: Any,
@@ -471,12 +474,12 @@ def node(
471474
def node(
472475
func: Callable[FuncParams, FuncReturns] = not_passed,
473476
*,
474-
inputs: Sequence[Input | Callable] | Input | Callable | None = None,
477+
inputs: Sequence[Input | Node] | Input | Node | None = None,
475478
outputs: Sequence[Output] | Output | None = None,
476-
checks: Sequence[Input | Output | Callable]
479+
checks: Sequence[Input | Output | Node]
477480
| Input
478481
| Output
479-
| Callable
482+
| Node
480483
| ResourceType
481484
| None = None,
482485
**attributes: Any,

0 commit comments

Comments
 (0)