@@ -297,12 +297,12 @@ def _is_node(obj: object) -> TypeGuard[Node]:
297297def create_node (
298298 func : Callable [FuncParams , FuncReturns ],
299299 * ,
300- inputs : Sequence [Input | Callable ] | Input | Callable | None = None ,
300+ inputs : Sequence [Input | Node ] | Input | Node | None = None ,
301301 outputs : Sequence [Output ] | Output | None = None ,
302- checks : Sequence [Input | Output | Callable ]
302+ checks : Sequence [Input | Output | Node ]
303303 | Input
304304 | Output
305- | Callable
305+ | Node
306306 | ResourceType
307307 | None = None ,
308308 attributes : dict [str , Any ] | None = None ,
@@ -315,12 +315,12 @@ def create_node(
315315def create_node (
316316 func : Callable [FuncParams , FuncReturns ],
317317 * ,
318- inputs : Sequence [Input | Callable ] | Input | Callable | None = None ,
318+ inputs : Sequence [Input | Node ] | Input | Node | None = None ,
319319 outputs : None = None ,
320- checks : Sequence [Input | Output | Callable ]
320+ checks : Sequence [Input | Output | Node ]
321321 | Input
322322 | Output
323- | Callable
323+ | Node
324324 | ResourceType
325325 | None = None ,
326326 attributes : dict [str , Any ] | None = None ,
@@ -332,12 +332,12 @@ def create_node(
332332def create_node (
333333 func : Callable [FuncParams , FuncReturns ],
334334 * ,
335- inputs : Sequence [Input | Callable ] | Input | Callable | None = None ,
335+ inputs : Sequence [Input | Node ] | Input | Node | None = None ,
336336 outputs : Sequence [Output ] | Output | None = None ,
337- checks : Sequence [Input | Output | Callable ]
337+ checks : Sequence [Input | Output | Node ]
338338 | Input
339339 | Output
340- | Callable
340+ | Node
341341 | ResourceType
342342 | None = None ,
343343 attributes : dict [str , Any ] | None = None ,
@@ -432,19 +432,22 @@ def create_node(
432432
433433# Default value for 'func' in case it is not passed.
434434# Used to distinguish between 'func=None' and func missing as positional arg.
435- def not_passed (* args , ** kwargs ): ...
435+ def _not_passed (* args , ** kwargs ): ...
436+
437+
438+ not_passed = cast ("Node" , _not_passed )
436439
437440
438441@overload
439442def node (
440443 func : Callable [FuncParams , FuncReturns ],
441444 * ,
442- inputs : Sequence [Input | Callable ] | Input | Callable | None = None ,
445+ inputs : Sequence [Input | Node ] | Input | Node | None = None ,
443446 outputs : Sequence [Output ] | Output | None = None ,
444- checks : Sequence [Input | Output | Callable ]
447+ checks : Sequence [Input | Output | Node ]
445448 | Input
446449 | Output
447- | Callable
450+ | Node
448451 | ResourceType
449452 | None = None ,
450453 ** attributes : Any ,
@@ -454,12 +457,12 @@ def node(
454457@overload
455458def node (
456459 * ,
457- inputs : Sequence [Input | Callable ] | Input | Callable = not_passed ,
460+ inputs : Sequence [Input | Node ] | Input | Node = not_passed ,
458461 outputs : Sequence [Output ] | Output | None = None ,
459- checks : Sequence [Input | Output | Callable ]
462+ checks : Sequence [Input | Output | Node ]
460463 | Input
461464 | Output
462- | Callable
465+ | Node
463466 | ResourceType
464467 | None = None ,
465468 ** attributes : Any ,
@@ -471,12 +474,12 @@ def node(
471474def node (
472475 func : Callable [FuncParams , FuncReturns ] = not_passed ,
473476 * ,
474- inputs : Sequence [Input | Callable ] | Input | Callable | None = None ,
477+ inputs : Sequence [Input | Node ] | Input | Node | None = None ,
475478 outputs : Sequence [Output ] | Output | None = None ,
476- checks : Sequence [Input | Output | Callable ]
479+ checks : Sequence [Input | Output | Node ]
477480 | Input
478481 | Output
479- | Callable
482+ | Node
480483 | ResourceType
481484 | None = None ,
482485 ** attributes : Any ,
0 commit comments