Skip to content

Commit 06dba06

Browse files
committed
refactor(langgraph): pydantic typehint and validation for context managers
1 parent 6ec7e14 commit 06dba06

File tree

6 files changed

+99
-43
lines changed

6 files changed

+99
-43
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from contextlib import asynccontextmanager
2+
3+
import psycopg
4+
import psycopg.errors
5+
import uvicorn
6+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
7+
from psycopg_pool import AsyncConnectionPool
8+
9+
from api.core.logs import uvicorn
10+
11+
12+
@asynccontextmanager
13+
async def checkpointer_context(conn_str: str):
14+
"""
15+
Async context manager that sets up and yields a LangGraph checkpointer.
16+
17+
Uses a psycopg async connection pool to initialize AsyncPostgresSaver.
18+
Skips setup if checkpointer is already configured.
19+
20+
Args:
21+
conn_str (str): PostgreSQL connection string.
22+
23+
Yields:
24+
AsyncPostgresSaver: The initialized checkpointer.
25+
"""
26+
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
27+
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
28+
async with AsyncConnectionPool(
29+
conninfo=conn_str,
30+
kwargs=dict(prepare_threshold=None),
31+
) as pool:
32+
checkpointer = AsyncPostgresSaver(pool)
33+
try:
34+
await checkpointer.setup()
35+
except (
36+
psycopg.errors.DuplicateColumn,
37+
psycopg.errors.ActiveSqlTransaction,
38+
):
39+
uvicorn.warning("Skipping checkpointer setup — already configured.")
40+
yield checkpointer

backend/api/core/dependencies.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
from contextlib import asynccontextmanager
2-
from typing import Annotated, AsyncGenerator
2+
from typing import Annotated
33

4-
import psycopg.errors
54
from fastapi import Depends
65
from langchain_mcp_adapters.tools import load_mcp_tools
76
from langchain_openai import ChatOpenAI
8-
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
9-
from mcp import ClientSession
10-
from mcp.client.sse import sse_client
11-
from psycopg_pool import AsyncConnectionPool
127
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
138

9+
from api.core.agent.persistence import checkpointer_context
1410
from api.core.config import settings
15-
from api.core.logs import uvicorn
11+
from api.core.mcps import mcp_sse_client
12+
from api.core.models import Resource
1613

1714

1815
def get_llm() -> ChatOpenAI:
@@ -38,39 +35,15 @@ def get_engine() -> AsyncEngine:
3835
EngineDep = Annotated[AsyncEngine, Depends(get_engine)]
3936

4037

41-
@asynccontextmanager
42-
async def mcp_sse_client() -> AsyncGenerator[ClientSession, None]:
43-
async with sse_client(f"http://mcp:{settings.mcp_server_port}/sse") as (
44-
read_stream,
45-
write_stream,
46-
):
47-
async with ClientSession(read_stream, write_stream) as session:
48-
await session.initialize()
49-
yield session
50-
51-
52-
async def checkpointer_setup(pool):
53-
checkpointer = AsyncPostgresSaver(pool)
54-
try:
55-
await checkpointer.setup()
56-
except (
57-
psycopg.errors.DuplicateColumn,
58-
psycopg.errors.ActiveSqlTransaction,
59-
):
60-
uvicorn.warning("Skipping checkpointer setup — already configured.")
61-
return checkpointer
62-
63-
6438
@asynccontextmanager
6539
async def setup_graph():
66-
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
67-
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
68-
async with AsyncConnectionPool(
69-
conninfo=settings.checkpoint_conn_str,
70-
kwargs=dict(prepare_threshold=None),
71-
) as pool:
72-
checkpointer = await checkpointer_setup(pool)
73-
40+
async with checkpointer_context(
41+
settings.checkpoint_conn_str
42+
) as checkpointer:
7443
async with mcp_sse_client() as session:
7544
tools = await load_mcp_tools(session)
76-
yield checkpointer, tools, session
45+
yield Resource(
46+
checkpointer=checkpointer,
47+
tools=tools,
48+
session=session,
49+
)

backend/api/core/mcps.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from contextlib import asynccontextmanager
2+
from typing import AsyncGenerator
3+
4+
from mcp import ClientSession
5+
from mcp.client.sse import sse_client
6+
7+
from api.core.config import settings
8+
9+
10+
@asynccontextmanager
11+
async def mcp_sse_client() -> AsyncGenerator[ClientSession]:
12+
"""
13+
Creates and initializes an MCP client session over SSE.
14+
15+
Establishes an SSE connection to the MCP server and yields an initialized
16+
`ClientSession` for communication.
17+
18+
Yields:
19+
ClientSession: An initialized MCP client session.
20+
"""
21+
async with sse_client(f"http://mcp:{settings.mcp_server_port}/sse") as (
22+
read_stream,
23+
write_stream,
24+
):
25+
async with ClientSession(read_stream, write_stream) as session:
26+
await session.initialize()
27+
yield session

backend/api/core/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from langchain_core.tools import StructuredTool
2+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
3+
from mcp import ClientSession
4+
from pydantic import BaseModel
5+
6+
7+
class Resource(BaseModel):
8+
checkpointer: AsyncPostgresSaver
9+
tools: list[StructuredTool]
10+
session: ClientSession
11+
12+
class Config:
13+
arbitrary_types_allowed = True

backend/api/routers/llms.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ async def stream_graph(
5858
query: str,
5959
llm: LLMDep,
6060
) -> AsyncGenerator[dict[str, str], None]:
61-
async with setup_graph() as resources:
62-
checkpointer, tools, _ = resources
63-
graph = get_graph(llm, tools=tools, checkpointer=checkpointer)
61+
async with setup_graph() as resource:
62+
graph = get_graph(
63+
llm,
64+
tools=resource.tools,
65+
checkpointer=resource.checkpointer,
66+
)
6467
config = get_config()
6568
events = dict(messages=[HumanMessage(content=query)])
6669

backend/api/routers/mcps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from fastapi import APIRouter
44
from mcp import types
55

6-
from api.core.dependencies import mcp_sse_client
6+
from api.core.mcps import mcp_sse_client
77
from shared_mcp.models import ToolRequest
88

99
router = APIRouter(prefix="/mcps", tags=["mcps"])

0 commit comments

Comments
 (0)