diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index e7bc4430..b5faffed 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -11,6 +11,7 @@ MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) +from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -92,6 +93,9 @@ async def call_tool( if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) + # This will send a resource notificaiton though standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) return [ types.TextContent( type="text", diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2e0f7090..8faff016 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -50,6 +50,9 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + # Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") @@ -443,10 +446,19 @@ async def sse_writer(): return async def _handle_get_request(self, request: Request, send: Send) -> None: - """Handle GET requests for SSE stream establishment.""" - # Validate session ID if server has one - if not await self._validate_session(request, send): - return + """ + Handle GET request to establish SSE. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -458,13 +470,80 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - # TODO: Implement SSE stream for GET requests - # For now, return 405 Method Not Allowed - response = self._create_error_response( - "SSE stream from GET request not implemented yet", - HTTPStatus.METHOD_NOT_ALLOWED, + if not await self._validate_session(request, send): + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + response = self._create_error_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, Any] + ](0) + + async def standalone_sse_writer(): + try: + # Create a standalone message stream for server-initiated messages + standalone_stream_writer, standalone_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Register this stream using the special key + self._request_streams[GET_STREAM_KEY] = standalone_stream_writer + + async with sse_stream_writer, standalone_stream_reader: + # Process messages from the standalone stream + async for received_message in standalone_stream_reader: + # For the standalone stream, we handle: + # - JSONRPCNotification (server sends notifications to client) + # - JSONRPCRequest (server sends requests to client) + # We should NOT receive JSONRPCResponse + + # Send the message via SSE + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in standalone SSE writer: {e}") + finally: + logger.debug("Closing standalone SSE writer") + # Remove the stream from request_streams + self._request_streams.pop(GET_STREAM_KEY, None) + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, ) - await response(request.scope, request.receive, send) + + try: + # This will send headers immediately and establish the SSE connection + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in standalone SSE response: {e}") + # Clean up the request stream + self._request_streams.pop(GET_STREAM_KEY, None) async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" @@ -611,13 +690,10 @@ async def message_router(): else: target_request_id = str(message.root.id) - # Send to the specific request stream if available - if ( - target_request_id - and target_request_id in self._request_streams - ): + request_stream_id = target_request_id or GET_STREAM_KEY + if request_stream_id in self._request_streams: try: - await self._request_streams[target_request_id].send( + await self._request_streams[request_stream_id].send( message ) except ( @@ -625,7 +701,7 @@ async def message_router(): anyio.ClosedResourceError, ): # Stream might be closed, remove from registry - self._request_streams.pop(target_request_id, None) + self._request_streams.pop(request_stream_id, None) except Exception as e: logger.exception(f"Error in message router: {e}") diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 8904bf4f..f612575c 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -541,3 +541,92 @@ def test_json_response(json_response_server, json_server_url): ) assert response.status_code == 200 assert response.headers.get("Content-Type") == "application/json" + + +def test_get_sse_stream(basic_server, basic_server_url): + """Test establishing an SSE stream via GET request.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Now attempt to establish an SSE stream via GET + get_response = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + + # Test that a second GET request gets rejected (only one stream allowed) + second_get = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Should get CONFLICT (409) since there's already a stream + # Note: This might fail if the first stream fully closed before this runs, + # but generally it should work in the test environment where it runs quickly + assert second_get.status_code == 409 + + +def test_get_validation(basic_server, basic_server_url): + """Test validation for GET requests.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test without Accept header + response = requests.get( + mcp_url, + headers={ + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = requests.get( + mcp_url, + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text