@@ -71,6 +71,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
71
71
assert "age" in tool .parameters ["$defs" ]["UserInput" ]["properties" ]
72
72
assert "flag" in tool .parameters ["properties" ]
73
73
74
+ def test_add_callable_object (self ):
75
+ """Test registering a callable object."""
76
+
77
+ class MyTool :
78
+ def __init__ (self ):
79
+ self .__name__ = "MyTool"
80
+
81
+ def __call__ (self , x : int ) -> int :
82
+ return x * 2
83
+
84
+ manager = ToolManager ()
85
+ tool = manager .add_tool (MyTool ())
86
+ assert tool .name == "MyTool"
87
+ assert tool .is_async is False
88
+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
89
+
90
+ @pytest .mark .anyio
91
+ async def test_add_async_callable_object (self ):
92
+ """Test registering an async callable object."""
93
+
94
+ class MyAsyncTool :
95
+ def __init__ (self ):
96
+ self .__name__ = "MyAsyncTool"
97
+
98
+ async def __call__ (self , x : int ) -> int :
99
+ return x * 2
100
+
101
+ manager = ToolManager ()
102
+ tool = manager .add_tool (MyAsyncTool ())
103
+ assert tool .name == "MyAsyncTool"
104
+ assert tool .is_async is True
105
+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
106
+
74
107
def test_add_invalid_tool (self ):
75
108
manager = ToolManager ()
76
109
with pytest .raises (AttributeError ):
@@ -137,6 +170,34 @@ async def double(n: int) -> int:
137
170
result = await manager .call_tool ("double" , {"n" : 5 })
138
171
assert result == 10
139
172
173
+ @pytest .mark .anyio
174
+ async def test_call_object_tool (self ):
175
+ class MyTool :
176
+ def __init__ (self ):
177
+ self .__name__ = "MyTool"
178
+
179
+ def __call__ (self , x : int ) -> int :
180
+ return x * 2
181
+
182
+ manager = ToolManager ()
183
+ tool = manager .add_tool (MyTool ())
184
+ result = await tool .run ({"x" : 5 })
185
+ assert result == 10
186
+
187
+ @pytest .mark .anyio
188
+ async def test_call_async_object_tool (self ):
189
+ class MyAsyncTool :
190
+ def __init__ (self ):
191
+ self .__name__ = "MyAsyncTool"
192
+
193
+ async def __call__ (self , x : int ) -> int :
194
+ return x * 2
195
+
196
+ manager = ToolManager ()
197
+ tool = manager .add_tool (MyAsyncTool ())
198
+ result = await tool .run ({"x" : 5 })
199
+ assert result == 10
200
+
140
201
@pytest .mark .anyio
141
202
async def test_call_tool_with_default_args (self ):
142
203
def add (a : int , b : int = 1 ) -> int :
0 commit comments