Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def reply_func(
"config": copy.copy(config),
"init_config": config,
"reset_config": reset_config,
"ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func),
"ignore_async_in_sync_chat": ignore_async_in_sync_chat and is_coroutine_callable(reply_func),
},
)

Expand Down Expand Up @@ -908,9 +908,7 @@ def reply_func_from_nested_chats(
if use_async:
if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._a_summary_from_nested_chats
if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
reply_func_from_nested_chats
):
if not callable(reply_func_from_nested_chats) or not is_coroutine_callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")

async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
Expand Down Expand Up @@ -1299,7 +1297,7 @@ def _raise_exception_on_async_reply_functions(self) -> None:
f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False)
}

async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)]
async_reply_functions = [f for f in reply_functions if is_coroutine_callable(f)]
if async_reply_functions:
msg = (
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
Expand Down Expand Up @@ -2391,7 +2389,7 @@ def generate_function_call_reply(
call_id = message.get("id", None)
func_call = message["function_call"]
func = self._function_map.get(func_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
if is_coroutine_callable(func):
coro = self.a_execute_function(func_call, call_id=call_id)
_, func_return = self._run_async_in_thread(coro)
else:
Expand Down Expand Up @@ -2420,7 +2418,7 @@ async def a_generate_function_call_reply(
func_call = message["function_call"]
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
if func and inspect.iscoroutinefunction(func):
if func and is_coroutine_callable(func):
_, func_return = await self.a_execute_function(func_call, call_id=call_id)
else:
_, func_return = self.execute_function(func_call, call_id=call_id)
Expand Down Expand Up @@ -2875,7 +2873,7 @@ def generate_reply(
reply_func = reply_func_tuple["reply_func"]
if reply_func in exclude:
continue
if inspect.iscoroutinefunction(reply_func):
if is_coroutine_callable(reply_func):
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
Expand Down Expand Up @@ -2950,7 +2948,7 @@ async def a_generate_reply(
continue

if self._match_trigger(reply_func_tuple["trigger"], sender):
if inspect.iscoroutinefunction(reply_func):
if is_coroutine_callable(reply_func):
final, reply = await reply_func(
self,
messages=messages,
Expand Down Expand Up @@ -3171,6 +3169,12 @@ def execute_function(
)
try:
content = func(**arguments)
if inspect.isawaitable(content):

async def _await_result(awaitable):
return await awaitable

content = self._run_async_in_thread(_await_result(content))
is_exec_success = True
except Exception as e:
content = f"Error: {e}"
Expand Down Expand Up @@ -3238,7 +3242,7 @@ async def a_execute_function(
ExecuteFunctionEvent(func_name=func_name, call_id=call_id, arguments=arguments, recipient=self)
)
try:
if inspect.iscoroutinefunction(func):
if is_coroutine_callable(func):
content = await func(**arguments)
else:
# Fallback to sync function if the function is not async
Expand Down Expand Up @@ -3555,7 +3559,7 @@ async def _a_wrapped_func(*args, **kwargs):
log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval) if serialize else retval

wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
wrapped_func = _a_wrapped_func if is_coroutine_callable(func) else _wrapped_func

# needed for testing
wrapped_func._origin = func
Expand Down
32 changes: 28 additions & 4 deletions autogen/oai/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,34 @@ def format_tools(tools: list[dict[str, Any]]) -> dict[Literal["tools"], list[dic
}

for prop_name, prop_details in function["parameters"]["properties"].items():
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
"type": prop_details["type"],
"description": prop_details.get("description", ""),
}
if not isinstance(prop_details, dict):
raise TypeError(f"Property '{prop_name}' schema must be a dict, got {type(prop_details)!r}")

prop_schema: dict[str, Any] = {"description": prop_details.get("description", "")}

for key in (
"type",
"enum",
"default",
"anyOf",
"oneOf",
"allOf",
"items",
"const",
"format",
"minimum",
"maximum",
"minItems",
"maxItems",
"minLength",
"maxLength",
"pattern",
"additionalProperties",
):
if key in prop_details:
prop_schema[key] = prop_details[key]

converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = prop_schema
if "enum" in prop_details:
converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
"enum"
Expand Down
50 changes: 50 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import copy
import json
import os
import threading
import time
Expand Down Expand Up @@ -1968,6 +1969,55 @@ def sample_tool_func(my_prop: str) -> str:
assert "tool2" in tool_schemas


def test_execute_function_resolves_async_tool(mock_credentials: Credentials):
"""execute_function should await async tools instead of returning coroutine reprs."""
agent = ConversableAgent(name="agent", llm_config=mock_credentials.llm_config)
observed_inputs: list[str] = []

@agent.register_for_execution()
@agent.register_for_llm(description="Uppercase text asynchronously")
async def uppercase_tool(text: str) -> str:
observed_inputs.append(text)
await asyncio.sleep(0)
return text.upper()

success, payload = agent.execute_function(
{"name": "uppercase_tool", "arguments": json.dumps({"text": "nyc"})},
call_id="tool-call-1",
)

assert success is True
assert payload["content"] == "NYC"
assert observed_inputs == ["nyc"]


def test_generate_tool_calls_reply_handles_async_tool(mock_credentials: Credentials):
"""generate_tool_calls_reply should await async tools registered for execution."""
agent = ConversableAgent(name="agent", llm_config=mock_credentials.llm_config)

@agent.register_for_execution()
@agent.register_for_llm(description="Title case text asynchronously")
async def title_tool(text: str) -> str:
await asyncio.sleep(0)
return text.title()

message = {
"role": "assistant",
"tool_calls": [
{
"id": "call-xyz",
"function": {"name": "title_tool", "arguments": json.dumps({"text": "new york"})},
}
],
}

handled, response = agent.generate_tool_calls_reply(messages=[message])
assert handled is True
tool_response = response["tool_responses"][0]
assert tool_response["tool_call_id"] == "call-xyz"
assert tool_response["content"] == "New York"


def test_create_or_get_executor(mock_credentials: Credentials):
agent = ConversableAgent(name="agent", llm_config=mock_credentials.llm_config)
executor_agent = None
Expand Down
4 changes: 1 addition & 3 deletions test/agentchat/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ async def add_num(num_to_be_added):
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}

# Asset coroutine doesn't match.
assert user.execute_function(func_call=correct_args)[1]["content"] != "15"
# Asset awaited coroutine does match.
assert user.execute_function(func_call=correct_args)[1]["content"] == "15"
assert (await user.a_execute_function(func_call=correct_args))[1]["content"] == "15"

# function name called is wrong or doesn't exist
Expand Down
90 changes: 89 additions & 1 deletion test/oai/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from autogen.import_utils import run_for_optional_imports
from autogen.llm_config import LLMConfig
from autogen.oai.bedrock import BedrockClient, BedrockLLMConfigEntry, oai_messages_to_bedrock_messages
from autogen.oai.bedrock import BedrockClient, BedrockLLMConfigEntry, format_tools, oai_messages_to_bedrock_messages


# Fixtures for mock data
Expand Down Expand Up @@ -341,3 +341,91 @@ def test_oai_messages_to_bedrock_messages(bedrock_client: BedrockClient):
]

assert messages == expected_messages, "'Please continue' message was not appended."


def test_format_tools_handles_various_property_shapes():
"""format_tools should faithfully copy every supported JSON Schema shape (scalars, enums, unions, arrays, nested objects)."""
cases = [
(
"simple_type",
{"type": "string", "description": "plain type"},
{"type": "string", "description": "plain type"},
),
(
"enum_default",
{"type": "integer", "enum": [1, 2], "default": 1},
{"type": "integer", "enum": [1, 2], "default": 1, "description": ""},
),
(
"union_anyof",
{
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "optional text",
},
{
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "optional text",
},
),
(
"array_items",
{"type": "array", "items": {"type": "number"}, "minItems": 1},
{"type": "array", "items": {"type": "number"}, "minItems": 1, "description": ""},
),
(
"object_additional",
{
"type": "object",
"additionalProperties": {"type": "boolean"},
"required": [],
},
{
"type": "object",
"additionalProperties": {"type": "boolean"},
"description": "",
},
),
]

tools = [
{
"type": "function",
"function": {
"name": "schema_tester",
"description": "verifies schema copying",
"parameters": {
"type": "object",
"properties": {name: prop for name, prop, _ in cases},
},
},
}
]

converted_props = format_tools(tools)["tools"][0]["toolSpec"]["inputSchema"]["json"]["properties"]

for name, _, expected in cases:
assert converted_props[name] == expected, f"schema mismatch for {name}"


def test_format_tools_rejects_non_dict_properties():
"""format_tools should raise TypeError when a property schema is not a dict, mirroring runtime validation."""
tools = [
{
"type": "function",
"function": {
"name": "bad_prop",
"description": "schema with malformed property",
"parameters": {
"type": "object",
"properties": {
"oops": "not a dict",
},
},
},
}
]

with pytest.raises(TypeError, match="Property 'oops' schema must be a dict"):
format_tools(tools)
Loading