|
3 | 3 | import psycopg.errors |
4 | 4 | from fastapi import APIRouter |
5 | 5 | from langchain_core.messages import HumanMessage |
6 | | -from langchain_mcp_adapters.tools import load_mcp_tools |
7 | 6 | from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver |
8 | | -from psycopg_pool import AsyncConnectionPool |
9 | 7 | from sse_starlette.sse import EventSourceResponse |
10 | 8 | from starlette.responses import Response |
11 | 9 |
|
12 | 10 | from api.core.agent.orchestration import get_config, get_graph |
13 | | -from api.core.config import settings |
14 | | -from api.core.dependencies import LLMDep |
| 11 | +from api.core.dependencies import LLMDep, setup_graph |
15 | 12 | from api.core.logs import print, uvicorn |
16 | | -from api.routers.mcps import mcp_sse_client |
17 | 13 |
|
18 | 14 | router = APIRouter(tags=["chat"]) |
19 | 15 |
|
@@ -62,23 +58,13 @@ async def stream_graph( |
62 | 58 | query: str, |
63 | 59 | llm: LLMDep, |
64 | 60 | ) -> AsyncGenerator[dict[str, str], None]: |
65 | | - # NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections. |
66 | | - # A compatible psycopg connection is created via the connection pool to connect to the checkpointer. |
67 | | - async with AsyncConnectionPool( |
68 | | - conninfo=settings.checkpoint_conn_str, |
69 | | - kwargs=dict(prepare_threshold=None), |
70 | | - ) as pool: |
71 | | - checkpointer = await checkpointer_setup(pool) |
72 | | - |
73 | | - async with mcp_sse_client() as session: |
74 | | - tools = await load_mcp_tools(session) |
75 | | - graph = get_graph(llm, tools=tools, checkpointer=checkpointer) |
76 | | - config = get_config() |
77 | | - events = dict(messages=[HumanMessage(content=query)]) |
78 | | - |
79 | | - async for event in graph.astream_events( |
80 | | - events, config, version="v2" |
81 | | - ): |
82 | | - if event.get("event").endswith("end"): |
83 | | - print(event) |
84 | | - yield dict(data=event) |
| 61 | + async with setup_graph() as resources: |
| 62 | + checkpointer, tools, _ = resources |
| 63 | + graph = get_graph(llm, tools=tools, checkpointer=checkpointer) |
| 64 | + config = get_config() |
| 65 | + events = dict(messages=[HumanMessage(content=query)]) |
| 66 | + |
| 67 | + async for event in graph.astream_events(events, config, version="v2"): |
| 68 | + if event.get("event").endswith("end"): |
| 69 | + print(event) |
| 70 | + yield dict(data=event) |
0 commit comments