Skip to content

Wrap JSONRPC messages with SessionMessage for metadata support #590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: ihrpr/resumability-server
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/mcp/client/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import JSONRPCMessage

if not sys.warnoptions:
import warnings
Expand All @@ -36,8 +36,8 @@ async def message_handler(


async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None,
):
async with ClientSession(
Expand Down
5 changes: 3 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

Expand Down Expand Up @@ -92,8 +93,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down
20 changes: 12 additions & 8 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from httpx_sse import aconnect_sse

import mcp.types as types
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)

Expand All @@ -31,11 +32,11 @@ async def sse_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand Down Expand Up @@ -97,7 +98,8 @@ async def sse_reader(
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
Expand All @@ -111,11 +113,13 @@ async def sse_reader(
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
async for session_message in write_stream_reader:
logger.debug(
f"Sending client message: {session_message}"
)
response = await client.post(
endpoint_url,
json=message.model_dump(
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
Expand Down
18 changes: 11 additions & 7 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel, Field

import mcp.types as types
from mcp.shared.message import SessionMessage

from .win32 import (
create_windows_process,
Expand Down Expand Up @@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand Down Expand Up @@ -143,7 +144,8 @@ async def stdout_reader():
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

Expand All @@ -152,8 +154,10 @@ async def stdin_writer():

try:
async with write_stream_reader:
async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True)
async for session_message in write_stream_reader:
json = session_message.message.model_dump_json(
by_alias=True, exclude_none=True
)
await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
Expand Down
23 changes: 16 additions & 7 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import httpx
from httpx_sse import EventSource, aconnect_sse

from mcp.shared.message import SessionMessage
from mcp.types import (
ErrorData,
JSONRPCError,
Expand Down Expand Up @@ -52,10 +53,10 @@ async def streamablehttp_client(
"""

read_stream_writer, read_stream = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
SessionMessage | Exception
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
JSONRPCMessage
SessionMessage
](0)

async def get_stream():
Expand Down Expand Up @@ -86,7 +87,8 @@ async def get_stream():
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"GET message: {message}")
await read_stream_writer.send(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing GET message: {exc}")
await read_stream_writer.send(exc)
Expand All @@ -100,7 +102,8 @@ async def post_writer(client: httpx.AsyncClient):
nonlocal session_id
try:
async with write_stream_reader:
async for message in write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
# Add session ID to headers if we have one
post_headers = request_headers.copy()
if session_id:
Expand Down Expand Up @@ -141,9 +144,10 @@ async def post_writer(client: httpx.AsyncClient):
message="Session terminated",
),
)
await read_stream_writer.send(
session_message = SessionMessage(
JSONRPCMessage(jsonrpc_error)
)
await read_stream_writer.send(session_message)
continue
response.raise_for_status()

Expand All @@ -163,7 +167,8 @@ async def post_writer(client: httpx.AsyncClient):
json_message = JSONRPCMessage.model_validate_json(
content
)
await read_stream_writer.send(json_message)
session_message = SessionMessage(json_message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)
Expand All @@ -175,11 +180,15 @@ async def post_writer(client: httpx.AsyncClient):
async for sse in event_source.aiter_sse():
if sse.event == "message":
try:
await read_stream_writer.send(
message = (
JSONRPCMessage.model_validate_json(
sse.data
)
)
session_message = SessionMessage(message)
await read_stream_writer.send(
session_message
)
except Exception as exc:
logger.exception("Error parsing message")
await read_stream_writer.send(exc)
Expand Down
20 changes: 11 additions & 9 deletions src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from websockets.typing import Subprotocol

import mcp.types as types
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)

Expand All @@ -19,8 +20,8 @@ async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
],
None,
]:
Expand All @@ -39,10 +40,10 @@ async def websocket_client(
# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand All @@ -59,7 +60,8 @@ async def ws_reader():
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
await read_stream_writer.send(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
Expand All @@ -70,9 +72,9 @@ async def ws_writer():
sends them to the server.
"""
async with write_stream_reader:
async for message in write_stream_reader:
async for session_message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = message.model_dump(
msg_dict = session_message.message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))
Expand Down
5 changes: 3 additions & 2 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async def main():
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -471,8 +472,8 @@ async def handler(req: types.CompleteRequest):

async def run(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down
Expand Down
5 changes: 3 additions & 2 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:

import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
Expand Down Expand Up @@ -82,8 +83,8 @@ class ServerSession(

def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
standalone_mode: bool = False,
) -> None:
Expand Down
24 changes: 12 additions & 12 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)

Expand All @@ -63,9 +64,7 @@ class SseServerTransport:
"""

_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]

def __init__(self, endpoint: str) -> None:
"""
Expand All @@ -85,11 +84,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
raise ValueError("connect_sse can only handle HTTP requests")

logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
Expand All @@ -109,12 +108,12 @@ async def sse_writer():
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}")

async for message in write_stream_reader:
logger.debug(f"Sending message via SSE: {message}")
async for session_message in write_stream_reader:
logger.debug(f"Sending message via SSE: {session_message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
"data": session_message.message.model_dump_json(
by_alias=True, exclude_none=True
),
}
Expand Down Expand Up @@ -169,7 +168,8 @@ async def handle_post_message(
await writer.send(err)
return

logger.debug(f"Sending message to writer: {message}")
session_message = SessionMessage(message)
logger.debug(f"Sending session message to writer: {session_message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)
await writer.send(session_message)
Loading
Loading