Skip to content

Commit e2d9e52

Browse files
authored
Improve generic typehints in graph.py (#1139)
1 parent 35895ce commit e2d9e52

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

pydantic_graph/pydantic_graph/graph.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
77
from dataclasses import dataclass, field
88
from functools import cached_property
9-
from typing import Any, Generic, TypeVar, cast
9+
from typing import Any, Generic, cast
1010

1111
import logfire_api
1212
import typing_extensions
@@ -34,9 +34,6 @@
3434

3535
_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
3636

37-
T = TypeVar('T')
38-
"""An invariant typevar."""
39-
4037

4138
@dataclass(init=False)
4239
class Graph(Generic[StateT, DepsT, RunEndT]):
@@ -121,15 +118,15 @@ def __init__(
121118
self._validate_edges()
122119

123120
async def run(
124-
self: Graph[StateT, DepsT, T],
125-
start_node: BaseNode[StateT, DepsT, T],
121+
self,
122+
start_node: BaseNode[StateT, DepsT, RunEndT],
126123
*,
127124
state: StateT = None,
128125
deps: DepsT = None,
129-
persistence: BaseStatePersistence[StateT, T] | None = None,
126+
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
130127
infer_name: bool = True,
131128
span: LogfireSpan | None = None,
132-
) -> GraphRunResult[StateT, T]:
129+
) -> GraphRunResult[StateT, RunEndT]:
133130
"""Run the graph from a starting node until it ends.
134131
135132
Args:
@@ -177,14 +174,14 @@ async def main():
177174
return final_result
178175

179176
def run_sync(
180-
self: Graph[StateT, DepsT, T],
181-
start_node: BaseNode[StateT, DepsT, T],
177+
self,
178+
start_node: BaseNode[StateT, DepsT, RunEndT],
182179
*,
183180
state: StateT = None,
184181
deps: DepsT = None,
185-
persistence: BaseStatePersistence[StateT, T] | None = None,
182+
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
186183
infer_name: bool = True,
187-
) -> GraphRunResult[StateT, T]:
184+
) -> GraphRunResult[StateT, RunEndT]:
188185
"""Synchronously run the graph.
189186
190187
This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
@@ -211,15 +208,15 @@ def run_sync(
211208

212209
@asynccontextmanager
213210
async def iter(
214-
self: Graph[StateT, DepsT, T],
215-
start_node: BaseNode[StateT, DepsT, T],
211+
self,
212+
start_node: BaseNode[StateT, DepsT, RunEndT],
216213
*,
217214
state: StateT = None,
218215
deps: DepsT = None,
219-
persistence: BaseStatePersistence[StateT, T] | None = None,
216+
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
220217
span: AbstractContextManager[Any] | None = None,
221218
infer_name: bool = True,
222-
) -> AsyncIterator[GraphRun[StateT, DepsT, T]]:
219+
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
223220
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
224221
225222
This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as
@@ -261,19 +258,19 @@ async def iter(
261258
with ExitStack() as stack:
262259
if span is not None:
263260
stack.enter_context(span)
264-
yield GraphRun[StateT, DepsT, T](
261+
yield GraphRun[StateT, DepsT, RunEndT](
265262
graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps
266263
)
267264

268265
@asynccontextmanager
269266
async def iter_from_persistence(
270-
self: Graph[StateT, DepsT, T],
271-
persistence: BaseStatePersistence[StateT, T],
267+
self,
268+
persistence: BaseStatePersistence[StateT, RunEndT],
272269
*,
273270
deps: DepsT = None,
274271
span: AbstractContextManager[Any] | None = None,
275272
infer_name: bool = True,
276-
) -> AsyncIterator[GraphRun[StateT, DepsT, T]]:
273+
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
277274
"""A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object.
278275
279276
This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter],
@@ -306,7 +303,7 @@ async def iter_from_persistence(
306303
with ExitStack() as stack:
307304
if span is not None:
308305
stack.enter_context(span)
309-
yield GraphRun[StateT, DepsT, T](
306+
yield GraphRun[StateT, DepsT, RunEndT](
310307
graph=self,
311308
start_node=snapshot.node,
312309
persistence=persistence,
@@ -316,9 +313,9 @@ async def iter_from_persistence(
316313
)
317314

318315
async def initialize(
319-
self: Graph[StateT, DepsT, T],
320-
node: BaseNode[StateT, DepsT, T],
321-
persistence: BaseStatePersistence[StateT, T],
316+
self,
317+
node: BaseNode[StateT, DepsT, RunEndT],
318+
persistence: BaseStatePersistence[StateT, RunEndT],
322319
*,
323320
state: StateT = None,
324321
infer_name: bool = True,
@@ -342,14 +339,14 @@ async def initialize(
342339

343340
@deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead')
344341
async def next(
345-
self: Graph[StateT, DepsT, T],
346-
node: BaseNode[StateT, DepsT, T],
347-
persistence: BaseStatePersistence[StateT, T],
342+
self,
343+
node: BaseNode[StateT, DepsT, RunEndT],
344+
persistence: BaseStatePersistence[StateT, RunEndT],
348345
*,
349346
state: StateT = None,
350347
deps: DepsT = None,
351348
infer_name: bool = True,
352-
) -> BaseNode[StateT, DepsT, Any] | End[T]:
349+
) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]:
353350
"""Run a node in the graph and return the next node to run.
354351
355352
Args:
@@ -367,7 +364,7 @@ async def next(
367364
self._infer_name(inspect.currentframe())
368365

369366
persistence.set_graph_types(self)
370-
run = GraphRun[StateT, DepsT, T](
367+
run = GraphRun[StateT, DepsT, RunEndT](
371368
graph=self,
372369
start_node=node,
373370
persistence=persistence,
@@ -537,8 +534,8 @@ def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]:
537534
return state_type, run_end_type # pyright: ignore[reportReturnType]
538535

539536
def _register_node(
540-
self: Graph[StateT, DepsT, T],
541-
node: type[BaseNode[StateT, DepsT, T]],
537+
self,
538+
node: type[BaseNode[StateT, DepsT, RunEndT]],
542539
parent_namespace: dict[str, Any] | None,
543540
) -> None:
544541
node_id = node.get_node_id()
@@ -689,8 +686,8 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
689686
)
690687

691688
async def next(
692-
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] | None = None
693-
) -> BaseNode[StateT, DepsT, T] | End[T]:
689+
self, node: BaseNode[StateT, DepsT, RunEndT] | None = None
690+
) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
694691
"""Manually drive the graph run by passing in the node you want to run next.
695692
696693
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
@@ -733,7 +730,10 @@ async def main():
733730
the run has completed.
734731
"""
735732
if node is None:
736-
node = cast(BaseNode[StateT, DepsT, T], self._next_node)
733+
# This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's
734+
# the case, but if it is, the only way to get there would be to have tried calling next manually after
735+
# the run finished. Either way, maybe it would be better to not do this cast...
736+
node = cast(BaseNode[StateT, DepsT, RunEndT], self._next_node)
737737
node_snapshot_id = node.get_snapshot_id()
738738
else:
739739
node_snapshot_id = node.get_snapshot_id()

0 commit comments

Comments
 (0)