Skip to content

Commit 92d9426

Browse files
committed
Fix async callable object tools
1 parent babb477 commit 92d9426

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

src/mcp/server/fastmcp/tools/base.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import functools
34
import inspect
45
from collections.abc import Callable
56
from typing import TYPE_CHECKING, Any, get_origin
@@ -48,7 +49,7 @@ def from_function(
4849
raise ValueError("You must provide a name for lambda functions")
4950

5051
func_doc = description or fn.__doc__ or ""
51-
is_async = inspect.iscoroutinefunction(fn)
52+
is_async = _is_async_callable(fn)
5253

5354
if context_kwarg is None:
5455
sig = inspect.signature(fn)
@@ -92,3 +93,12 @@ async def run(
9293
)
9394
except Exception as e:
9495
raise ToolError(f"Error executing tool {self.name}: {e}") from e
96+
97+
98+
def _is_async_callable(obj: Any) -> bool:
99+
while isinstance(obj, functools.partial):
100+
obj = obj.func
101+
102+
return inspect.iscoroutinefunction(obj) or (
103+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
104+
)

tests/server/fastmcp/test_tool_manager.py

+61
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
7171
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
7272
assert "flag" in tool.parameters["properties"]
7373

74+
def test_add_callable_object(self):
75+
"""Test registering a callable object."""
76+
77+
class MyTool:
78+
def __init__(self):
79+
self.__name__ = "MyTool"
80+
81+
def __call__(self, x: int) -> int:
82+
return x * 2
83+
84+
manager = ToolManager()
85+
tool = manager.add_tool(MyTool())
86+
assert tool.name == "MyTool"
87+
assert tool.is_async is False
88+
assert tool.parameters["properties"]["x"]["type"] == "integer"
89+
90+
@pytest.mark.anyio
91+
async def test_add_async_callable_object(self):
92+
"""Test registering an async callable object."""
93+
94+
class MyAsyncTool:
95+
def __init__(self):
96+
self.__name__ = "MyAsyncTool"
97+
98+
async def __call__(self, x: int) -> int:
99+
return x * 2
100+
101+
manager = ToolManager()
102+
tool = manager.add_tool(MyAsyncTool())
103+
assert tool.name == "MyAsyncTool"
104+
assert tool.is_async is True
105+
assert tool.parameters["properties"]["x"]["type"] == "integer"
106+
74107
def test_add_invalid_tool(self):
75108
manager = ToolManager()
76109
with pytest.raises(AttributeError):
@@ -137,6 +170,34 @@ async def double(n: int) -> int:
137170
result = await manager.call_tool("double", {"n": 5})
138171
assert result == 10
139172

173+
@pytest.mark.anyio
174+
async def test_call_object_tool(self):
175+
class MyTool:
176+
def __init__(self):
177+
self.__name__ = "MyTool"
178+
179+
def __call__(self, x: int) -> int:
180+
return x * 2
181+
182+
manager = ToolManager()
183+
tool = manager.add_tool(MyTool())
184+
result = await tool.run({"x": 5})
185+
assert result == 10
186+
187+
@pytest.mark.anyio
188+
async def test_call_async_object_tool(self):
189+
class MyAsyncTool:
190+
def __init__(self):
191+
self.__name__ = "MyAsyncTool"
192+
193+
async def __call__(self, x: int) -> int:
194+
return x * 2
195+
196+
manager = ToolManager()
197+
tool = manager.add_tool(MyAsyncTool())
198+
result = await tool.run({"x": 5})
199+
assert result == 10
200+
140201
@pytest.mark.anyio
141202
async def test_call_tool_with_default_args(self):
142203
def add(a: int, b: int = 1) -> int:

0 commit comments

Comments
 (0)