Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: reorganize message handling for better type safety and clarity #239

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from datetime import timedelta
from typing import Any, Protocol

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter

import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS


Expand Down Expand Up @@ -57,8 +56,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: ReadStream,
write_stream: WriteStream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down
23 changes: 16 additions & 7 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse

import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame

logger = logging.getLogger(__name__)

Expand All @@ -31,11 +37,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: ReadStream
read_stream_writer: ReadStreamWriter

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: WriteStream
write_stream_reader: WriteStreamReader

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand Down Expand Up @@ -84,8 +90,11 @@ async def sse_reader(

case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
message = MessageFrame(
root=types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
),
raw=sse,
)
logger.debug(
f"Received server message: {message}"
Expand Down
7 changes: 3 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ async def main():
from typing import Any, AsyncIterator, Generic, TypeVar

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl

import mcp.types as types
Expand All @@ -82,7 +81,7 @@ async def main():
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from mcp.shared.session import ReadStream, RequestResponder, WriteStream

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -472,8 +471,8 @@ async def handler(req: types.CompleteRequest):

async def run(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: ReadStream,
write_stream: WriteStream,
initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down
Expand Down
4 changes: 1 addition & 3 deletions src/mcp/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from pydantic import BaseModel

from mcp.types import (
ServerCapabilities,
)
from mcp.types import ServerCapabilities


class InitializationOptions(BaseModel):
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:

import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl

import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.session import (
BaseSession,
ReadStream,
RequestResponder,
WriteStream,
)


Expand All @@ -73,8 +74,8 @@ class ServerSession(

def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: ReadStream,
write_stream: WriteStream,
init_options: InitializationOptions,
) -> None:
super().__init__(
Expand Down
22 changes: 13 additions & 9 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ async def handle_sse(request):
from uuid import UUID, uuid4

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame

logger = logging.getLogger(__name__)

Expand All @@ -63,9 +69,7 @@ class SseServerTransport:
"""

_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
_read_stream_writers: dict[UUID, ReadStreamWriter]

def __init__(self, endpoint: str) -> None:
"""
Expand All @@ -85,11 +89,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
raise ValueError("connect_sse can only handle HTTP requests")

logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: ReadStream
read_stream_writer: ReadStreamWriter

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: WriteStream
write_stream_reader: WriteStreamReader

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand Down Expand Up @@ -172,4 +176,4 @@ async def handle_post_message(
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)
await writer.send(MessageFrame(root=message, raw=request))
19 changes: 13 additions & 6 deletions src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ async def run_server():

import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame


@asynccontextmanager
Expand All @@ -47,11 +53,11 @@ async def stdio_server(
if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))

read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: ReadStream
read_stream_writer: ReadStreamWriter

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: WriteStream
write_stream_reader: WriteStreamReader

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand All @@ -66,14 +72,15 @@ async def stdin_reader():
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(message)
await read_stream_writer.send(MessageFrame(root=message, raw=line))
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

async def stdout_writer():
try:
async with write_stream_reader:
async for message in write_stream_reader:
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n")
await stdout.flush()
Expand Down
20 changes: 14 additions & 6 deletions src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
from contextlib import asynccontextmanager

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket

import mcp.types as types
from mcp.shared.session import (
ReadStream,
ReadStreamWriter,
WriteStream,
WriteStreamReader,
)
from mcp.types import MessageFrame

logger = logging.getLogger(__name__)

Expand All @@ -21,11 +27,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp")

read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: ReadStream
read_stream_writer: ReadStreamWriter

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: WriteStream
write_stream_reader: WriteStreamReader

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand All @@ -40,7 +46,9 @@ async def ws_reader():
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(client_message)
await read_stream_writer.send(
MessageFrame(root=client_message, raw=message)
)
except anyio.ClosedResourceError:
await websocket.close()

Expand Down
17 changes: 7 additions & 10 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage
from mcp.types import MessageFrame

MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
MemoryObjectReceiveStream[MessageFrame | Exception],
MemoryObjectSendStream[MessageFrame],
]


Expand All @@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
MessageFrame | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
MessageFrame | Exception
](1)

client_streams = (server_to_client_receive, client_to_server_send)
Expand All @@ -60,12 +60,9 @@ async def create_connected_server_and_client_session(
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
(client_read, client_write),
(server_read, server_write),
):
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
Expand Down
Loading