Skip to content

Adding roots changed and initialized notification handlers #1043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def main():
from __future__ import annotations as _annotations

import contextvars
import inspect
import json
import logging
import warnings
Expand Down Expand Up @@ -104,6 +105,9 @@ async def main():
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")

# Context variable to hold the current ServerSession, accessible by notification handlers
current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session")


class NotificationOptions:
def __init__(
Expand Down Expand Up @@ -520,6 +524,36 @@ async def handler(req: types.ProgressNotification):

return decorator

def initialized_notification(self):
"""Decorator to register a handler for InitializedNotification."""

def decorator(
func: (
Callable[[types.InitializedNotification, ServerSession], Awaitable[None]]
| Callable[[types.InitializedNotification], Awaitable[None]]
),
):
logger.debug("Registering handler for InitializedNotification")
self.notification_handlers[types.InitializedNotification] = func
return func

return decorator

def roots_list_changed_notification(self):
"""Decorator to register a handler for RootsListChangedNotification."""

def decorator(
func: (
Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]]
| Callable[[types.RootsListChangedNotification], Awaitable[None]]
),
):
logger.debug("Registering handler for RootsListChangedNotification")
self.notification_handlers[types.RootsListChangedNotification] = func
return func

return decorator

def completion(self):
"""Provides completions for prompts and resource templates"""

Expand Down Expand Up @@ -591,22 +625,26 @@ async def run(

async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception),
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
session_token = current_session_ctx.set(session)
try:
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

for warning in w:
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
finally:
current_session_ctx.reset(session_token)

async def _handle_request(
self,
Expand Down Expand Up @@ -666,7 +704,11 @@ async def _handle_notification(self, notify: Any):
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
await handler(notify)
sig = inspect.signature(handler)
if "session" in sig.parameters:
await handler(notify, current_session_ctx.get())
else:
await handler(notify)
except Exception:
logger.exception("Uncaught exception in notification handler")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, cast

import anyio
Expand All @@ -10,11 +11,11 @@
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.progress import progress
from mcp.shared.session import (
BaseSession,
RequestResponder,
SessionMessage,
)


Expand Down Expand Up @@ -333,3 +334,191 @@ async def handle_client_message(
assert server_progress_updates[3]["progress"] == 100
assert server_progress_updates[3]["total"] == 100
assert server_progress_updates[3]["message"] == "Processing results..."


@pytest.mark.anyio
async def test_initialized_notification():
"""Test that the server receives and handles InitializedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()

@server.initialized_notification()
async def handle_initialized(notification: types.InitializedNotification):
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()


@pytest.mark.anyio
async def test_roots_list_changed_notification():
"""Test that the server receives and handles RootsListChangedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()

@server.roots_list_changed_notification()
async def handle_roots_list_changed(
notification: types.RootsListChangedNotification,
):
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()


@pytest.mark.anyio
async def test_initialized_notification_with_session():
"""Test that the server receives and handles InitializedNotification with a session."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()
received_session = None

@server.initialized_notification()
async def handle_initialized(notification: types.InitializedNotification, session: ServerSession):
nonlocal received_session
received_session = session
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()
assert isinstance(received_session, ServerSession)


@pytest.mark.anyio
async def test_roots_list_changed_notification_with_session():
"""Test that the server receives and handles RootsListChangedNotification with a session."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()
received_session = None

@server.roots_list_changed_notification()
async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession):
nonlocal received_session
received_session = session
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()
assert isinstance(received_session, ServerSession)
Loading