6
6
from contextlib import AbstractContextManager , ExitStack , asynccontextmanager
7
7
from dataclasses import dataclass , field
8
8
from functools import cached_property
9
- from typing import Any , Generic , TypeVar , cast
9
+ from typing import Any , Generic , cast
10
10
11
11
import logfire_api
12
12
import typing_extensions
34
34
35
35
_logfire = logfire_api .Logfire (otel_scope = 'pydantic-graph' )
36
36
37
- T = TypeVar ('T' )
38
- """An invariant typevar."""
39
-
40
37
41
38
@dataclass (init = False )
42
39
class Graph (Generic [StateT , DepsT , RunEndT ]):
@@ -121,15 +118,15 @@ def __init__(
121
118
self ._validate_edges ()
122
119
123
120
async def run (
124
- self : Graph [ StateT , DepsT , T ] ,
125
- start_node : BaseNode [StateT , DepsT , T ],
121
+ self ,
122
+ start_node : BaseNode [StateT , DepsT , RunEndT ],
126
123
* ,
127
124
state : StateT = None ,
128
125
deps : DepsT = None ,
129
- persistence : BaseStatePersistence [StateT , T ] | None = None ,
126
+ persistence : BaseStatePersistence [StateT , RunEndT ] | None = None ,
130
127
infer_name : bool = True ,
131
128
span : LogfireSpan | None = None ,
132
- ) -> GraphRunResult [StateT , T ]:
129
+ ) -> GraphRunResult [StateT , RunEndT ]:
133
130
"""Run the graph from a starting node until it ends.
134
131
135
132
Args:
@@ -177,14 +174,14 @@ async def main():
177
174
return final_result
178
175
179
176
def run_sync (
180
- self : Graph [ StateT , DepsT , T ] ,
181
- start_node : BaseNode [StateT , DepsT , T ],
177
+ self ,
178
+ start_node : BaseNode [StateT , DepsT , RunEndT ],
182
179
* ,
183
180
state : StateT = None ,
184
181
deps : DepsT = None ,
185
- persistence : BaseStatePersistence [StateT , T ] | None = None ,
182
+ persistence : BaseStatePersistence [StateT , RunEndT ] | None = None ,
186
183
infer_name : bool = True ,
187
- ) -> GraphRunResult [StateT , T ]:
184
+ ) -> GraphRunResult [StateT , RunEndT ]:
188
185
"""Synchronously run the graph.
189
186
190
187
This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
@@ -211,15 +208,15 @@ def run_sync(
211
208
212
209
@asynccontextmanager
213
210
async def iter (
214
- self : Graph [ StateT , DepsT , T ] ,
215
- start_node : BaseNode [StateT , DepsT , T ],
211
+ self ,
212
+ start_node : BaseNode [StateT , DepsT , RunEndT ],
216
213
* ,
217
214
state : StateT = None ,
218
215
deps : DepsT = None ,
219
- persistence : BaseStatePersistence [StateT , T ] | None = None ,
216
+ persistence : BaseStatePersistence [StateT , RunEndT ] | None = None ,
220
217
span : AbstractContextManager [Any ] | None = None ,
221
218
infer_name : bool = True ,
222
- ) -> AsyncIterator [GraphRun [StateT , DepsT , T ]]:
219
+ ) -> AsyncIterator [GraphRun [StateT , DepsT , RunEndT ]]:
223
220
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
224
221
225
222
This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as
@@ -261,19 +258,19 @@ async def iter(
261
258
with ExitStack () as stack :
262
259
if span is not None :
263
260
stack .enter_context (span )
264
- yield GraphRun [StateT , DepsT , T ](
261
+ yield GraphRun [StateT , DepsT , RunEndT ](
265
262
graph = self , start_node = start_node , persistence = persistence , state = state , deps = deps
266
263
)
267
264
268
265
@asynccontextmanager
269
266
async def iter_from_persistence (
270
- self : Graph [ StateT , DepsT , T ] ,
271
- persistence : BaseStatePersistence [StateT , T ],
267
+ self ,
268
+ persistence : BaseStatePersistence [StateT , RunEndT ],
272
269
* ,
273
270
deps : DepsT = None ,
274
271
span : AbstractContextManager [Any ] | None = None ,
275
272
infer_name : bool = True ,
276
- ) -> AsyncIterator [GraphRun [StateT , DepsT , T ]]:
273
+ ) -> AsyncIterator [GraphRun [StateT , DepsT , RunEndT ]]:
277
274
"""A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object.
278
275
279
276
This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter],
@@ -306,7 +303,7 @@ async def iter_from_persistence(
306
303
with ExitStack () as stack :
307
304
if span is not None :
308
305
stack .enter_context (span )
309
- yield GraphRun [StateT , DepsT , T ](
306
+ yield GraphRun [StateT , DepsT , RunEndT ](
310
307
graph = self ,
311
308
start_node = snapshot .node ,
312
309
persistence = persistence ,
@@ -316,9 +313,9 @@ async def iter_from_persistence(
316
313
)
317
314
318
315
async def initialize (
319
- self : Graph [ StateT , DepsT , T ] ,
320
- node : BaseNode [StateT , DepsT , T ],
321
- persistence : BaseStatePersistence [StateT , T ],
316
+ self ,
317
+ node : BaseNode [StateT , DepsT , RunEndT ],
318
+ persistence : BaseStatePersistence [StateT , RunEndT ],
322
319
* ,
323
320
state : StateT = None ,
324
321
infer_name : bool = True ,
@@ -342,14 +339,14 @@ async def initialize(
342
339
343
340
@deprecated ('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead' )
344
341
async def next (
345
- self : Graph [ StateT , DepsT , T ] ,
346
- node : BaseNode [StateT , DepsT , T ],
347
- persistence : BaseStatePersistence [StateT , T ],
342
+ self ,
343
+ node : BaseNode [StateT , DepsT , RunEndT ],
344
+ persistence : BaseStatePersistence [StateT , RunEndT ],
348
345
* ,
349
346
state : StateT = None ,
350
347
deps : DepsT = None ,
351
348
infer_name : bool = True ,
352
- ) -> BaseNode [StateT , DepsT , Any ] | End [T ]:
349
+ ) -> BaseNode [StateT , DepsT , Any ] | End [RunEndT ]:
353
350
"""Run a node in the graph and return the next node to run.
354
351
355
352
Args:
@@ -367,7 +364,7 @@ async def next(
367
364
self ._infer_name (inspect .currentframe ())
368
365
369
366
persistence .set_graph_types (self )
370
- run = GraphRun [StateT , DepsT , T ](
367
+ run = GraphRun [StateT , DepsT , RunEndT ](
371
368
graph = self ,
372
369
start_node = node ,
373
370
persistence = persistence ,
@@ -537,8 +534,8 @@ def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]:
537
534
return state_type , run_end_type # pyright: ignore[reportReturnType]
538
535
539
536
def _register_node (
540
- self : Graph [ StateT , DepsT , T ] ,
541
- node : type [BaseNode [StateT , DepsT , T ]],
537
+ self ,
538
+ node : type [BaseNode [StateT , DepsT , RunEndT ]],
542
539
parent_namespace : dict [str , Any ] | None ,
543
540
) -> None :
544
541
node_id = node .get_node_id ()
@@ -689,8 +686,8 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
689
686
)
690
687
691
688
async def next (
692
- self : GraphRun [ StateT , DepsT , T ], node : BaseNode [StateT , DepsT , T ] | None = None
693
- ) -> BaseNode [StateT , DepsT , T ] | End [T ]:
689
+ self , node : BaseNode [StateT , DepsT , RunEndT ] | None = None
690
+ ) -> BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]:
694
691
"""Manually drive the graph run by passing in the node you want to run next.
695
692
696
693
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
@@ -733,7 +730,10 @@ async def main():
733
730
the run has completed.
734
731
"""
735
732
if node is None :
736
- node = cast (BaseNode [StateT , DepsT , T ], self ._next_node )
733
+ # This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's
734
+ # the case, but if it is, the only way to get there would be to have tried calling next manually after
735
+ # the run finished. Either way, maybe it would be better to not do this cast...
736
+ node = cast (BaseNode [StateT , DepsT , RunEndT ], self ._next_node )
737
737
node_snapshot_id = node .get_snapshot_id ()
738
738
else :
739
739
node_snapshot_id = node .get_snapshot_id ()
0 commit comments