diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md new file mode 100644 index 00000000..e5aaa652 --- /dev/null +++ b/examples/servers/simple-streamablehttp/README.md @@ -0,0 +1,37 @@ +# MCP Simple StreamableHttp Server Example + +A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming. + +## Features + +- Uses the StreamableHTTP transport for server-client communication +- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint +- Task management with anyio task groups +- Ability to send multiple notifications over time to the client +- Proper resource cleanup and lifespan management + +## Usage + +Start the server on the default or custom port: + +```bash + +# Using custom port +uv run mcp-simple-streamablehttp --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Client + +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py new file mode 100644 index 00000000..f5f6e402 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py new file mode 100644 index 00000000..e7bc4430 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -0,0 +1,201 @@ +import contextlib +import logging +from http import HTTPStatus +from uuid import uuid4 + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount + +# Configure logging +logger = logging.getLogger(__name__) + +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # We need to store the server instances between requests + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + # try to establish new session + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + async with session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=json_response, + ) + server_instances[http_transport.mcp_session_id] = http_transport + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + task_group.start_soon(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml new file mode 100644 index 00000000..c35887d1 --- /dev/null +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport for testing" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 65d342e1..5b57eb13 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -814,7 +814,10 @@ async def log( **extra: Additional structured data to include """ await self.request_context.session.send_log_message( - level=level, data=message, logger=logger_name + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, ) @property diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b..3a1f210d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -179,7 +179,11 @@ async def _received_notification( ) async def send_log_message( - self, level: types.LoggingLevel, data: Any, logger: str | None = None + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( @@ -192,7 +196,8 @@ async def send_log_message( logger=logger, ), ) - ) + ), + related_request_id, ) async def send_resource_updated(self, uri: AnyUrl) -> None: @@ -261,7 +266,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -274,7 +283,8 @@ async def send_progress_notification( total=total, ), ) - ) + ), + related_request_id, ) async def send_resource_list_changed(self) -> None: diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py new file mode 100644 index 00000000..2e0f7090 --- /dev/null +++ b/src/mcp/server/streamableHttp.py @@ -0,0 +1,644 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +import re +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + +# Maximum size for incoming messages +MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + None + ) + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None + + def __init__( + self, + mcp_session_id: str | None, + is_json_response_enabled: bool = False, + ) -> None: + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. + + Raises: + ValueError: If the session ID contains invalid characters. + """ + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( + mcp_session_id + ): + raise ValueError( + "Session ID must only contain visible ASCII characters (0x21-0x7E)" + ) + + self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled + self._request_streams: dict[ + RequestId, MemoryObjectSendStream[JSONRPCMessage] + ] = {} + self._terminated = False + + def _create_error_response( + self, + error_message: str, + status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, + headers: dict[str, str] | None = None, + ) -> Response: + """Create an error response with a simple string message.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + + return Response( + error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + + def _create_json_response( + self, + response_message: JSONRPCMessage | None, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """Create a JSON response from a JSONRPCMessage""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """Extract the session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Application entry point that handles all HTTP requests""" + request = Request(scope, receive) + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """Check if the request accepts the required media types.""" + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types + ) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """Check if the request has the correct Content-Type.""" + content_type = request.headers.get("content-type", "") + content_type_parts = [ + part.strip() for part in content_type.split(";")[0].split(",") + ] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """Handle POST requests containing JSON-RPC messages.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + try: + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( + ( + "Not Acceptable: Client must accept both application/json and " + "text/event-stream" + ), + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + if not self._check_content_type(request): + response = self._create_error_response( + "Unsupported Media Type: Content-Type must be application/json", + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + ) + await response(scope, receive, send) + return + + # Parse the body - only read it once + body = await request.body() + if len(body) > MAXIMUM_MESSAGE_SIZE: + response = self._create_error_response( + "Payload Too Large: Message exceeds maximum size", + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + ) + await response(scope, receive, send) + return + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = self._create_error_response( + f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR + ) + await response(scope, receive, send) + return + + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = self._create_error_response( + f"Validation error: {str(e)}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + if is_initialization_request: + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + # For non-initialization requests, validate the session + elif not await self._validate_session(request, send): + return + + # For notifications and responses only, return 202 Accepted + if not isinstance(message.root, JSONRPCRequest): + # Create response object and send it + response = self._create_json_response( + None, + HTTPStatus.ACCEPTED, + ) + await response(scope, receive, send) + + # Process the message after sending the response + await writer.send(message) + + return + + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer + + if self.is_json_response_enabled: + # Process the message + await writer.send(message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + received_message.root, JSONRPCResponse | JSONRPCError + ): + response_message = received_message + break + # For notifications and request, keep waiting + else: + logger.debug(f"received: {received_message.root.method}") + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error( + "No response message received before stream closed" + ) + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + except Exception as e: + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + finally: + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + # Get the request ID from the incoming request message + try: + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, + JSONRPCResponse | JSONRPCError, + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) + + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **( + {MCP_SESSION_ID_HEADER: self.mcp_session_id} + if self.mcp_session_id + else {} + ), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Clean up the request stream if something goes wrong + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) + + except Exception as err: + logger.exception("Error handling POST request") + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + if writer: + await writer.send(err) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + """Handle GET requests for SSE stream establishment.""" + # Validate session ID if server has one + if not await self._validate_session(request, send): + return + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + # TODO: Implement SSE stream for GET requests + # For now, return 405 Method Not Allowed + response = self._create_error_response( + "SSE stream from GET request not implemented yet", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + """Handle DELETE requests for explicit session termination.""" + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_session(request, send): + return + + self._terminate_session() + + response = self._create_json_response( + None, + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + def _terminate_session(self) -> None: + """Terminate the current session, closing all streams. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams (synchronously) + for key in request_stream_keys: + try: + # Get the stream + stream = self._request_streams.get(key) + if stream: + # We must use close() here, not aclose() since this is a sync method + stream.close() + except Exception as e: + logger.debug(f"Error closing stream {key} during termination: {e}") + + # Clear the request streams dictionary immediately + self._request_streams.clear() + + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """Handle unsupported HTTP methods.""" + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = self._create_error_response( + "Method Not Allowed", + HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) + + async def _validate_session(self, request: Request, send: Send) -> bool: + """Validate the session ID in the request.""" + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ], + None, + ]: + """Context manager that provides read and write streams for a connection. + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._write_stream_reader = write_stream_reader + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for message in write_stream_reader: + # Determine which request stream(s) should receive this message + target_request_id = None + if isinstance( + message.root, JSONRPCNotification | JSONRPCRequest + ): + # Extract related_request_id from meta if it exists + if ( + (params := getattr(message.root, "params", None)) + and (meta := params.get("_meta")) + and (related_id := meta.get("related_request_id")) + is not None + ): + target_request_id = str(related_id) + else: + target_request_id = str(message.root.id) + + # Send to the specific request stream if available + if ( + target_request_id + and target_request_id in self._request_streams + ): + try: + await self._request_streams[target_request_id].send( + message + ) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(target_request_id, None) + except Exception as e: + logger.exception(f"Error in message router: {e}") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + for stream in list(self._request_streams.values()): + try: + await stream.aclose() + except Exception: + pass + self._request_streams.clear() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 11daedc9..c1259da7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -6,7 +6,6 @@ from typing import Any, Generic, TypeVar import anyio -import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel @@ -24,6 +23,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -274,16 +274,32 @@ async def send_request( await response_stream.aclose() await response_stream_reader.aclose() - async def send_notification(self, notification: SendNotificationT) -> None: + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: """ Emits a notification, which is a one-way message that does not expect a response. """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered them. + if related_request_id is not None and notification.root.params is not None: + # Create meta if it doesn't exist + if notification.root.params.meta is None: + meta_dict = {"related_request_id": related_request_id} + + else: + meta_dict = notification.root.params.meta.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + meta_dict["related_request_id"] = related_request_id + notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) async def _send_response( diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 797f817e..588fa649 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,6 +9,7 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, + NotificationParams, TextContent, ) @@ -78,6 +79,11 @@ async def message_handler( ) assert log_result.isError is False assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Create meta object with related_request_id added dynamically + meta = NotificationParams.Meta() + setattr(meta, "related_request_id", "2") + log = logging_collector.log_messages[0] + assert log.level == "info" + assert log.logger == "test_logger" + assert log.data == "Test log message" + assert log.meta == meta diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 2aa6c49c..d0a86885 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 3 * _sleep_time_seconds + assert duration < 6 * _sleep_time_seconds print(duration) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c5..772c4152 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -544,14 +544,28 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert mock_log.call_count == 4 mock_log.assert_any_call( - level="debug", data="Debug message", logger=None + level="debug", + data="Debug message", + logger=None, + related_request_id="1", ) - mock_log.assert_any_call(level="info", data="Info message", logger=None) mock_log.assert_any_call( - level="warning", data="Warning message", logger=None + level="info", + data="Info message", + logger=None, + related_request_id="1", ) mock_log.assert_any_call( - level="error", data="Error message", logger=None + level="warning", + data="Warning message", + logger=None, + related_request_id="1", + ) + mock_log.assert_any_call( + level="error", + data="Error message", + logger=None, + related_request_id="1", ) @pytest.mark.anyio diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py new file mode 100644 index 00000000..8904bf4f --- /dev/null +++ b/tests/server/test_streamableHttp.py @@ -0,0 +1,543 @@ +""" +Tests for the StreamableHTTP server transport validation. + +This file contains tests for request validation in the StreamableHTTP transport. +""" + +import contextlib +import multiprocessing +import socket +import time +from collections.abc import Generator +from http import HTTPStatus +from uuid import uuid4 + +import anyio +import pytest +import requests +import uvicorn +from pydantic import AnyUrl +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount + +from mcp.server import Server +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + SESSION_ID_PATTERN, + StreamableHTTPServerTransport, +) +from mcp.shared.exceptions import McpError +from mcp.types import ( + ErrorData, + TextContent, + Tool, +) + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +# Test server implementation that follows MCP protocol +class ServerTest(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + return [TextContent(type="text", text=f"Called {name}")] + + +def create_app(is_json_response_enabled=False) -> Starlette: + """Create a Starlette application for testing that matches the example server. + + Args: + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + # Create server instance + server = ServerTest() + + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + task_group = None + + @contextlib.asynccontextmanager + async def lifespan(app): + """Application lifespan context manager for managing task group.""" + nonlocal task_group + + async with anyio.create_task_group() as tg: + task_group = tg + print("Application started, task group initialized!") + try: + yield + finally: + print("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + print("Resources cleaned up successfully.") + + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Use existing transport if session ID matches + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + async with session_creation_lock: + new_session_id = uuid4().hex + + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=is_json_response_enabled, + ) + + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + try: + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + except Exception as e: + print(f"Server exception: {e}") + + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + + # Create an ASGI application + app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + return app + + +def run_server(port: int, is_json_response_enabled=False) -> None: + """Run the test server. + + Args: + port: Port to listen on. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + print( + f"Starting test server on port {port} with " + f"json_enabled={is_json_response_enabled}" + ) + + app = create_app(is_json_response_enabled) + # Configure server + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="info", + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) + + # Start the server + server = uvicorn.Server(config=config) + + # This is important to catch exceptions and prevent test hangs + try: + print("Server starting...") + server.run() + except Exception as e: + print(f"ERROR: Server failed to run: {e}") + import traceback + + traceback.print_exc() + + print("Server shutdown") + + +# Test fixtures - using same approach as SSE tests +@pytest.fixture +def basic_server_port() -> int: + """Find an available port for the basic server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def json_server_port() -> int: + """Find an available port for the JSON response server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" + proc = multiprocessing.Process( + target=run_server, kwargs={"port": basic_server_port}, daemon=True + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def json_response_server(json_server_port: int) -> Generator[None, None, None]: + """Start a server with JSON response enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": json_server_port, "is_json_response_enabled": True}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", json_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def basic_server_url(basic_server_port: int) -> str: + """Get the URL for the basic test server.""" + return f"http://127.0.0.1:{basic_server_port}" + + +@pytest.fixture +def json_server_url(json_server_port: int) -> str: + """Get the URL for the JSON response test server.""" + return f"http://127.0.0.1:{json_server_port}" + + +# Basic request validation tests +def test_accept_header_validation(basic_server, basic_server_url): + """Test that Accept header is properly validated.""" + # Test without Accept header + response = requests.post( + f"{basic_server_url}/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +def test_content_type_validation(basic_server, basic_server_url): + """Test that Content-Type header is properly validated.""" + # Test with incorrect Content-Type + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + data="This is not JSON", + ) + assert response.status_code == 415 + assert "Unsupported Media Type" in response.text + + +def test_json_validation(basic_server, basic_server_url): + """Test that JSON content is properly validated.""" + # Test with invalid JSON + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + data="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + +def test_json_parsing(basic_server, basic_server_url): + """Test that JSON content is properly parse.""" + # Test with valid JSON but invalid JSON-RPC + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +def test_method_not_allowed(basic_server, basic_server_url): + """Test that unsupported HTTP methods are rejected.""" + # Test with unsupported method (PUT) + response = requests.put( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +def test_session_validation(basic_server, basic_server_url): + """Test session ID validation.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + +def test_session_id_pattern(): + """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) + valid_session_ids = [ + "test-session-id", + "1234567890", + "session!@#$%^&*()_+-=[]{}|;:,.<>?/", + "~`", + ] + + for session_id in valid_session_ids: + assert SESSION_ID_PATTERN.match(session_id) is not None + # Ensure fullmatch matches too (whole string) + assert SESSION_ID_PATTERN.fullmatch(session_id) is not None + + # Invalid session IDs + invalid_session_ids = [ + "", # Empty string + " test", # Space (0x20) + "test\t", # Tab + "test\n", # Newline + "test\r", # Carriage return + "test" + chr(0x7F), # DEL character + "test" + chr(0x80), # Extended ASCII + "test" + chr(0x00), # Null character + "test" + chr(0x20), # Space (0x20) + ] + + for session_id in invalid_session_ids: + # For invalid IDs, either match will fail or fullmatch will fail + if SESSION_ID_PATTERN.match(session_id) is not None: + # If match succeeds, fullmatch should fail (partial match case) + assert SESSION_ID_PATTERN.fullmatch(session_id) is None + + +def test_streamable_http_transport_init_validation(): + """Test that StreamableHTTPServerTransport validates session ID on init.""" + # Valid session ID should initialize without errors + valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") + assert valid_transport.mcp_session_id == "valid-id" + + # None should be accepted + none_transport = StreamableHTTPServerTransport(mcp_session_id=None) + assert none_transport.mcp_session_id is None + + # Invalid session ID should raise ValueError + with pytest.raises(ValueError) as excinfo: + StreamableHTTPServerTransport(mcp_session_id="invalid id with space") + assert "Session ID must only contain visible ASCII characters" in str(excinfo.value) + + # Test with control characters + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\nid") + + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\n") + + +def test_session_termination(basic_server, basic_server_url): + """Test session termination via DELETE and subsequent request handling.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + response = requests.delete( + f"{basic_server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text + + +def test_response(basic_server, basic_server_url): + """Test response handling for a valid request.""" + mcp_url = f"{basic_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + + # Try to use the terminated session + tools_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + stream=True, + ) + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +def test_json_response(json_response_server, json_server_url): + """Test response handling when is_json_response_enabled is True.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" diff --git a/uv.lock b/uv.lock index fdb788a7..01726cc9 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,7 @@ members = [ "mcp", "mcp-simple-prompt", "mcp-simple-resource", + "mcp-simple-streamablehttp", "mcp-simple-tool", ] @@ -632,6 +633,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0"