From 2b9559864b772757f1f4a0f7f0b4f08272829f3b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 10:23:37 +0100 Subject: [PATCH 01/19] initial streamable http server --- .../servers/simple-streamablehttp/README.md | 33 ++ .../mcp_simple_streamablehttp/__init__.py | 0 .../mcp_simple_streamablehttp/__main__.py | 4 + .../mcp_simple_streamablehttp/server.py | 167 +++++++ .../simple-streamablehttp/pyproject.toml | 47 ++ src/mcp/server/session.py | 18 +- src/mcp/server/streamableHttp.py | 415 ++++++++++++++++++ src/mcp/shared/session.py | 9 +- src/mcp/types.py | 1 + uv.lock | 41 +- 10 files changed, 727 insertions(+), 8 deletions(-) create mode 100644 examples/servers/simple-streamablehttp/README.md create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py create mode 100644 examples/servers/simple-streamablehttp/pyproject.toml create mode 100644 src/mcp/server/streamableHttp.py diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md new file mode 100644 index 00000000..aa5e707a --- /dev/null +++ b/examples/servers/simple-streamablehttp/README.md @@ -0,0 +1,33 @@ +# MCP Simple StreamableHttp Server Example + +A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming. + +## Features + +- Uses the StreamableHTTP transport for server-client communication +- Task management with anyio task groups +- Ability to send multiple notifications over time to the client +- Proper resource cleanup and lifespan management + +## Usage + +Start the server on the default or custom port: + +```bash + +# Using custom port +uv run mcp-simple-streamablehttp --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp --log-level DEBUG +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Client + +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py new file mode 100644 index 00000000..a6876bf9 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py new file mode 100644 index 00000000..19a83790 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -0,0 +1,167 @@ +import contextlib +import logging +from uuid import uuid4 + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import StreamableHTTPServerTransport +from starlette.applications import Starlette +from starlette.routing import Mount + +# Configure logging +logger = logging.getLogger(__name__) + +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +def main( + port: int, + log_level: str, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # Create a Streamable HTTP transport + http_transport = StreamableHTTPServerTransport( + mcp_session_id=uuid4().hex, + ) + + # We need to store the server instances between requests + server_instances = {} + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + if http_transport.mcp_session_id in server_instances: + logger.debug("Session already exists, handling request directly") + await http_transport.handle_request(scope, receive, send) + else: + # Start new server instance for this session + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, write_stream, app.create_initialization_options() + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + task_group.start_soon(run_server) + + # For initialization requests, store the server reference + if http_transport.mcp_session_id: + server_instances[http_transport.mcp_session_id] = True + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml new file mode 100644 index 00000000..de43bd2f --- /dev/null +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +description = "A simple MCP server exposing a website fetching tool with StreamableHttp transport" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +maintainers = [ + { name = "David Soria Parra", email = "davidsp@anthropic.com" }, + { name = "Justin Spahr-Summers", email = "justin@anthropic.com" }, +] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b..3a1f210d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -179,7 +179,11 @@ async def _received_notification( ) async def send_log_message( - self, level: types.LoggingLevel, data: Any, logger: str | None = None + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( @@ -192,7 +196,8 @@ async def send_log_message( logger=logger, ), ) - ) + ), + related_request_id, ) async def send_resource_updated(self, uri: AnyUrl) -> None: @@ -261,7 +266,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -274,7 +283,8 @@ async def send_progress_notification( total=total, ), ) - ) + ), + related_request_id, ) async def send_resource_list_changed(self) -> None: diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py new file mode 100644 index 00000000..cfce6629 --- /dev/null +++ b/src/mcp/server/streamableHttp.py @@ -0,0 +1,415 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.types import ( + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, +) + +logger = logging.getLogger(__name__) + +# Maximum size for incoming messages +MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles POST requests containing JSON-RPC messages and provides + Server-Sent Events (SSE) responses for streaming communication. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + # Dictionary to track request-specific message streams + _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] + + def __init__( + self, + mcp_session_id: str | None, + ): + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection + """ + self.mcp_session_id = mcp_session_id + self._request_streams = {} + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + ASGI application entry point that handles all HTTP requests + + Args: + stream_id: Unique identifier for this stream + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(send) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """ + Handles POST requests containing JSON-RPC messages + + Args: + stream_id: Unique identifier for this stream + scope: ASGI scope + request: Starlette Request object + receive: ASGI receive function + send: ASGI send function + """ + body = await request.body() + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + return + try: + # Validate Accept header + accept_header = request.headers.get("accept", "") + if ( + "application/json" not in accept_header + or "text/event-stream" not in accept_header + ): + response = Response( + ( + "Not Acceptable: Client must accept both application/json and " + "text/event-stream" + ), + status_code=406, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + content_type = request.headers.get("content-type", "") + if "application/json" not in content_type: + response = Response( + "Unsupported Media Type: Content-Type must be application/json", + status_code=415, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Parse the body + body = await request.body() + if len(body) > MAXIMUM_MESSAGE_SIZE: + response = Response( + "Payload Too Large: Message exceeds maximum size", + status_code=413, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = Response( + f"Parse error: {str(e)}", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + message = None + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = Response( + f"Validation error: {str(e)}", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + if not message: + response = Response( + "Invalid Request: Message is empty", + status_code=400, + headers={"Content-Type": "application/json"}, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + if is_initialization_request: + # TODO validate + logger.info("INITIALIZATION REQUEST") + # For non-initialization requests, validate the session + elif not await self._validate_session(request, send): + return + + is_request = isinstance(message.root, JSONRPCRequest) + + # For notifications and responses only, return 202 Accepted + if not is_request: + headers: dict[str, str] = {} + if self.mcp_session_id: + headers["mcp-session-id"] = self.mcp_session_id + + # Create response object and send it + response = Response("Accepted", status_code=202, headers=headers) + await response(scope, receive, send) + + # Process the message after sending the response + await writer.send(message) + + return + + # For requests, set up an SSE stream for the response + if is_request: + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + } + + if self.mcp_session_id: + headers["mcp-session-id"] = self.mcp_session_id + + # For SSE responses, set up SSE stream + headers["Content-Type"] = "text/event-stream" + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + try: + # Create a request-specific message stream for this POST request + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Get the request ID from the incoming request message + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + # Register this stream for the request ID + if request_id: + self._request_streams[request_id] = ( + request_stream_writer + ) + + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Send the message via SSE + related_request_id = None + + if isinstance( + received_message.root, JSONRPCNotification + ): + # Get related_request_id from params + params = received_message.root.params + if params and "related_request_id" in params: + related_request_id = params.get( + "related_request_id" + ) + logger.debug( + f"NOTIFICATION: {related_request_id}, " + f"{params.get('data')}" + ) + + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance(received_message.root, JSONRPCResponse): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # TODO + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Extract the request ID outside the try block for proper scope + outer_request_id = None + if isinstance(message.root, JSONRPCRequest): + outer_request_id = str(message.root.id) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Make sure to clean up the request stream if something goes wrong + if outer_request_id and outer_request_id in self._request_streams: + self._request_streams.pop(outer_request_id, None) + + except Exception as err: + logger.exception("Error handling POST request") + response = Response(f"Error handling POST request: {err}", status_code=500) + await response(scope, receive, send) + if writer: + await writer.send(err) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + pass + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + pass + + async def _handle_unsupported_request(self, send: Send) -> None: + pass + + async def _validate_session(self, request: Request, send: Send) -> bool: + # TODO + return True + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ], + None, + ]: + """ + Context manager that provides read and write streams for a connection + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._write_stream_reader = write_stream_reader + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for message in write_stream_reader: + # Determine which request stream(s) should receive this message + target_request_id = None + + # For responses, route based on the request ID + if isinstance(message.root, JSONRPCResponse): + target_request_id = str(message.root.id) + # For notifications, route by related_request_id if available + elif isinstance(message.root, JSONRPCNotification): + # Get related_request_id from params + params = message.root.params + if params and "related_request_id" in params: + related_id = params.get("related_request_id") + if related_id is not None: + target_request_id = str(related_id) + + # Send to the specific request stream if available + if ( + target_request_id + and target_request_id in self._request_streams + ): + try: + await self._request_streams[target_request_id].send( + message + ) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(target_request_id, None) + except Exception as e: + logger.exception(f"Error in message router: {e}") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + # Clean up any remaining request streams + for stream in list(self._request_streams.values()): + try: + await stream.aclose() + except Exception: + pass + self._request_streams.clear() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce3..1017bb98 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -267,16 +267,21 @@ async def send_request( else: return result_type.model_validate(response_or_error.result) - async def send_notification(self, notification: SendNotificationT) -> None: + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: """ Emits a notification, which is a one-way message that does not expect a response. """ + if related_request_id is not None and notification.root.params is not None: + notification.root.params.related_request_id = related_request_id jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) async def _send_response( diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f..30500e31 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -58,6 +58,7 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) + related_request_id: RequestId | None = None """ This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. diff --git a/uv.lock b/uv.lock index 78f46f47..65439e5c 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -10,6 +9,7 @@ members = [ "mcp", "mcp-simple-prompt", "mcp-simple-resource", + "mcp-simple-streamablehttp", "mcp-simple-tool", ] @@ -487,6 +487,7 @@ wheels = [ [[package]] name = "mcp" +version = "1.6.1.dev12+70115b9" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -543,7 +544,6 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -628,6 +628,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" From 3d790f8979bfd43d505151e024433b533376946b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 12:17:33 +0100 Subject: [PATCH 02/19] add request validation and tests --- src/mcp/server/streamableHttp.py | 270 ++++++++++++++++---- tests/server/test_streamableHttp.py | 378 ++++++++++++++++++++++++++++ 2 files changed, 601 insertions(+), 47 deletions(-) create mode 100644 tests/server/test_streamableHttp.py diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index cfce6629..e65c6c46 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -11,6 +11,7 @@ import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from http import HTTPStatus from typing import Any import anyio @@ -33,6 +34,14 @@ # Maximum size for incoming messages MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + class StreamableHTTPServerTransport: """ @@ -61,6 +70,34 @@ def __init__( self.mcp_session_id = mcp_session_id self._request_streams = {} + def _create_error_response( + self, + message: str, + status_code: HTTPStatus, + headers: dict[str, str] | None = None, + ) -> Response: + """ + Create a standardized error response. + """ + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + message, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """ + Extract the session ID from request headers. + """ + return request.headers.get(MCP_SESSION_ID_HEADER) + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """ ASGI application entry point that handles all HTTP requests @@ -80,7 +117,46 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No elif request.method == "DELETE": await self._handle_delete_request(request, send) else: - await self._handle_unsupported_request(send) + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """ + Check if the request accepts the required media types. + + Args: + request: The HTTP request + + Returns: + Tuple of (has_json, has_sse) indicating whether each media type is accepted + """ + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types + ) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """ + Check if the request has the correct Content-Type. + + Args: + request: The HTTP request + + Returns: + True if Content-Type is acceptable, False otherwise + """ + content_type = request.headers.get("content-type", "") + content_type_parts = [ + part.strip() for part in content_type.split(";")[0].split(",") + ] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) async def _handle_post_request( self, scope: Scope, request: Request, receive: Receive, send: Send @@ -89,13 +165,11 @@ async def _handle_post_request( Handles POST requests containing JSON-RPC messages Args: - stream_id: Unique identifier for this stream scope: ASGI scope request: Starlette Request object receive: ASGI receive function send: ASGI send function """ - body = await request.body() writer = self._read_stream_writer if writer is None: raise ValueError( @@ -103,41 +177,34 @@ async def _handle_post_request( ) return try: - # Validate Accept header - accept_header = request.headers.get("accept", "") - if ( - "application/json" not in accept_header - or "text/event-stream" not in accept_header - ): - response = Response( + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" ), - status_code=406, - headers={"Content-Type": "application/json"}, + HTTPStatus.NOT_ACCEPTABLE, ) await response(scope, receive, send) return # Validate Content-Type - content_type = request.headers.get("content-type", "") - if "application/json" not in content_type: - response = Response( + if not self._check_content_type(request): + response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", - status_code=415, - headers={"Content-Type": "application/json"}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) await response(scope, receive, send) return - # Parse the body + # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = Response( + response = self._create_error_response( "Payload Too Large: Message exceeds maximum size", - status_code=413, - headers={"Content-Type": "application/json"}, + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) await response(scope, receive, send) return @@ -145,29 +212,28 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = Response( + response = self._create_error_response( f"Parse error: {str(e)}", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return + message = None try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = Response( + response = self._create_error_response( f"Validation error: {str(e)}", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return + if not message: - response = Response( + response = self._create_error_response( "Invalid Request: Message is empty", - status_code=400, - headers={"Content-Type": "application/json"}, + HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return @@ -179,8 +245,19 @@ async def _handle_post_request( ) if is_initialization_request: - # TODO validate - logger.info("INITIALIZATION REQUEST") + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return # For non-initialization requests, validate the session elif not await self._validate_session(request, send): return @@ -189,12 +266,11 @@ async def _handle_post_request( # For notifications and responses only, return 202 Accepted if not is_request: - headers: dict[str, str] = {} - if self.mcp_session_id: - headers["mcp-session-id"] = self.mcp_session_id - # Create response object and send it - response = Response("Accepted", status_code=202, headers=headers) + response = self._create_error_response( + "Accepted", + HTTPStatus.ACCEPTED, + ) await response(scope, receive, send) # Process the message after sending the response @@ -208,13 +284,11 @@ async def _handle_post_request( headers = { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, } if self.mcp_session_id: - headers["mcp-session-id"] = self.mcp_session_id - - # For SSE responses, set up SSE stream - headers["Content-Type"] = "text/event-stream" + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Create SSE stream sse_stream_writer, sse_stream_reader = ( anyio.create_memory_object_stream[dict[str, Any]](0) @@ -306,23 +380,125 @@ async def sse_writer(): except Exception as err: logger.exception("Error handling POST request") - response = Response(f"Error handling POST request: {err}", status_code=500) + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) await response(scope, receive, send) if writer: await writer.send(err) return async def _handle_get_request(self, request: Request, send: Send) -> None: - pass + """ + Handle GET requests for SSE stream establishment + + Args: + request: The HTTP request + send: ASGI send function + """ + # Validate session ID if server has one + if not await self._validate_session(request, send): + return + + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + # TODO: Implement SSE stream for GET requests + # For now, return 501 Not Implemented + response = self._create_error_response( + "SSE stream from GET request not implemented yet", + HTTPStatus.NOT_IMPLEMENTED, + ) + await response(request.scope, request.receive, send) async def _handle_delete_request(self, request: Request, send: Send) -> None: - pass + """ + Handle DELETE requests for explicit session termination + + Args: + request: The HTTP request + send: ASGI send function + """ + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + if not await self._validate_session(request, send): + return + # TODO : Implement session termination logic - async def _handle_unsupported_request(self, send: Send) -> None: - pass + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """ + Handle unsupported HTTP methods + + Args: + request: The HTTP request + send: ASGI send function + """ + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = Response( + "Method Not Allowed", + status_code=HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) async def _validate_session(self, request: Request, send: Send) -> bool: - # TODO + """ + Validate the session ID in the request. + + Args: + request: The HTTP request + send: ASGI send function + + Returns: + bool: True if session is valid, False otherwise + """ + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + return True @asynccontextmanager diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py new file mode 100644 index 00000000..6296f22c --- /dev/null +++ b/tests/server/test_streamableHttp.py @@ -0,0 +1,378 @@ +""" +Tests for the StreamableHTTP server transport validation. + +This file contains tests for request validation in the StreamableHTTP transport. +""" + +import socket +import time +from collections.abc import Generator +from multiprocessing import Process + +import anyio +import pytest +import requests +import uvicorn +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) +from mcp.types import JSONRPCMessage + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" + + +# App handler class for testing validation (not a pytest test class) +class StreamableAppHandler: + def __init__(self, session_id=None): + self.transport = StreamableHTTPServerTransport(mcp_session_id=session_id) + self.started = False + self.read_stream = None + self.write_stream = None + + async def startup(self): + """Initialize the transport streams.""" + # Create real memory streams to satisfy type checking + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Assign the streams to the transport + self.transport._read_stream_writer = read_stream_writer + self.transport._write_stream_reader = write_stream_reader + + # Store the streams so they don't get garbage collected + self.read_stream = read_stream + self.write_stream = write_stream + + self.started = True + print("Transport streams initialized") + + async def handle_request(self, request: Request): + """Handle incoming requests by validating and responding.""" + # Make sure transport is initialized + if not self.started: + await self.startup() + + # Let the transport handle the request validation and response + try: + await self.transport.handle_request( + request.scope, request.receive, request._send + ) + except Exception as e: + print(f"Error handling request: {e}") + # Make sure we provide an error response + response = Response( + status_code=500, + content=f"Server error: {str(e)}", + media_type="text/plain", + ) + await response(request.scope, request.receive, request._send) + + +@pytest.fixture +def server_port() -> int: + """Find an available port for the test server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + """Get the URL for the test server.""" + return f"http://127.0.0.1:{server_port}" + + +def create_app(session_id=None) -> Starlette: + """Create a Starlette application for testing.""" + # Create our test app handler + app_handler = StreamableAppHandler(session_id=session_id) + + # Define a startup event to ensure the transport is initialized + async def on_startup(): + """Initialize the transport on application startup.""" + print("Initializing transport streams...") + await app_handler.startup() + app_handler.started = True + print("Transport initialized") + + app = Starlette( + debug=True, # Enable debug mode for better error messages + routes=[ + Route( + "/mcp", + endpoint=app_handler.handle_request, + methods=["GET", "POST", "DELETE"], + ), + ], + on_startup=[on_startup], + ) + + return app + + +def run_server(port: int, session_id=None) -> None: + """Run the test server.""" + print(f"Starting test server on port {port} with session_id={session_id}") + + # Create app with simpler configuration + app = create_app(session_id) + + # Configure to use a single worker and simpler settings + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="info", # Use info to see startup messages + limit_concurrency=10, + timeout_keep_alive=2, + access_log=False, + ) + + # Start the server + server = uvicorn.Server(config=config) + + # This is important to catch exceptions and prevent test hangs + try: + print("Server starting...") + server.run() + except Exception as e: + print(f"ERROR: Server failed to run: {e}") + import traceback + + traceback.print_exc() + + print("Server shutdown") + + +@pytest.fixture +def basic_server(server_port: int) -> Generator[None, None, None]: + """Start a basic server without session ID.""" + # Start server process + process = Process(target=run_server, kwargs={"port": server_port}, daemon=True) + process.start() + + # Wait for server to start + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + process.terminate() + process.join(timeout=1) + if process.is_alive(): + process.kill() + + +@pytest.fixture +def session_server(server_port: int) -> Generator[str, None, None]: + """Start a server with session ID.""" + # Start server process + process = Process( + target=run_server, + kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, + daemon=True, + ) + process.start() + + # Wait for server to start + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield TEST_SESSION_ID + + # Clean up + process.terminate() + process.join(timeout=1) + if process.is_alive(): + process.kill() + + +# Basic request validation tests +def test_accept_header_validation(basic_server, server_url): + """Test that Accept header is properly validated.""" + # Test without Accept header + response = requests.post( + f"{server_url}/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with only application/json + response = requests.post( + f"{server_url}/mcp", + headers={"Accept": "application/json", "Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + + # Test with only text/event-stream + response = requests.post( + f"{server_url}/mcp", + headers={"Accept": "text/event-stream", "Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + + +def test_content_type_validation(basic_server, server_url): + """Test that Content-Type header is properly validated.""" + # Test with incorrect Content-Type + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + data="This is not JSON", + ) + assert response.status_code == 415 + assert "Unsupported Media Type" in response.text + + +def test_json_validation(basic_server, server_url): + """Test that JSON content is properly validated.""" + # Test with invalid JSON + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + data="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + # Test with valid JSON but invalid JSON-RPC + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +def test_method_not_allowed(basic_server, server_url): + """Test that unsupported HTTP methods are rejected.""" + # Test with unsupported method (PUT) + response = requests.put( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +def test_get_request_validation(basic_server, server_url): + """Test GET request validation for SSE streams.""" + # Test GET without Accept header + response = requests.get(f"{server_url}/mcp") + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test GET with wrong Accept header + response = requests.get( + f"{server_url}/mcp", + headers={"Accept": "application/json"}, + ) + assert response.status_code == 406 + + +def test_session_validation(session_server, server_url): + """Test session ID validation.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test with invalid session ID + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: "invalid-session-id", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 404 + assert "Invalid or expired session ID" in response.text + + +def test_delete_request(session_server, server_url): + """Test DELETE request for session termination.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test with invalid session ID + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: "invalid-session-id"}, + ) + assert response.status_code == 404 + assert "Invalid or expired session ID" in response.text + + +def test_delete_without_session_support(basic_server, server_url): + """Test DELETE request when server doesn't support sessions.""" + # Server without session support should reject DELETE + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 405 + assert "Method Not Allowed" in response.text From 27bc01ec4bb63f398316ed7648dbc99108e0176f Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 14:30:35 +0100 Subject: [PATCH 03/19] session management --- .../mcp_simple_streamablehttp/server.py | 78 ++++++++++++------- src/mcp/server/streamableHttp.py | 21 ++++- tests/server/test_streamableHttp.py | 60 +++++++++++++- 3 files changed, 128 insertions(+), 31 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 19a83790..3dc972b7 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,13 +1,19 @@ import contextlib import logging +from http import HTTPStatus from uuid import uuid4 import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import StreamableHTTPServerTransport +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount # Configure logging @@ -116,40 +122,56 @@ async def list_tools() -> list[types.Tool]: ) ] - # Create a Streamable HTTP transport - http_transport = StreamableHTTPServerTransport( - mcp_session_id=uuid4().hex, - ) - # We need to store the server instances between requests server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() # ASGI handler for streamable HTTP connections async def handle_streamable_http(scope, receive, send): - if http_transport.mcp_session_id in server_instances: + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") - await http_transport.handle_request(scope, receive, send) + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + # try to establish new session + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + async with session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + ) + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) else: - # Start new server instance for this session - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - async def run_server(): - await app.run( - read_stream, write_stream, app.create_initialization_options() - ) - - if not task_group: - raise RuntimeError("Task group is not initialized") - - task_group.start_soon(run_server) - - # For initialization requests, store the server reference - if http_transport.mcp_session_id: - server_instances[http_transport.mcp_session_id] = True - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index e65c6c46..8b1498d7 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -9,6 +9,7 @@ import json import logging +import re from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from http import HTTPStatus @@ -42,6 +43,10 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + class StreamableHTTPServerTransport: """ @@ -65,8 +70,20 @@ def __init__( Initialize a new StreamableHTTP server transport. Args: - mcp_session_id: Optional session identifier for this connection + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + + Raises: + ValueError: If the session ID contains invalid characters. """ + if mcp_session_id is not None and ( + not SESSION_ID_PATTERN.match(mcp_session_id) or + SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + ): + raise ValueError( + "Session ID must only contain visible ASCII characters (0x21-0x7E)" + ) + self.mcp_session_id = mcp_session_id self._request_streams = {} @@ -439,7 +456,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: return if not await self._validate_session(request, send): return - # TODO : Implement session termination logic + # TODO : Implement session termination logic async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """ diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 6296f22c..bf0128d1 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -13,7 +13,6 @@ import pytest import requests import uvicorn -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -21,6 +20,7 @@ from mcp.server.streamableHttp import ( MCP_SESSION_ID_HEADER, + SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) from mcp.types import JSONRPCMessage @@ -376,3 +376,61 @@ def test_delete_without_session_support(basic_server, server_url): response = requests.delete(f"{server_url}/mcp") assert response.status_code == 405 assert "Method Not Allowed" in response.text + + +def test_session_id_pattern(): + """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) + valid_session_ids = [ + "test-session-id", + "1234567890", + "session!@#$%^&*()_+-=[]{}|;:,.<>?/", + "~`", + ] + + for session_id in valid_session_ids: + assert SESSION_ID_PATTERN.match(session_id) is not None + # Ensure fullmatch matches too (whole string) + assert SESSION_ID_PATTERN.fullmatch(session_id) is not None + + # Invalid session IDs + invalid_session_ids = [ + "", # Empty string + " test", # Space (0x20) + "test\t", # Tab + "test\n", # Newline + "test\r", # Carriage return + "test" + chr(0x7F), # DEL character + "test" + chr(0x80), # Extended ASCII + "test" + chr(0x00), # Null character + "test" + chr(0x20), # Space (0x20) + ] + + for session_id in invalid_session_ids: + # For invalid IDs, either match will fail or fullmatch will fail + if SESSION_ID_PATTERN.match(session_id) is not None: + # If match succeeds, fullmatch should fail (partial match case) + assert SESSION_ID_PATTERN.fullmatch(session_id) is None + + +def test_streamable_http_transport_init_validation(): + """Test that StreamableHTTPServerTransport validates session ID on initialization.""" + # Valid session ID should initialize without errors + valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") + assert valid_transport.mcp_session_id == "valid-id" + + # None should be accepted + none_transport = StreamableHTTPServerTransport(mcp_session_id=None) + assert none_transport.mcp_session_id is None + + # Invalid session ID should raise ValueError + with pytest.raises(ValueError) as excinfo: + StreamableHTTPServerTransport(mcp_session_id="invalid id with space") + assert "Session ID must only contain visible ASCII characters" in str(excinfo.value) + + # Test with control characters + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\nid") + + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\n") From 3c4cf109c2534306105ed7d656bcfe5eacd0d2c0 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 17:10:02 +0100 Subject: [PATCH 04/19] terminations of a session --- src/mcp/server/streamableHttp.py | 62 ++++- tests/server/test_streamableHttp.py | 351 ++++++++++++++++++---------- 2 files changed, 281 insertions(+), 132 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 8b1498d7..0dd73e50 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -77,8 +77,8 @@ def __init__( ValueError: If the session ID contains invalid characters. """ if mcp_session_id is not None and ( - not SESSION_ID_PATTERN.match(mcp_session_id) or - SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + not SESSION_ID_PATTERN.match(mcp_session_id) + or SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None ): raise ValueError( "Session ID must only contain visible ASCII characters (0x21-0x7E)" @@ -86,6 +86,7 @@ def __init__( self.mcp_session_id = mcp_session_id self._request_streams = {} + self._terminated = False def _create_error_response( self, @@ -126,6 +127,14 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No send: ASGI send function """ request = Request(scope, receive) + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return if request.method == "POST": await self._handle_post_request(scope, request, receive, send) @@ -192,7 +201,6 @@ async def _handle_post_request( raise ValueError( "No read stream writer available. Ensure connect() is called first." ) - return try: # Check Accept headers has_json, has_sse = self._check_accept_headers(request) @@ -417,7 +425,6 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate session ID if server has one if not await self._validate_session(request, send): return - # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -454,9 +461,46 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) return + if not await self._validate_session(request, send): return - # TODO : Implement session termination logic + + # Terminate the session + self._terminate_session() + + # Return success response + response = self._create_error_response( + "Session terminated", + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + def _terminate_session(self) -> None: + """ + Terminate the current session, closing all streams and marking as terminated. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams (synchronously) + for key in request_stream_keys: + try: + # Get the stream + stream = self._request_streams.get(key) + if stream: + # We must use close() here, not aclose() since this is a sync method + stream.close() + except Exception as e: + logger.debug(f"Error closing stream {key} during termination: {e}") + + # Clear the request streams dictionary immediately + self._request_streams.clear() async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """ @@ -599,10 +643,16 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - # Clean up any remaining request streams for stream in list(self._request_streams.values()): try: await stream.aclose() except Exception: pass self._request_streams.clear() + # Clean up read/write streams + if self._read_stream_writer: + try: + await self._read_stream_writer.aclose() + except Exception: + pass + self._read_stream_writer = None diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index bf0128d1..eb7a5390 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -4,81 +4,87 @@ This file contains tests for request validation in the StreamableHTTP transport. """ +import multiprocessing import socket import time -from collections.abc import Generator -from multiprocessing import Process - +from collections.abc import AsyncGenerator, Generator +from http import HTTPStatus +from uuid import uuid4 +import contextlib import anyio import pytest import requests import uvicorn +from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Route +from starlette.routing import Mount, Route +from mcp.server import Server from mcp.server.streamableHttp import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) -from mcp.types import JSONRPCMessage +from mcp.shared.exceptions import McpError +from mcp.types import ( + EmptyResult, + ErrorData, + JSONRPCMessage, + TextContent, + TextResourceContents, + Tool, +) # Test constants SERVER_NAME = "test_streamable_http_server" TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +# Test server implementation that follows MCP protocol +class ServerTest(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] -# App handler class for testing validation (not a pytest test class) -class StreamableAppHandler: - def __init__(self, session_id=None): - self.transport = StreamableHTTPServerTransport(mcp_session_id=session_id) - self.started = False - self.read_stream = None - self.write_stream = None - - async def startup(self): - """Initialize the transport streams.""" - # Create real memory streams to satisfy type checking - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage - ](0) - - # Assign the streams to the transport - self.transport._read_stream_writer = read_stream_writer - self.transport._write_stream_reader = write_stream_reader - - # Store the streams so they don't get garbage collected - self.read_stream = read_stream - self.write_stream = write_stream - - self.started = True - print("Transport streams initialized") - - async def handle_request(self, request: Request): - """Handle incoming requests by validating and responding.""" - # Make sure transport is initialized - if not self.started: - await self.startup() - - # Let the transport handle the request validation and response - try: - await self.transport.handle_request( - request.scope, request.receive, request._send - ) - except Exception as e: - print(f"Error handling request: {e}") - # Make sure we provide an error response - response = Response( - status_code=500, - content=f"Server error: {str(e)}", - media_type="text/plain", - ) - await response(request.scope, request.receive, request._send) + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + return [TextContent(type="text", text=f"Called {name}")] @pytest.fixture @@ -96,28 +102,93 @@ def server_url(server_port: int) -> str: def create_app(session_id=None) -> Starlette: - """Create a Starlette application for testing.""" - # Create our test app handler - app_handler = StreamableAppHandler(session_id=session_id) - - # Define a startup event to ensure the transport is initialized - async def on_startup(): - """Initialize the transport on application startup.""" - print("Initializing transport streams...") - await app_handler.startup() - app_handler.started = True - print("Transport initialized") + """Create a Starlette application for testing that matches the example server.""" + # Create server instance + server = ServerTest() + + # Store the server instances between requests for session management + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + # Task group for running server instances + task_group = None + + @contextlib.asynccontextmanager + async def lifespan(app): + """Application lifespan context manager for managing task group.""" + nonlocal task_group + + async with anyio.create_task_group() as tg: + task_group = tg + print("Application started, task group initialized!") + try: + yield + finally: + print("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + print("Resources cleaned up successfully.") + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Use existing transport if session ID matches + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + print("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + elif session_id is None or request_mcp_session_id is None: + async with session_creation_lock: + # For tests with fixed session ID + new_session_id = session_id if session_id else uuid4().hex + + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + ) + + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + # Create an ASGI application app = Starlette( - debug=True, # Enable debug mode for better error messages + debug=True, routes=[ - Route( - "/mcp", - endpoint=app_handler.handle_request, - methods=["GET", "POST", "DELETE"], - ), + Mount("/mcp", app=handle_streamable_http), ], - on_startup=[on_startup], + lifespan=lifespan, ) return app @@ -127,17 +198,15 @@ def run_server(port: int, session_id=None) -> None: """Run the test server.""" print(f"Starting test server on port {port} with session_id={session_id}") - # Create app with simpler configuration app = create_app(session_id) - - # Configure to use a single worker and simpler settings + # Configure server config = uvicorn.Config( app=app, host="127.0.0.1", port=port, - log_level="info", # Use info to see startup messages + log_level="info", limit_concurrency=10, - timeout_keep_alive=2, + timeout_keep_alive=5, access_log=False, ) @@ -161,7 +230,9 @@ def run_server(port: int, session_id=None) -> None: def basic_server(server_port: int) -> Generator[None, None, None]: """Start a basic server without session ID.""" # Start server process - process = Process(target=run_server, kwargs={"port": server_port}, daemon=True) + process = multiprocessing.Process( + target=run_server, kwargs={"port": server_port}, daemon=True + ) process.start() # Wait for server to start @@ -191,7 +262,7 @@ def basic_server(server_port: int) -> Generator[None, None, None]: def session_server(server_port: int) -> Generator[str, None, None]: """Start a server with session ID.""" # Start server process - process = Process( + process = multiprocessing.Process( target=run_server, kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, daemon=True, @@ -309,17 +380,20 @@ def test_method_not_allowed(basic_server, server_url): def test_get_request_validation(basic_server, server_url): """Test GET request validation for SSE streams.""" - # Test GET without Accept header - response = requests.get(f"{server_url}/mcp") - assert response.status_code == 406 - assert "Not Acceptable" in response.text - # Test GET with wrong Accept header - response = requests.get( + response = requests.post( f"{server_url}/mcp", - headers={"Accept": "application/json"}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, ) + # Test GET without Accept header + assert response.status_code == 200 + response = requests.get(f"{server_url}/mcp") assert response.status_code == 406 + assert "Not Acceptable" in response.text def test_session_validation(session_server, server_url): @@ -338,45 +412,6 @@ def test_session_validation(session_server, server_url): assert response.status_code == 400 assert "Missing session ID" in response.text - # Test with invalid session ID - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: "invalid-session-id", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 404 - assert "Invalid or expired session ID" in response.text - - -def test_delete_request(session_server, server_url): - """Test DELETE request for session termination.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 400 - assert "Missing session ID" in response.text - - # Test with invalid session ID - response = requests.delete( - f"{server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: "invalid-session-id"}, - ) - assert response.status_code == 404 - assert "Invalid or expired session ID" in response.text - - -def test_delete_without_session_support(basic_server, server_url): - """Test DELETE request when server doesn't support sessions.""" - # Server without session support should reject DELETE - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - def test_session_id_pattern(): """Test that SESSION_ID_PATTERN correctly validates session IDs.""" @@ -414,7 +449,7 @@ def test_session_id_pattern(): def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on initialization.""" + """Test that StreamableHTTPServerTransport validates session ID on init.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -434,3 +469,67 @@ def test_streamable_http_transport_init_validation(): with pytest.raises(ValueError): StreamableHTTPServerTransport(mcp_session_id="test\n") + + +def test_delete_request(session_server, server_url): + """Test DELETE request for session termination.""" + session_id = session_server + + # First, send an initialize request to properly initialize the server + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Test without session ID + response = requests.delete(f"{server_url}/mcp") + assert response.status_code == 400 + assert "Missing session ID" in response.text + + # Test valid session termination + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + # assert response.status_code == 200 + assert "Session terminated" in response.text + + +def test_session_termination(session_server, server_url): + """Test session termination via DELETE and subsequent request handling.""" + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = session_server + response = requests.delete( + f"{server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + assert response.status_code == 200 + assert "Session terminated" in response.text + + # Try to use the terminated session + response = requests.post( + f"{server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text From bce74b3e148038a38324d730d5506aac055b6e05 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 19:25:21 +0100 Subject: [PATCH 05/19] fix cleaning up --- src/mcp/server/streamableHttp.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 0dd73e50..a8ee6f9b 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -649,10 +649,3 @@ async def message_router(): except Exception: pass self._request_streams.clear() - # Clean up read/write streams - if self._read_stream_writer: - try: - await self._read_stream_writer.aclose() - except Exception: - pass - self._read_stream_writer = None From 201157912d1d588283ed91fd05567b8b003ef891 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 20:51:46 +0100 Subject: [PATCH 06/19] add happy path test --- tests/server/test_streamableHttp.py | 64 +++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index eb7a5390..51b88c0c 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -4,13 +4,14 @@ This file contains tests for request validation in the StreamableHTTP transport. """ +import contextlib import multiprocessing import socket import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import Generator from http import HTTPStatus from uuid import uuid4 -import contextlib + import anyio import pytest import requests @@ -19,7 +20,7 @@ from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Mount, Route +from starlette.routing import Mount from mcp.server import Server from mcp.server.streamableHttp import ( @@ -29,11 +30,8 @@ ) from mcp.shared.exceptions import McpError from mcp.types import ( - EmptyResult, ErrorData, - JSONRPCMessage, TextContent, - TextResourceContents, Tool, ) @@ -106,11 +104,9 @@ def create_app(session_id=None) -> Starlette: # Create server instance server = ServerTest() - # Store the server instances between requests for session management server_instances = {} # Lock to prevent race conditions when creating new sessions session_creation_lock = anyio.Lock() - # Task group for running server instances task_group = None @contextlib.asynccontextmanager @@ -130,7 +126,6 @@ async def lifespan(app): task_group = None print("Resources cleaned up successfully.") - # ASGI handler for streamable HTTP connections async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) @@ -141,12 +136,11 @@ async def handle_streamable_http(scope, receive, send): and request_mcp_session_id in server_instances ): transport = server_instances[request_mcp_session_id] - print("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) - elif session_id is None or request_mcp_session_id is None: + elif request_mcp_session_id is None: async with session_creation_lock: - # For tests with fixed session ID - new_session_id = session_id if session_id else uuid4().hex + new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, @@ -156,11 +150,14 @@ async def handle_streamable_http(scope, receive, send): read_stream, write_stream = streams async def run_server(): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) + try: + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + except Exception as e: + print(f"Server exception: {e}") if task_group is None: response = Response( @@ -533,3 +530,34 @@ def test_session_termination(session_server, server_url): ) assert response.status_code == 404 assert "Session has been terminated" in response.text + + +def test_response(basic_server, server_url): + """Test response handling for a valid request.""" + mcp_url = f"{server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + + # Try to use the terminated session + tools_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + stream=True, # Important for SSE + ) + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" From 2cebf087d6b5e965198b8a3bf57248cf8aa7aa31 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 21:58:21 +0100 Subject: [PATCH 07/19] tests --- tests/server/test_streamableHttp.py | 70 +++-------------------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 51b88c0c..b375fdc9 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -301,22 +301,6 @@ def test_accept_header_validation(basic_server, server_url): assert response.status_code == 406 assert "Not Acceptable" in response.text - # Test with only application/json - response = requests.post( - f"{server_url}/mcp", - headers={"Accept": "application/json", "Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - - # Test with only text/event-stream - response = requests.post( - f"{server_url}/mcp", - headers={"Accept": "text/event-stream", "Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - def test_content_type_validation(basic_server, server_url): """Test that Content-Type header is properly validated.""" @@ -347,6 +331,9 @@ def test_json_validation(basic_server, server_url): assert response.status_code == 400 assert "Parse error" in response.text + +def test_json_parsing(basic_server, server_url): + """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( f"{server_url}/mcp", @@ -375,24 +362,6 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_get_request_validation(basic_server, server_url): - """Test GET request validation for SSE streams.""" - - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - # Test GET without Accept header - assert response.status_code == 200 - response = requests.get(f"{server_url}/mcp") - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - def test_session_validation(session_server, server_url): """Test session ID validation.""" # session_id not used directly in this test @@ -468,36 +437,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_delete_request(session_server, server_url): - """Test DELETE request for session termination.""" - session_id = session_server - - # First, send an initialize request to properly initialize the server - response = requests.post( - f"{server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - - # Test without session ID - response = requests.delete(f"{server_url}/mcp") - assert response.status_code == 400 - assert "Missing session ID" in response.text - - # Test valid session termination - response = requests.delete( - f"{server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: session_id}, - ) - # assert response.status_code == 200 - assert "Session terminated" in response.text - - -def test_session_termination(session_server, server_url): +def test_session_termination(basic_server, server_url): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{server_url}/mcp", @@ -510,7 +450,7 @@ def test_session_termination(session_server, server_url): assert response.status_code == 200 # Now terminate the session - session_id = session_server + session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( f"{server_url}/mcp", headers={MCP_SESSION_ID_HEADER: session_id}, From 6c9c320a38654c9145fe326d9e308e2893e8d9e3 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 20 Apr 2025 22:19:26 +0100 Subject: [PATCH 08/19] json mode --- .../servers/simple-streamablehttp/README.md | 3 + .../mcp_simple_streamablehttp/server.py | 9 +- src/mcp/server/streamableHttp.py | 295 ++++++++++++------ tests/server/test_streamableHttp.py | 56 +++- 4 files changed, 258 insertions(+), 105 deletions(-) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index aa5e707a..5125c3eb 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -20,6 +20,9 @@ uv run mcp-simple-streamablehttp --port 3000 # Custom logging level uv run mcp-simple-streamablehttp --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp --json-response ``` The server exposes a tool named "start-notification-stream" that accepts three arguments: diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 3dc972b7..c39a3720 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -48,9 +48,16 @@ async def lifespan(app): default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) def main( port: int, log_level: str, + json_response: bool, ) -> int: # Configure logging logging.basicConfig( @@ -145,7 +152,7 @@ async def handle_streamable_http(scope, receive, send): async with session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, + mcp_session_id=new_session_id, is_json_response_enabled=json_response ) async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index a8ee6f9b..b6ef396c 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -54,6 +54,7 @@ class StreamableHTTPServerTransport: Handles POST requests containing JSON-RPC messages and provides Server-Sent Events (SSE) responses for streaming communication. + When configured, can also return JSON responses instead of SSE streams. """ # Server notification streams for POST requests as well as standalone SSE stream @@ -65,6 +66,7 @@ class StreamableHTTPServerTransport: def __init__( self, mcp_session_id: str | None, + is_json_response_enabled: bool = False, ): """ Initialize a new StreamableHTTP server transport. @@ -72,6 +74,8 @@ def __init__( Args: mcp_session_id: Optional session identifier for this connection. Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. Raises: ValueError: If the session ID contains invalid characters. @@ -85,6 +89,7 @@ def __init__( ) self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled self._request_streams = {} self._terminated = False @@ -110,6 +115,36 @@ def _create_error_response( headers=response_headers, ) + def _create_json_response( + self, + response_message: JSONRPCMessage, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """ + Create a JSON response from a JSONRPCMessage. + + Args: + response_message: The JSON-RPC message to include in the response + status_code: HTTP status code (default: 200 OK) + headers: Additional headers to include + + Returns: + A Starlette Response object with the JSON-RPC message + """ + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + def _get_session_id(self, request: Request) -> str | None: """ Extract the session ID from request headers. @@ -303,105 +338,183 @@ async def _handle_post_request( return - # For requests, set up an SSE stream for the response + # For requests, determine whether to return JSON or set up SSE stream if is_request: - # Set up headers - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } + if self.is_json_response_enabled: + # JSON response mode - create a response future + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + + if not request_id: + # Should not happen for valid JSONRPCRequest, but handle just in case + response = self._create_error_response( + "Invalid Request: Missing request ID", + HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + return - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) - ) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) - async def sse_writer(): - try: - # Create a request-specific message stream for this POST request - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer - # Get the request ID from the incoming request message - request_id = None - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - # Register this stream for the request ID - if request_id: - self._request_streams[request_id] = ( - request_stream_writer - ) + # Process the message + await writer.send(message) - async with sse_stream_writer, request_stream_reader: - # Process messages from the request-specific stream - async for received_message in request_stream_reader: - # Send the message via SSE - related_request_id = None - - if isinstance( - received_message.root, JSONRPCNotification - ): - # Get related_request_id from params - params = received_message.root.params - if params and "related_request_id" in params: - related_request_id = params.get( - "related_request_id" - ) - logger.debug( - f"NOTIFICATION: {related_request_id}, " - f"{params.get('data')}" - ) - - # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance(received_message.root, JSONRPCResponse): - if request_id: - self._request_streams.pop(request_id, None) - break + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance(received_message.root, JSONRPCResponse): + response_message = received_message + break + # For notifications, we need to keep waiting for the actual response + elif isinstance(received_message.root, JSONRPCNotification): + # Just process it and continue waiting + logger.debug( + f"Received notification while waiting for response: {received_message.root.method}" + ) + continue + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error("No response message received before stream closed") + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) except Exception as e: - logger.exception(f"Error in SSE writer: {e}") + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) finally: - logger.debug("Closing SSE writer") - # TODO - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=sse_writer, - headers=headers, - ) - - # Extract the request ID outside the try block for proper scope - outer_request_id = None - if isinstance(message.root, JSONRPCRequest): - outer_request_id = str(message.root.id) + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + try: + # Create a request-specific message stream for this POST request + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Get the request ID from the incoming request message + request_id = None + if isinstance(message.root, JSONRPCRequest): + request_id = str(message.root.id) + # Register this stream for the request ID + if request_id: + self._request_streams[request_id] = ( + request_stream_writer + ) + + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Send the message via SSE + related_request_id = None + + if isinstance( + received_message.root, JSONRPCNotification + ): + # Get related_request_id from params + params = received_message.root.params + if params and "related_request_id" in params: + related_request_id = params.get( + "related_request_id" + ) + logger.debug( + f"NOTIFICATION: {related_request_id}, " + f"{params.get('data')}" + ) + + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, JSONRPCResponse + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # TODO + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Extract the request ID outside the try block for proper scope + outer_request_id = None + if isinstance(message.root, JSONRPCRequest): + outer_request_id = str(message.root.id) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) - # Start the SSE response (this will send headers immediately) - try: - # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) - - # Then send the message to be processed by the server - await writer.send(message) - except Exception: - logger.exception("SSE response error") - # Make sure to clean up the request stream if something goes wrong - if outer_request_id and outer_request_id in self._request_streams: - self._request_streams.pop(outer_request_id, None) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Make sure to clean up the request stream if something goes wrong + if ( + outer_request_id + and outer_request_id in self._request_streams + ): + self._request_streams.pop(outer_request_id, None) except Exception as err: logger.exception("Error handling POST request") diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index b375fdc9..42c416c5 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -99,8 +99,13 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -def create_app(session_id=None) -> Starlette: - """Create a Starlette application for testing that matches the example server.""" +def create_app(session_id=None, is_json_response_enabled=False) -> Starlette: + """Create a Starlette application for testing that matches the example server. + + Args: + session_id: Optional session ID to use for the server. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ # Create server instance server = ServerTest() @@ -144,6 +149,7 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, + is_json_response_enabled=is_json_response_enabled, ) async with http_transport.connect() as streams: @@ -191,11 +197,20 @@ async def run_server(): return app -def run_server(port: int, session_id=None) -> None: - """Run the test server.""" - print(f"Starting test server on port {port} with session_id={session_id}") +def run_server(port: int, session_id=None, is_json_response_enabled=False) -> None: + """Run the test server. + + Args: + port: Port to listen on. + session_id: Optional session ID to use for the server. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + print( + f"Starting test server on port {port} with " + f"session_id={session_id}, json_enabled={is_json_response_enabled}" + ) - app = create_app(session_id) + app = create_app(session_id, is_json_response_enabled) # Configure server config = uvicorn.Config( app=app, @@ -256,12 +271,12 @@ def basic_server(server_port: int) -> Generator[None, None, None]: @pytest.fixture -def session_server(server_port: int) -> Generator[str, None, None]: - """Start a server with session ID.""" - # Start server process +def json_response_server(server_port: int) -> Generator[None, None, None]: + """Start a server with JSON response enabled.""" + # Start server process with is_json_response_enabled=True process = multiprocessing.Process( target=run_server, - kwargs={"port": server_port, "session_id": TEST_SESSION_ID}, + kwargs={"port": server_port, "is_json_response_enabled": True}, daemon=True, ) process.start() @@ -280,7 +295,7 @@ def session_server(server_port: int) -> Generator[str, None, None]: else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - yield TEST_SESSION_ID + yield # Clean up process.terminate() @@ -362,7 +377,7 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_session_validation(session_server, server_url): +def test_session_validation(basic_server, server_url): """Test session ID validation.""" # session_id not used directly in this test @@ -497,7 +512,22 @@ def test_response(basic_server, server_url): MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier }, json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, # Important for SSE + stream=True, ) assert tools_response.status_code == 200 assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +def test_json_response(json_response_server, server_url): + """Test response handling when is_json_response_enabled is True.""" + mcp_url = f"{server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" From ede8cde91c938db4a64bcccb78387bc79e713d86 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 10:40:52 +0100 Subject: [PATCH 09/19] clean up --- .../mcp_simple_streamablehttp/__main__.py | 2 +- .../mcp_simple_streamablehttp/server.py | 3 +- src/mcp/server/streamableHttp.py | 192 +++++------------- uv.lock | 1 - 4 files changed, 58 insertions(+), 140 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py index a6876bf9..f5f6e402 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -1,4 +1,4 @@ from .server import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index c39a3720..88249baf 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -152,7 +152,8 @@ async def handle_streamable_http(scope, receive, send): async with session_creation_lock: new_session_id = uuid4().hex http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, is_json_response_enabled=json_response + mcp_session_id=new_session_id, + is_json_response_enabled=json_response, ) async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index b6ef396c..2bc528b0 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -52,14 +52,15 @@ class StreamableHTTPServerTransport: """ HTTP server transport with event streaming support for MCP. - Handles POST requests containing JSON-RPC messages and provides - Server-Sent Events (SSE) responses for streaming communication. - When configured, can also return JSON responses instead of SSE streams. + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None - _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + None + ) + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None # Dictionary to track request-specific message streams _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] @@ -67,7 +68,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, - ): + ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -80,9 +81,8 @@ def __init__( Raises: ValueError: If the session ID contains invalid characters. """ - if mcp_session_id is not None and ( - not SESSION_ID_PATTERN.match(mcp_session_id) - or SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( + mcp_session_id ): raise ValueError( "Session ID must only contain visible ASCII characters (0x21-0x7E)" @@ -93,15 +93,13 @@ def __init__( self._request_streams = {} self._terminated = False - def _create_error_response( + def _create_server_response( self, message: str, status_code: HTTPStatus, headers: dict[str, str] | None = None, ) -> Response: - """ - Create a standardized error response. - """ + """Create a standardized server response.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -121,17 +119,7 @@ def _create_json_response( status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: - """ - Create a JSON response from a JSONRPCMessage. - - Args: - response_message: The JSON-RPC message to include in the response - status_code: HTTP status code (default: 200 OK) - headers: Additional headers to include - - Returns: - A Starlette Response object with the JSON-RPC message - """ + """Create a JSON response from a JSONRPCMessage""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -146,25 +134,15 @@ def _create_json_response( ) def _get_session_id(self, request: Request) -> str | None: - """ - Extract the session ID from request headers. - """ + """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """ - ASGI application entry point that handles all HTTP requests - - Args: - stream_id: Unique identifier for this stream - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ + """Application entry point that handles all HTTP requests""" request = Request(scope, receive) if self._terminated: # If the session has been terminated, return 404 Not Found - response = self._create_error_response( + response = self._create_server_response( "Not Found: Session has been terminated", HTTPStatus.NOT_FOUND, ) @@ -181,15 +159,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """ - Check if the request accepts the required media types. - - Args: - request: The HTTP request - - Returns: - Tuple of (has_json, has_sse) indicating whether each media type is accepted - """ + """Check if the request accepts the required media types.""" accept_header = request.headers.get("accept", "") accept_types = [media_type.strip() for media_type in accept_header.split(",")] @@ -203,15 +173,7 @@ def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: return has_json, has_sse def _check_content_type(self, request: Request) -> bool: - """ - Check if the request has the correct Content-Type. - - Args: - request: The HTTP request - - Returns: - True if Content-Type is acceptable, False otherwise - """ + """Check if the request has the correct Content-Type.""" content_type = request.headers.get("content-type", "") content_type_parts = [ part.strip() for part in content_type.split(";")[0].split(",") @@ -222,15 +184,7 @@ def _check_content_type(self, request: Request) -> bool: async def _handle_post_request( self, scope: Scope, request: Request, receive: Receive, send: Send ) -> None: - """ - Handles POST requests containing JSON-RPC messages - - Args: - scope: ASGI scope - request: Starlette Request object - receive: ASGI receive function - send: ASGI send function - """ + """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: raise ValueError( @@ -240,7 +194,7 @@ async def _handle_post_request( # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): - response = self._create_error_response( + response = self._create_server_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" @@ -252,7 +206,7 @@ async def _handle_post_request( # Validate Content-Type if not self._check_content_type(request): - response = self._create_error_response( + response = self._create_server_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) @@ -262,7 +216,7 @@ async def _handle_post_request( # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = self._create_error_response( + response = self._create_server_response( "Payload Too Large: Message exceeds maximum size", HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) @@ -272,32 +226,23 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_error_response( + response = self._create_server_response( f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return - message = None try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = self._create_error_response( + response = self._create_server_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) return - if not message: - response = self._create_error_response( - "Invalid Request: Message is empty", - HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) - return - # Check if this is an initialization request is_initialization_request = ( isinstance(message.root, JSONRPCRequest) @@ -312,7 +257,7 @@ async def _handle_post_request( # If request has a session ID but doesn't match, return 404 if request_session_id and request_session_id != self.mcp_session_id: - response = self._create_error_response( + response = self._create_server_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -327,7 +272,7 @@ async def _handle_post_request( # For notifications and responses only, return 202 Accepted if not is_request: # Create response object and send it - response = self._create_error_response( + response = self._create_server_response( "Accepted", HTTPStatus.ACCEPTED, ) @@ -347,8 +292,8 @@ async def _handle_post_request( request_id = str(message.root.id) if not request_id: - # Should not happen for valid JSONRPCRequest, but handle just in case - response = self._create_error_response( + # Should not happen for valid JSONRPCRequest, but handle it + response = self._create_server_response( "Invalid Request: Missing request ID", HTTPStatus.BAD_REQUEST, ) @@ -370,20 +315,19 @@ async def _handle_post_request( # Process messages from the request-specific stream # We need to collect all messages until we get a response response_message = None - + # Use similar approach to SSE writer for consistency async for received_message in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance(received_message.root, JSONRPCResponse): response_message = received_message break - # For notifications, we need to keep waiting for the actual response + # For notifications, keep waiting for the actual response elif isinstance(received_message.root, JSONRPCNotification): # Just process it and continue waiting logger.debug( - f"Received notification while waiting for response: {received_message.root.method}" + f"Notification: {received_message.root.method}" ) - continue # At this point we should have a response if response_message: @@ -392,15 +336,17 @@ async def _handle_post_request( await response(scope, receive, send) else: # This shouldn't happen in normal operation - logger.error("No response message received before stream closed") - response = self._create_error_response( + logger.error( + "No response message received before stream closed" + ) + response = self._create_server_response( "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, ) await response(scope, receive, send) except Exception as e: logger.exception(f"Error processing JSON response: {e}") - response = self._create_error_response( + response = self._create_server_response( f"Error processing request: {str(e)}", HTTPStatus.INTERNAL_SERVER_ERROR, ) @@ -428,14 +374,14 @@ async def _handle_post_request( ) async def sse_writer(): + # Get the request ID from the incoming request message + request_id = None try: - # Create a request-specific message stream for this POST request + # Create a request-specific message stream for this POST request_stream_writer, request_stream_reader = ( anyio.create_memory_object_stream[JSONRPCMessage](0) ) - # Get the request ID from the incoming request message - request_id = None if isinstance(message.root, JSONRPCRequest): request_id = str(message.root.id) # Register this stream for the request ID @@ -485,7 +431,9 @@ async def sse_writer(): logger.exception(f"Error in SSE writer: {e}") finally: logger.debug("Closing SSE writer") - # TODO + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) # Create and start EventSourceResponse response = EventSourceResponse( @@ -509,7 +457,7 @@ async def sse_writer(): await writer.send(message) except Exception: logger.exception("SSE response error") - # Make sure to clean up the request stream if something goes wrong + # Clean up the request stream if something goes wrong if ( outer_request_id and outer_request_id in self._request_streams @@ -518,7 +466,7 @@ async def sse_writer(): except Exception as err: logger.exception("Error handling POST request") - response = self._create_error_response( + response = self._create_server_response( f"Error handling POST request: {err}", HTTPStatus.INTERNAL_SERVER_ERROR, ) @@ -528,13 +476,7 @@ async def sse_writer(): return async def _handle_get_request(self, request: Request, send: Send) -> None: - """ - Handle GET requests for SSE stream establishment - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle GET requests for SSE stream establishment.""" # Validate session ID if server has one if not await self._validate_session(request, send): return @@ -542,7 +484,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: _, has_sse = self._check_accept_headers(request) if not has_sse: - response = self._create_error_response( + response = self._create_server_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, ) @@ -551,24 +493,18 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # TODO: Implement SSE stream for GET requests # For now, return 501 Not Implemented - response = self._create_error_response( + response = self._create_server_response( "SSE stream from GET request not implemented yet", HTTPStatus.NOT_IMPLEMENTED, ) await response(request.scope, request.receive, send) async def _handle_delete_request(self, request: Request, send: Send) -> None: - """ - Handle DELETE requests for explicit session termination - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle DELETE requests for explicit session termination.""" # Validate session ID if not self.mcp_session_id: # If no session ID set, return Method Not Allowed - response = self._create_error_response( + response = self._create_server_response( "Method Not Allowed: Session termination not supported", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -581,16 +517,14 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # Terminate the session self._terminate_session() - # Return success response - response = self._create_error_response( + response = self._create_server_response( "Session terminated", HTTPStatus.OK, ) await response(request.scope, request.receive, send) def _terminate_session(self) -> None: - """ - Terminate the current session, closing all streams and marking as terminated. + """Terminate the current session, closing all streams. Once terminated, all requests with this session ID will receive 404 Not Found. """ @@ -616,13 +550,7 @@ def _terminate_session(self) -> None: self._request_streams.clear() async def _handle_unsupported_request(self, request: Request, send: Send) -> None: - """ - Handle unsupported HTTP methods - - Args: - request: The HTTP request - send: ASGI send function - """ + """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", @@ -638,16 +566,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non await response(request.scope, request.receive, send) async def _validate_session(self, request: Request, send: Send) -> bool: - """ - Validate the session ID in the request. - - Args: - request: The HTTP request - send: ASGI send function - - Returns: - bool: True if session is valid, False otherwise - """ + """Validate the session ID in the request.""" if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -657,7 +576,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If no session ID provided but required, return error if not request_session_id: - response = self._create_error_response( + response = self._create_server_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, ) @@ -666,7 +585,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If session ID doesn't match, return error if request_session_id != self.mcp_session_id: - response = self._create_error_response( + response = self._create_server_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -685,8 +604,7 @@ async def connect( ], None, ]: - """ - Context manager that provides read and write streams for a connection + """Context manager that provides read and write streams for a connection. Yields: Tuple of (read_stream, write_stream) for bidirectional communication diff --git a/uv.lock b/uv.lock index 65439e5c..6618ea36 100644 --- a/uv.lock +++ b/uv.lock @@ -487,7 +487,6 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.1.dev12+70115b9" source = { editable = "." } dependencies = [ { name = "anyio" }, From 2a3bed8e50d19a572ad3fe9a82e2ad347d2d1c57 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:01:51 +0100 Subject: [PATCH 10/19] fix example server --- .../mcp_simple_streamablehttp/server.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 88249baf..eec5edb4 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -155,25 +155,24 @@ async def handle_streamable_http(scope, receive, send): mcp_session_id=new_session_id, is_json_response_enabled=json_response, ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams + server_instances[http_transport.mcp_session_id] = http_transport + async with http_transport.connect() as streams: + read_stream, write_stream = streams - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) - if not task_group: - raise RuntimeError("Task group is not initialized") + if not task_group: + raise RuntimeError("Task group is not initialized") - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - task_group.start_soon(run_server) + task_group.start_soon(run_server) - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) else: response = Response( "Bad Request: No valid session ID provided", From 0456b1bd1c8f5dd2eaf6650ac8b23413c6614322 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:07:15 +0100 Subject: [PATCH 11/19] return 405 for get stream --- .../servers/simple-streamablehttp/pyproject.toml | 13 +------------ src/mcp/server/streamableHttp.py | 5 ++--- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml index de43bd2f..c35887d1 100644 --- a/examples/servers/simple-streamablehttp/pyproject.toml +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -1,23 +1,12 @@ [project] name = "mcp-simple-streamablehttp" version = "0.1.0" -description = "A simple MCP server exposing a website fetching tool with StreamableHttp transport" +description = "A simple MCP server exposing a StreamableHttp transport for testing" readme = "README.md" requires-python = ">=3.10" authors = [{ name = "Anthropic, PBC." }] -maintainers = [ - { name = "David Soria Parra", email = "davidsp@anthropic.com" }, - { name = "Justin Spahr-Summers", email = "justin@anthropic.com" }, -] keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] license = { text = "MIT" } -classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", -] dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] [project.scripts] diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2bc528b0..09b94395 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -492,10 +492,10 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # TODO: Implement SSE stream for GET requests - # For now, return 501 Not Implemented + # For now, return 405 Method Not Allowed response = self._create_server_response( "SSE stream from GET request not implemented yet", - HTTPStatus.NOT_IMPLEMENTED, + HTTPStatus.METHOD_NOT_ALLOWED, ) await response(request.scope, request.receive, send) @@ -514,7 +514,6 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return - # Terminate the session self._terminate_session() response = self._create_server_response( From 97ca48dc2dd00ca56e99b955c04151084c1d3801 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 11:59:46 +0100 Subject: [PATCH 12/19] speed up tests --- tests/server/test_streamableHttp.py | 139 +++++++++++++++------------- 1 file changed, 75 insertions(+), 64 deletions(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 42c416c5..063ad82b 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -85,25 +85,10 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] -@pytest.fixture -def server_port() -> int: - """Find an available port for the test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the URL for the test server.""" - return f"http://127.0.0.1:{server_port}" - - -def create_app(session_id=None, is_json_response_enabled=False) -> Starlette: +def create_app(is_json_response_enabled=False) -> Starlette: """Create a Starlette application for testing that matches the example server. Args: - session_id: Optional session ID to use for the server. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ # Create server instance @@ -197,20 +182,19 @@ async def run_server(): return app -def run_server(port: int, session_id=None, is_json_response_enabled=False) -> None: +def run_server(port: int, is_json_response_enabled=False) -> None: """Run the test server. Args: port: Port to listen on. - session_id: Optional session ID to use for the server. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ print( f"Starting test server on port {port} with " - f"session_id={session_id}, json_enabled={is_json_response_enabled}" + f"json_enabled={is_json_response_enabled}" ) - app = create_app(session_id, is_json_response_enabled) + app = create_app(is_json_response_enabled) # Configure server config = uvicorn.Config( app=app, @@ -238,22 +222,38 @@ def run_server(port: int, session_id=None, is_json_response_enabled=False) -> No print("Server shutdown") +# Test fixtures - using same approach as SSE tests @pytest.fixture -def basic_server(server_port: int) -> Generator[None, None, None]: - """Start a basic server without session ID.""" - # Start server process - process = multiprocessing.Process( - target=run_server, kwargs={"port": server_port}, daemon=True +def basic_server_port() -> int: + """Find an available port for the basic server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def json_server_port() -> int: + """Find an available port for the JSON response server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" + proc = multiprocessing.Process( + target=run_server, kwargs={"port": basic_server_port}, daemon=True ) - process.start() + proc.start() - # Wait for server to start + # Wait for server to be running max_attempts = 20 attempt = 0 while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) + s.connect(("127.0.0.1", basic_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -264,30 +264,29 @@ def basic_server(server_port: int) -> Generator[None, None, None]: yield # Clean up - process.terminate() - process.join(timeout=1) - if process.is_alive(): - process.kill() + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") @pytest.fixture -def json_response_server(server_port: int) -> Generator[None, None, None]: +def json_response_server(json_server_port: int) -> Generator[None, None, None]: """Start a server with JSON response enabled.""" - # Start server process with is_json_response_enabled=True - process = multiprocessing.Process( + proc = multiprocessing.Process( target=run_server, - kwargs={"port": server_port, "is_json_response_enabled": True}, + kwargs={"port": json_server_port, "is_json_response_enabled": True}, daemon=True, ) - process.start() + proc.start() - # Wait for server to start + # Wait for server to be running max_attempts = 20 attempt = 0 while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) + s.connect(("127.0.0.1", json_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -298,18 +297,30 @@ def json_response_server(server_port: int) -> Generator[None, None, None]: yield # Clean up - process.terminate() - process.join(timeout=1) - if process.is_alive(): - process.kill() + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def basic_server_url(basic_server_port: int) -> str: + """Get the URL for the basic test server.""" + return f"http://127.0.0.1:{basic_server_port}" + + +@pytest.fixture +def json_server_url(json_server_port: int) -> str: + """Get the URL for the JSON response test server.""" + return f"http://127.0.0.1:{json_server_port}" # Basic request validation tests -def test_accept_header_validation(basic_server, server_url): +def test_accept_header_validation(basic_server, basic_server_url): """Test that Accept header is properly validated.""" # Test without Accept header response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={"Content-Type": "application/json"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, ) @@ -317,11 +328,11 @@ def test_accept_header_validation(basic_server, server_url): assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server, server_url): +def test_content_type_validation(basic_server, basic_server_url): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "text/plain", @@ -332,11 +343,11 @@ def test_content_type_validation(basic_server, server_url): assert "Unsupported Media Type" in response.text -def test_json_validation(basic_server, server_url): +def test_json_validation(basic_server, basic_server_url): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -347,11 +358,11 @@ def test_json_validation(basic_server, server_url): assert "Parse error" in response.text -def test_json_parsing(basic_server, server_url): +def test_json_parsing(basic_server, basic_server_url): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -362,11 +373,11 @@ def test_json_parsing(basic_server, server_url): assert "Validation error" in response.text -def test_method_not_allowed(basic_server, server_url): +def test_method_not_allowed(basic_server, basic_server_url): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -377,13 +388,13 @@ def test_method_not_allowed(basic_server, server_url): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server, server_url): +def test_session_validation(basic_server, basic_server_url): """Test session ID validation.""" # session_id not used directly in this test # Test without session ID response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -452,10 +463,10 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server, server_url): +def test_session_termination(basic_server, basic_server_url): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -467,7 +478,7 @@ def test_session_termination(basic_server, server_url): # Now terminate the session session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={MCP_SESSION_ID_HEADER: session_id}, ) assert response.status_code == 200 @@ -475,7 +486,7 @@ def test_session_termination(basic_server, server_url): # Try to use the terminated session response = requests.post( - f"{server_url}/mcp", + f"{basic_server_url}/mcp", headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", @@ -487,9 +498,9 @@ def test_session_termination(basic_server, server_url): assert "Session has been terminated" in response.text -def test_response(basic_server, server_url): +def test_response(basic_server, basic_server_url): """Test response handling for a valid request.""" - mcp_url = f"{server_url}/mcp" + mcp_url = f"{basic_server_url}/mcp" response = requests.post( mcp_url, headers={ @@ -518,9 +529,9 @@ def test_response(basic_server, server_url): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server, server_url): +def test_json_response(json_response_server, json_server_url): """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{server_url}/mcp" + mcp_url = f"{json_server_url}/mcp" response = requests.post( mcp_url, headers={ @@ -530,4 +541,4 @@ def test_json_response(json_response_server, server_url): json=INIT_REQUEST, ) assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" + assert response.headers.get("Content-Type") == "application/json" \ No newline at end of file From 92d42875746f8b417b9b981de9ec491e1178c6be Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:07:41 +0100 Subject: [PATCH 13/19] format --- tests/server/test_streamableHttp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 063ad82b..e23059d6 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -541,4 +541,4 @@ def test_json_response(json_response_server, json_server_url): json=INIT_REQUEST, ) assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" \ No newline at end of file + assert response.headers.get("Content-Type") == "application/json" From aa9f6e5f3dac9808a830da9ceba17392635a1c42 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 21 Apr 2025 17:25:07 +0100 Subject: [PATCH 14/19] uv lock --- uv.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/uv.lock b/uv.lock index 6618ea36..3ea01ff8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -543,6 +544,7 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ From 46ec72d0ecdf0b6fe01699c7ccff4d2ed65aa31c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 22 Apr 2025 21:12:07 +0100 Subject: [PATCH 15/19] clean up --- src/mcp/server/streamableHttp.py | 393 +++++++++++++--------------- tests/server/test_streamableHttp.py | 1 - 2 files changed, 179 insertions(+), 215 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 09b94395..34a272f6 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -24,10 +24,17 @@ from starlette.types import Receive, Scope, Send from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + RequestId, ) logger = logging.getLogger(__name__) @@ -61,8 +68,6 @@ class StreamableHTTPServerTransport: None ) _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None - # Dictionary to track request-specific message streams - _request_streams: dict[str, MemoryObjectSendStream[JSONRPCMessage]] def __init__( self, @@ -90,16 +95,19 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled - self._request_streams = {} + self._request_streams: dict[ + RequestId, MemoryObjectSendStream[JSONRPCMessage] + ] = {} self._terminated = False - def _create_server_response( + def _create_error_response( self, - message: str, + error_message: str, status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, headers: dict[str, str] | None = None, ) -> Response: - """Create a standardized server response.""" + """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -107,15 +115,25 @@ def _create_server_response( if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + return Response( - message, + error_response.model_dump_json(by_alias=True, exclude_none=True), status_code=status_code, headers=response_headers, ) def _create_json_response( self, - response_message: JSONRPCMessage, + response_message: JSONRPCMessage | None, status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: @@ -128,7 +146,9 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True), + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, status_code=status_code, headers=response_headers, ) @@ -142,7 +162,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No request = Request(scope, receive) if self._terminated: # If the session has been terminated, return 404 Not Found - response = self._create_server_response( + response = self._create_error_response( "Not Found: Session has been terminated", HTTPStatus.NOT_FOUND, ) @@ -194,7 +214,7 @@ async def _handle_post_request( # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): - response = self._create_server_response( + response = self._create_error_response( ( "Not Acceptable: Client must accept both application/json and " "text/event-stream" @@ -206,7 +226,7 @@ async def _handle_post_request( # Validate Content-Type if not self._check_content_type(request): - response = self._create_server_response( + response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, ) @@ -216,7 +236,7 @@ async def _handle_post_request( # Parse the body - only read it once body = await request.body() if len(body) > MAXIMUM_MESSAGE_SIZE: - response = self._create_server_response( + response = self._create_error_response( "Payload Too Large: Message exceeds maximum size", HTTPStatus.REQUEST_ENTITY_TOO_LARGE, ) @@ -226,9 +246,8 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_server_response( - f"Parse error: {str(e)}", - HTTPStatus.BAD_REQUEST, + response = self._create_error_response( + f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR ) await response(scope, receive, send) return @@ -236,9 +255,10 @@ async def _handle_post_request( try: message = JSONRPCMessage.model_validate(raw_message) except ValidationError as e: - response = self._create_server_response( + response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, ) await response(scope, receive, send) return @@ -257,7 +277,7 @@ async def _handle_post_request( # If request has a session ID but doesn't match, return 404 if request_session_id and request_session_id != self.mcp_session_id: - response = self._create_server_response( + response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -267,13 +287,11 @@ async def _handle_post_request( elif not await self._validate_session(request, send): return - is_request = isinstance(message.root, JSONRPCRequest) - # For notifications and responses only, return 202 Accepted - if not is_request: + if not isinstance(message.root, JSONRPCRequest): # Create response object and send it - response = self._create_server_response( - "Accepted", + response = self._create_json_response( + None, HTTPStatus.ACCEPTED, ) await response(scope, receive, send) @@ -283,192 +301,141 @@ async def _handle_post_request( return - # For requests, determine whether to return JSON or set up SSE stream - if is_request: - if self.is_json_response_enabled: - # JSON response mode - create a response future - request_id = None - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - - if not request_id: - # Should not happen for valid JSONRPCRequest, but handle it - response = self._create_server_response( - "Invalid Request: Missing request ID", - HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) - return - - # Create promise stream for getting response - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) - - # Register this stream for the request ID - self._request_streams[request_id] = request_stream_writer - - # Process the message - await writer.send(message) + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) - try: - # Process messages from the request-specific stream - # We need to collect all messages until we get a response - response_message = None - - # Use similar approach to SSE writer for consistency - async for received_message in request_stream_reader: - # If it's a response, this is what we're waiting for - if isinstance(received_message.root, JSONRPCResponse): - response_message = received_message - break - # For notifications, keep waiting for the actual response - elif isinstance(received_message.root, JSONRPCNotification): - # Just process it and continue waiting - logger.debug( - f"Notification: {received_message.root.method}" - ) + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer - # At this point we should have a response - if response_message: - # Create JSON response - response = self._create_json_response(response_message) - await response(scope, receive, send) + if self.is_json_response_enabled: + # Process the message + await writer.send(message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + received_message.root, JSONRPCResponse | JSONRPCError + ): + response_message = received_message + break + # For notifications and request, keep waiting else: - # This shouldn't happen in normal operation - logger.error( - "No response message received before stream closed" - ) - response = self._create_server_response( - "Error processing request: No response received", - HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - except Exception as e: - logger.exception(f"Error processing JSON response: {e}") - response = self._create_server_response( - f"Error processing request: {str(e)}", + logger.debug(f"received: {received_message.root.method}") + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error( + "No response message received before stream closed" + ) + response = self._create_error_response( + "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, ) await response(scope, receive, send) - finally: - # Clean up the request stream - if request_id in self._request_streams: - self._request_streams.pop(request_id, None) - await request_stream_reader.aclose() - await request_stream_writer.aclose() - else: - # SSE stream mode (original behavior) - # Set up headers - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } - - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) - ) - - async def sse_writer(): - # Get the request ID from the incoming request message - request_id = None - try: - # Create a request-specific message stream for this POST - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) - ) - - if isinstance(message.root, JSONRPCRequest): - request_id = str(message.root.id) - # Register this stream for the request ID - if request_id: - self._request_streams[request_id] = ( - request_stream_writer - ) - - async with sse_stream_writer, request_stream_reader: - # Process messages from the request-specific stream - async for received_message in request_stream_reader: - # Send the message via SSE - related_request_id = None - - if isinstance( - received_message.root, JSONRPCNotification - ): - # Get related_request_id from params - params = received_message.root.params - if params and "related_request_id" in params: - related_request_id = params.get( - "related_request_id" - ) - logger.debug( - f"NOTIFICATION: {related_request_id}, " - f"{params.get('data')}" - ) - - # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance( - received_message.root, JSONRPCResponse - ): - if request_id: - self._request_streams.pop(request_id, None) - break - except Exception as e: - logger.exception(f"Error in SSE writer: {e}") - finally: - logger.debug("Closing SSE writer") - # Clean up the request-specific streams - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=sse_writer, - headers=headers, + except Exception as e: + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, ) + await response(scope, receive, send) + finally: + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) - # Extract the request ID outside the try block for proper scope - outer_request_id = None - if isinstance(message.root, JSONRPCRequest): - outer_request_id = str(message.root.id) - - # Start the SSE response (this will send headers immediately) + async def sse_writer(): + # Get the request ID from the incoming request message try: - # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, + JSONRPCResponse | JSONRPCError, + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) - # Then send the message to be processed by the server - await writer.send(message) - except Exception: - logger.exception("SSE response error") - # Clean up the request stream if something goes wrong - if ( - outer_request_id - and outer_request_id in self._request_streams - ): - self._request_streams.pop(outer_request_id, None) + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **( + {MCP_SESSION_ID_HEADER: self.mcp_session_id} + if self.mcp_session_id + else {} + ), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Clean up the request stream if something goes wrong + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) except Exception as err: logger.exception("Error handling POST request") - response = self._create_server_response( + response = self._create_error_response( f"Error handling POST request: {err}", HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, ) await response(scope, receive, send) if writer: @@ -484,7 +451,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: _, has_sse = self._check_accept_headers(request) if not has_sse: - response = self._create_server_response( + response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, ) @@ -493,7 +460,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # TODO: Implement SSE stream for GET requests # For now, return 405 Method Not Allowed - response = self._create_server_response( + response = self._create_error_response( "SSE stream from GET request not implemented yet", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -504,7 +471,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # Validate session ID if not self.mcp_session_id: # If no session ID set, return Method Not Allowed - response = self._create_server_response( + response = self._create_error_response( "Method Not Allowed: Session termination not supported", HTTPStatus.METHOD_NOT_ALLOWED, ) @@ -516,8 +483,8 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: self._terminate_session() - response = self._create_server_response( - "Session terminated", + response = self._create_json_response( + None, HTTPStatus.OK, ) await response(request.scope, request.receive, send) @@ -557,9 +524,9 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non if self.mcp_session_id: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - response = Response( + response = self._create_error_response( "Method Not Allowed", - status_code=HTTPStatus.METHOD_NOT_ALLOWED, + HTTPStatus.METHOD_NOT_ALLOWED, headers=headers, ) await response(request.scope, request.receive, send) @@ -575,7 +542,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If no session ID provided but required, return error if not request_session_id: - response = self._create_server_response( + response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, ) @@ -584,7 +551,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # If session ID doesn't match, return error if request_session_id != self.mcp_session_id: - response = self._create_server_response( + response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, ) @@ -635,18 +602,16 @@ async def message_router(): async for message in write_stream_reader: # Determine which request stream(s) should receive this message target_request_id = None - - # For responses, route based on the request ID - if isinstance(message.root, JSONRPCResponse): + if isinstance( + message.root, JSONRPCNotification | JSONRPCRequest + ): + # Extract related_request_id from params if it exists + if (params := getattr(message.root, "params", None)) and ( + related_id := params.get("related_request_id") + ) is not None: + target_request_id = str(related_id) + else: target_request_id = str(message.root.id) - # For notifications, route by related_request_id if available - elif isinstance(message.root, JSONRPCNotification): - # Get related_request_id from params - params = message.root.params - if params and "related_request_id" in params: - related_id = params.get("related_request_id") - if related_id is not None: - target_request_id = str(related_id) # Send to the specific request stream if available if ( diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index e23059d6..8904bf4f 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -482,7 +482,6 @@ def test_session_termination(basic_server, basic_server_url): headers={MCP_SESSION_ID_HEADER: session_id}, ) assert response.status_code == 200 - assert "Session terminated" in response.text # Try to use the terminated session response = requests.post( From 9b096dc558f02d4a903e9215037d302e9554d62a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 09:44:58 +0100 Subject: [PATCH 16/19] add comments to server example where we use related_request_id --- .../mcp_simple_streamablehttp/server.py | 5 +++++ src/mcp/server/fastmcp/server.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index eec5edb4..e7bc4430 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -82,6 +82,11 @@ async def call_tool( level="info", data=f"Notification {i+1}/{count} from caller: {caller}", logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) if i < count - 1: # Don't wait after the last notification diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f3bb2586..008b235f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -466,6 +466,7 @@ async def run_stdio_async(self) -> None: async def run_sse_async(self) -> None: """Run the server using SSE transport.""" import uvicorn + starlette_app = self.sse_app() config = uvicorn.Config( @@ -673,7 +674,10 @@ async def log( **extra: Additional structured data to include """ await self.request_context.session.send_log_message( - level=level, data=message, logger=logger_name + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, ) @property From a0a9c5b4e5f7c6b39118405a3230b9e2bc66175e Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 10:18:06 +0100 Subject: [PATCH 17/19] small fixes --- examples/servers/simple-streamablehttp/README.md | 1 + src/mcp/server/streamableHttp.py | 5 ----- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index 5125c3eb..e5aaa652 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -5,6 +5,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en ## Features - Uses the StreamableHTTP transport for server-client communication +- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint - Task management with anyio task groups - Ability to send multiple notifications over time to the client - Proper resource cleanup and lifespan management diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 34a272f6..2d536a40 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -577,11 +577,6 @@ async def connect( """ # Create the memory streams for this connection - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] - - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream[ JSONRPCMessage | Exception From a5ac2e09df6df2eda17ea73559004a789a3d1f3c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 11:04:38 +0100 Subject: [PATCH 18/19] use mta field for related_request_id --- src/mcp/server/streamableHttp.py | 11 +++++++---- src/mcp/shared/session.py | 19 +++++++++++++++++-- src/mcp/types.py | 1 - tests/client/test_logging_callback.py | 12 +++++++++--- tests/server/fastmcp/test_server.py | 22 ++++++++++++++++++---- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2d536a40..2e0f7090 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -600,10 +600,13 @@ async def message_router(): if isinstance( message.root, JSONRPCNotification | JSONRPCRequest ): - # Extract related_request_id from params if it exists - if (params := getattr(message.root, "params", None)) and ( - related_id := params.get("related_request_id") - ) is not None: + # Extract related_request_id from meta if it exists + if ( + (params := getattr(message.root, "params", None)) + and (meta := params.get("_meta")) + and (related_id := meta.get("related_request_id")) + is not None + ): target_request_id = str(related_id) else: target_request_id = str(message.root.id) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 1017bb98..368524f9 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -6,7 +6,6 @@ from typing import Any, Generic, TypeVar import anyio -import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel @@ -24,6 +23,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -276,8 +276,23 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered + # them. + # Update notification meta with related request ID if provided if related_request_id is not None and notification.root.params is not None: - notification.root.params.related_request_id = related_request_id + # Create meta if it doesn't exist + if notification.root.params.meta is None: + # Create meta dict with related_request_id + meta_dict = {"related_request_id": related_request_id} + + else: + # Update existing meta with model_validate to properly handle extra fields + meta_dict = notification.root.params.meta.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + meta_dict["related_request_id"] = related_request_id + notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), diff --git a/src/mcp/types.py b/src/mcp/types.py index 30500e31..bd71d51f 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -58,7 +58,6 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) - related_request_id: RequestId | None = None """ This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 797f817e..588fa649 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,6 +9,7 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, + NotificationParams, TextContent, ) @@ -78,6 +79,11 @@ async def message_handler( ) assert log_result.isError is False assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Create meta object with related_request_id added dynamically + meta = NotificationParams.Meta() + setattr(meta, "related_request_id", "2") + log = logging_collector.log_messages[0] + assert log.level == "info" + assert log.logger == "test_logger" + assert log.data == "Test log message" + assert log.meta == meta diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c5..772c4152 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -544,14 +544,28 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert mock_log.call_count == 4 mock_log.assert_any_call( - level="debug", data="Debug message", logger=None + level="debug", + data="Debug message", + logger=None, + related_request_id="1", ) - mock_log.assert_any_call(level="info", data="Info message", logger=None) mock_log.assert_any_call( - level="warning", data="Warning message", logger=None + level="info", + data="Info message", + logger=None, + related_request_id="1", ) mock_log.assert_any_call( - level="error", data="Error message", logger=None + level="warning", + data="Warning message", + logger=None, + related_request_id="1", + ) + mock_log.assert_any_call( + level="error", + data="Error message", + logger=None, + related_request_id="1", ) @pytest.mark.anyio From 2e615f36b7e515ad9efaf2cff46bfc1a2fa00f46 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 23 Apr 2025 11:10:20 +0100 Subject: [PATCH 19/19] unrelated test and format --- src/mcp/shared/session.py | 6 +----- tests/issues/test_188_concurrency.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 368524f9..3a01cb04 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -277,17 +277,13 @@ async def send_notification( a response. """ # Some transport implementations may need to set the related_request_id - # to attribute to the notifications to the request that triggered - # them. - # Update notification meta with related request ID if provided + # to attribute to the notifications to the request that triggered them. if related_request_id is not None and notification.root.params is not None: # Create meta if it doesn't exist if notification.root.params.meta is None: - # Create meta dict with related_request_id meta_dict = {"related_request_id": related_request_id} else: - # Update existing meta with model_validate to properly handle extra fields meta_dict = notification.root.params.meta.model_dump( by_alias=True, mode="json", exclude_none=True ) diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 2aa6c49c..d0a86885 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 3 * _sleep_time_seconds + assert duration < 6 * _sleep_time_seconds print(duration)