Skip to content

Add support for MCP's Streamable HTTP transport #1716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions docs/mcp/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@ pip/uv-add "pydantic-ai-slim[mcp]"

PydanticAI comes with two ways to connect to MCP servers:

- [`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport
- [`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] which connects to an MCP server using the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) transport
- [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] which runs the server as a subprocess and connects to it using the [stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) transport

Examples of both are shown below; [mcp-run-python](run-python.md) is used as the MCP server in both examples.

### SSE Client
### HTTP Client

[`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server.
[`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] connects over HTTP using the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) to a server.

!!! note
[`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI.

The name "HTTP" is used since this implemented will be adapted in future to use the new
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
The StreamableHTTP Transport is able to connect to both stateless HTTP and older Server Sent Events (SSE) servers.

Before creating the SSE client, we need to run the server (docs [here](run-python.md)):
Before creating the HTTP client, we need to run the server (docs [here](run-python.md)):

```bash {title="terminal (run sse server)"}
```bash {title="terminal (run http server)"}
deno run \
-N -R=node_modules -W=node_modules --node-modules-dir=auto \
jsr:@pydantic/mcp-run-python sse
Expand All @@ -56,7 +55,7 @@ async def main():
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
```

1. Define the MCP server with the URL used to connect.
1. Define the MCP server with the URL used to connect. This will typically end in `/mcp` for HTTP servers and `/sse` for SSE.
2. Create an agent with the MCP server attached.
3. Create a client session to connect to the server.

Expand Down
91 changes: 57 additions & 34 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

import base64
import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from types import TracebackType
from typing import Any

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.shared.message import SessionMessage
from mcp.types import (
BlobResourceContents,
EmbeddedResource,
ImageContent,
JSONRPCMessage,
LoggingLevel,
TextContent,
TextResourceContents,
Expand All @@ -28,8 +30,8 @@

try:
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
except ImportError as _import_error:
raise ImportError(
'Please install the `mcp` package to use the MCP server, '
Expand All @@ -55,19 +57,16 @@ class MCPServer(ABC):
"""

_client: ClientSession
_read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
_write_stream: MemoryObjectSendStream[JSONRPCMessage]
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
_write_stream: MemoryObjectSendStream[SessionMessage]
_exit_stack: AsyncExitStack

@abstractmethod
@asynccontextmanager
async def client_streams(
self,
) -> AsyncIterator[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]:
"""Create the streams for the MCP server."""
raise NotImplementedError('MCP Server subclasses must implement this method.')
Expand Down Expand Up @@ -256,10 +255,7 @@ async def main():
async def client_streams(
self,
) -> AsyncIterator[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]:
server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd)
async with stdio_client(server=server) as (read_stream, write_stream):
Expand All @@ -276,11 +272,11 @@ def __repr__(self) -> str:
class MCPServerHTTP(MCPServer):
"""An MCP server that connects over streamable HTTP connections.
This class implements the SSE transport from the MCP specification.
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
This class implements the Streamable HTTP transport from the MCP specification.
See <https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http> for more information.
The name "HTTP" is used since this implemented will be adapted in future to use the new
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
The Streamable HTTP transport is intended to replace the SSE transport from the previous protocol, but it is fully
backwards compatible with SSE-based servers.
!!! note
Using this class as an async context manager will create a new pool of HTTP connections to connect
Expand All @@ -291,7 +287,7 @@ class MCPServerHTTP(MCPServer):
from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerHTTP
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
server = MCPServerHTTP('http://localhost:3001/mcp') # (1)!
agent = Agent('openai:gpt-4o', mcp_servers=[server])
async def main():
Expand All @@ -304,27 +300,27 @@ async def main():
"""

url: str
"""The URL of the SSE endpoint on the MCP server.
"""The URL of the SSE or MCP endpoint on the MCP server.
For example for a server running locally, this might be `http://localhost:3001/sse`.
For example for a server running locally, this might be `http://localhost:3001/mcp`.
"""

headers: dict[str, Any] | None = None
"""Optional HTTP headers to be sent with each request to the SSE endpoint.
"""Optional HTTP headers to be sent with each request to the endpoint.
These headers will be passed directly to the underlying `httpx.AsyncClient`.
Useful for authentication, custom headers, or other HTTP-specific configurations.
"""

timeout: float = 5
"""Initial connection timeout in seconds for establishing the SSE connection.
timeout: timedelta | float = timedelta(seconds=5)
"""Initial connection timeout as a timedelta for establishing the connection.
This timeout applies to the initial connection setup and handshake.
If the connection cannot be established within this time, the operation will fail.
"""

sse_read_timeout: float = 60 * 5
"""Maximum time in seconds to wait for new SSE messages before timing out.
sse_read_timeout: timedelta | float = timedelta(minutes=5)
"""Maximum time as a timedelta to wait for new SSE messages before timing out.
This timeout applies to the long-lived SSE connection after it's established.
If no new messages are received within this time, the connection will be considered stale
Expand All @@ -346,21 +342,48 @@ async def main():
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
"""

def __post_init__(self):
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)

@asynccontextmanager
async def client_streams(
self,
) -> AsyncIterator[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]: # pragma: no cover
async with sse_client(
url=self.url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as (read_stream, write_stream):
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)

async with streamablehttp_client(
url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
) as (read_stream, write_stream, _):
yield read_stream, write_stream

def _get_log_level(self) -> LoggingLevel | None:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.6.0; python_version >= '3.10'"]
mcp = ["mcp>=1.8.0; python_version >= '3.10'"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that I think this may need to be 1.9.0 as the protocol version in 1.8.0 is still the 2024-11-05 one. we want the 2025-03-26 one (latest) as that's where the streamable HTTP protocol is defined

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlaneve Thanks, fixed in #1840

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tagging 1.9.0 works great, but I'm not seeing why we needed to. The 1.8.0 release is what introduced streamable http.

# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
43 changes: 30 additions & 13 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the MCP (Model Context Protocol) server implementation."""

import re
from datetime import timedelta
from pathlib import Path

import pytest
Expand Down Expand Up @@ -70,25 +71,41 @@ async def test_stdio_server_with_cwd():
assert len(tools) == 10


def test_sse_server():
sse_server = MCPServerHTTP(url='http://localhost:8000/sse')
assert sse_server.url == 'http://localhost:8000/sse'
assert sse_server._get_log_level() is None # pyright: ignore[reportPrivateUsage]
def test_http_server():
http_server = MCPServerHTTP(url='http://localhost:8000/sse')
assert http_server.url == 'http://localhost:8000/sse'
assert http_server._get_log_level() is None # pyright: ignore[reportPrivateUsage]


def test_sse_server_with_header_and_timeout():
sse_server = MCPServerHTTP(
def test_http_server_with_header_and_timeout():
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
timeout=timedelta(seconds=10),
sse_read_timeout=timedelta(seconds=100),
log_level='info',
)
assert sse_server.url == 'http://localhost:8000/sse'
assert sse_server.headers is not None and sse_server.headers['my-custom-header'] == 'my-header-value'
assert sse_server.timeout == 10
assert sse_server.sse_read_timeout == 100
assert sse_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


def test_http_server_with_deprecated_arguments():
with pytest.warns(DeprecationWarning):
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
log_level='info',
)
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


@pytest.mark.vcr()
Expand Down
11 changes: 6 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.