Skip to content

Commit 6ec7e14

Browse files
committed
refactor(langgraph): combine context managers
1 parent f7d3e33 commit 6ec7e14

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

backend/api/core/dependencies.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from contextlib import asynccontextmanager
22
from typing import Annotated, AsyncGenerator
33

4+
import psycopg.errors
45
from fastapi import Depends
6+
from langchain_mcp_adapters.tools import load_mcp_tools
57
from langchain_openai import ChatOpenAI
8+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
69
from mcp import ClientSession
710
from mcp.client.sse import sse_client
11+
from psycopg_pool import AsyncConnectionPool
812
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
913

1014
from api.core.config import settings
15+
from api.core.logs import uvicorn
1116

1217

1318
def get_llm() -> ChatOpenAI:
@@ -42,3 +47,30 @@ async def mcp_sse_client() -> AsyncGenerator[ClientSession, None]:
4247
async with ClientSession(read_stream, write_stream) as session:
4348
await session.initialize()
4449
yield session
50+
51+
52+
async def checkpointer_setup(pool):
53+
checkpointer = AsyncPostgresSaver(pool)
54+
try:
55+
await checkpointer.setup()
56+
except (
57+
psycopg.errors.DuplicateColumn,
58+
psycopg.errors.ActiveSqlTransaction,
59+
):
60+
uvicorn.warning("Skipping checkpointer setup — already configured.")
61+
return checkpointer
62+
63+
64+
@asynccontextmanager
65+
async def setup_graph():
66+
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
67+
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
68+
async with AsyncConnectionPool(
69+
conninfo=settings.checkpoint_conn_str,
70+
kwargs=dict(prepare_threshold=None),
71+
) as pool:
72+
checkpointer = await checkpointer_setup(pool)
73+
74+
async with mcp_sse_client() as session:
75+
tools = await load_mcp_tools(session)
76+
yield checkpointer, tools, session

backend/api/routers/llms.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
import psycopg.errors
44
from fastapi import APIRouter
55
from langchain_core.messages import HumanMessage
6-
from langchain_mcp_adapters.tools import load_mcp_tools
76
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
8-
from psycopg_pool import AsyncConnectionPool
97
from sse_starlette.sse import EventSourceResponse
108
from starlette.responses import Response
119

1210
from api.core.agent.orchestration import get_config, get_graph
13-
from api.core.config import settings
14-
from api.core.dependencies import LLMDep
11+
from api.core.dependencies import LLMDep, setup_graph
1512
from api.core.logs import print, uvicorn
16-
from api.routers.mcps import mcp_sse_client
1713

1814
router = APIRouter(tags=["chat"])
1915

@@ -62,23 +58,13 @@ async def stream_graph(
6258
query: str,
6359
llm: LLMDep,
6460
) -> AsyncGenerator[dict[str, str], None]:
65-
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
66-
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
67-
async with AsyncConnectionPool(
68-
conninfo=settings.checkpoint_conn_str,
69-
kwargs=dict(prepare_threshold=None),
70-
) as pool:
71-
checkpointer = await checkpointer_setup(pool)
72-
73-
async with mcp_sse_client() as session:
74-
tools = await load_mcp_tools(session)
75-
graph = get_graph(llm, tools=tools, checkpointer=checkpointer)
76-
config = get_config()
77-
events = dict(messages=[HumanMessage(content=query)])
78-
79-
async for event in graph.astream_events(
80-
events, config, version="v2"
81-
):
82-
if event.get("event").endswith("end"):
83-
print(event)
84-
yield dict(data=event)
61+
async with setup_graph() as resources:
62+
checkpointer, tools, _ = resources
63+
graph = get_graph(llm, tools=tools, checkpointer=checkpointer)
64+
config = get_config()
65+
events = dict(messages=[HumanMessage(content=query)])
66+
67+
async for event in graph.astream_events(events, config, version="v2"):
68+
if event.get("event").endswith("end"):
69+
print(event)
70+
yield dict(data=event)

0 commit comments

Comments
 (0)