diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 948817140..a2ba54d0b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -282,8 +282,12 @@ async def call_tool( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, + _meta: dict[str, Any] | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback.""" + + # Create the Meta object if _meta is provided + meta_obj = types.RequestParams.Meta(**_meta) if _meta else None return await self.send_request( types.ClientRequest( @@ -292,6 +296,7 @@ async def call_tool( params=types.CallToolRequestParams( name=name, arguments=arguments, + _meta=meta_obj, ), ) ), diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 700b5417f..9d8daa0d0 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -172,11 +172,13 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + async def call_tool( + self, name: str, args: dict[str, Any], _meta: dict[str, Any] | None = None + ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) + return await session.call_tool(session_tool_name, args, _meta=_meta) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..2d4f30616 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -495,3 +495,82 @@ async def mock_server(): assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) assert received_capabilities.roots.listChanged is True # Should be True for custom callback + + +@pytest.mark.anyio +async def test_client_session_call_tool_with_meta(): + """Test that call_tool properly handles the _meta parameter.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_request = None + + async def mock_server(): + nonlocal received_request + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + received_request = request + + # Send a successful response + result = ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text="Tool executed successfully")], + isError=False, + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + # Test call_tool with _meta parameter + meta_data = {"user_id": "12345", "session_id": "abc123"} + result = await session.call_tool( + name="test_tool", + arguments={"param1": "value1"}, + _meta=meta_data, + ) + + # Assert that the request was sent with the correct meta data + assert received_request is not None + assert isinstance(received_request.root, types.CallToolRequest) + assert received_request.root.params.name == "test_tool" + assert received_request.root.params.arguments == {"param1": "value1"} + assert received_request.root.params.meta is not None + assert received_request.root.params.meta.progressToken is None # No progressToken in our test meta + # The meta object should contain our custom data + assert hasattr(received_request.root.params.meta, "user_id") + assert hasattr(received_request.root.params.meta, "session_id") + + # Assert the result + assert isinstance(result, types.CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], types.TextContent) + assert result.content[0].text == "Tool executed successfully" diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 16a887e00..5890204f7 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -77,6 +77,40 @@ def hook(name, server_info): mock_session.call_tool.assert_called_once_with( "my_tool", {"name": "value1", "args": {}}, + _meta=None, + ) + + async def test_call_tool_with_meta(self): + # --- Mock Dependencies --- + mock_session = mock.AsyncMock() + + # --- Prepare Session Group --- + def hook(name, server_info): + return f"{(server_info.name)}-{name}" + + mcp_session_group = ClientSessionGroup(component_name_hook=hook) + mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})} + mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} + text_content = types.TextContent(type="text", text="OK") + mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) + + # --- Test Execution with _meta --- + meta_data = {"user_id": "12345", "session_id": "abc123"} + result = await mcp_session_group.call_tool( + name="server1-my_tool", + args={ + "name": "value1", + "args": {}, + }, + _meta=meta_data, + ) + + # --- Assertions --- + assert result.content == [text_content] + mock_session.call_tool.assert_called_once_with( + "my_tool", + {"name": "value1", "args": {}}, + _meta=meta_data, ) async def test_connect_to_server(self, mock_exit_stack):