diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index faad95aca..8ae261117 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,6 +68,7 @@ async def main(): from __future__ import annotations as _annotations import contextvars +import inspect import json import logging import warnings @@ -104,6 +105,9 @@ async def main(): # This will be properly typed in each Server instance's context request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +# Context variable to hold the current ServerSession, accessible by notification handlers +current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session") + class NotificationOptions: def __init__( @@ -520,6 +524,36 @@ async def handler(req: types.ProgressNotification): return decorator + def initialized_notification(self): + """Decorator to register a handler for InitializedNotification.""" + + def decorator( + func: ( + Callable[[types.InitializedNotification, ServerSession], Awaitable[None]] + | Callable[[types.InitializedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for InitializedNotification") + self.notification_handlers[types.InitializedNotification] = func + return func + + return decorator + + def roots_list_changed_notification(self): + """Decorator to register a handler for RootsListChangedNotification.""" + + def decorator( + func: ( + Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]] + | Callable[[types.RootsListChangedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for RootsListChangedNotification") + self.notification_handlers[types.RootsListChangedNotification] = func + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" @@ -591,22 +625,26 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception), session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, ): - with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] - case RequestResponder(request=types.ClientRequest(root=req)) as responder: - with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + session_token = current_session_ctx.set(session) + try: + with warnings.catch_warnings(record=True) as w: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] + case RequestResponder(request=types.ClientRequest(root=req)) as responder: + with responder: + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) for warning in w: logger.info("Warning: %s: %s", warning.category.__name__, warning.message) + finally: + current_session_ctx.reset(session_token) async def _handle_request( self, @@ -666,7 +704,11 @@ async def _handle_notification(self, notify: Any): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + sig = inspect.signature(handler) + if "session" in sig.parameters: + await handler(notify, current_session_ctx.get()) + else: + await handler(notify) except Exception: logger.exception("Uncaught exception in notification handler") diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_notifications.py similarity index 63% rename from tests/shared/test_progress_notifications.py rename to tests/shared/test_notifications.py index 08bcb2662..fe835cd9e 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_notifications.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, cast import anyio @@ -10,11 +11,11 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import ( BaseSession, RequestResponder, - SessionMessage, ) @@ -333,3 +334,191 @@ async def handle_client_message( assert server_progress_updates[3]["progress"] == 100 assert server_progress_updates[3]["total"] == 100 assert server_progress_updates[3]["message"] == "Processing results..." + + +@pytest.mark.anyio +async def test_initialized_notification(): + """Test that the server receives and handles InitializedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification): + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + + +@pytest.mark.anyio +async def test_roots_list_changed_notification(): + """Test that the server receives and handles RootsListChangedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + + @server.roots_list_changed_notification() + async def handle_roots_list_changed( + notification: types.RootsListChangedNotification, + ): + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + + +@pytest.mark.anyio +async def test_initialized_notification_with_session(): + """Test that the server receives and handles InitializedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + received_session = None + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification, session: ServerSession): + nonlocal received_session + received_session = session + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + assert isinstance(received_session, ServerSession) + + +@pytest.mark.anyio +async def test_roots_list_changed_notification_with_session(): + """Test that the server receives and handles RootsListChangedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + received_session = None + + @server.roots_list_changed_notification() + async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession): + nonlocal received_session + received_session = session + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + assert isinstance(received_session, ServerSession)