diff --git a/vllm_mlx/mcp/executor.py b/vllm_mlx/mcp/executor.py index 18918158e..11c7afcbe 100644 --- a/vllm_mlx/mcp/executor.py +++ b/vllm_mlx/mcp/executor.py @@ -95,6 +95,10 @@ def __init__( self.default_timeout = default_timeout or manager.config.default_timeout self.validate_arguments = validate_arguments self.sandbox = sandbox or get_sandbox() + # Lazily-built lookup index: full_name and short name → MCPTool. + # Rebuilt whenever the tool count changes (e.g. after refresh_tools). + self._tool_index: dict[str, "MCPTool"] = {} + self._tool_index_size: int = -1 async def execute_tool_calls( self, @@ -119,17 +123,21 @@ async def execute_tool_calls( else: return await self._execute_sequential(tool_calls) + def _get_tool_index(self) -> dict[str, "MCPTool"]: + """Return a name→tool index, rebuilding it if the tool list changed.""" + tools = self.manager.get_all_tools() + if len(tools) != self._tool_index_size: + idx: dict[str, "MCPTool"] = {} + for t in tools: + idx[t.full_name] = t + idx.setdefault(t.name, t) # short name maps to first match + self._tool_index = idx + self._tool_index_size = len(tools) + return self._tool_index + def _get_tool_by_name(self, full_name: str) -> Optional[MCPTool]: """Get a tool by its full name (server__tool or just tool).""" - for tool in self.manager.get_all_tools(): - if tool.full_name == full_name: - return tool - # Try without server prefix - if "__" not in full_name: - for tool in self.manager.get_all_tools(): - if tool.name == full_name: - return tool - return None + return self._get_tool_index().get(full_name) def _validate_tool_call(self, tool_call: Dict[str, Any]) -> Optional[str]: """ @@ -186,11 +194,8 @@ def _get_server_for_tool(self, full_name: str) -> str: """Extract server name from full tool name or find it.""" if "__" in full_name: return full_name.split("__")[0] - # Find which server has this tool - for tool in self.manager.get_all_tools(): - if tool.name == full_name: - return tool.server_name - return "unknown" + tool = self._get_tool_index().get(full_name) + return tool.server_name if tool else "unknown" async def _execute_parallel( self, @@ -465,18 +470,7 @@ def extract_and_validate( def _tool_exists(self, full_name: str) -> bool: """Check if a tool exists in any connected server.""" - # Check by full name - for tool in self.manager.get_all_tools(): - if tool.full_name == full_name: - return True - - # Check by just tool name (without server prefix) - if "__" not in full_name: - for tool in self.manager.get_all_tools(): - if tool.name == full_name: - return True - - return False + return full_name in self._get_tool_index() async def execute_single_tool(