diff --git a/src/strands_evals/extractors/tools_use_extractor.py b/src/strands_evals/extractors/tools_use_extractor.py index d7ef14e..9dad892 100644 --- a/src/strands_evals/extractors/tools_use_extractor.py +++ b/src/strands_evals/extractors/tools_use_extractor.py @@ -22,37 +22,38 @@ def extract_agent_tools_used_from_messages(agent_messages): if message.get("role") == "assistant": message_info = message.get("content") if len(message_info) > 0: - tool = None + tools = [] for message in message_info: if "toolUse" in message: - tool = message.get("toolUse") - - if tool: - tool_name = tool.get("name") - tool_input = tool.get("input") - tool_id = tool.get("toolUseId") - # get the tool result from the next message - tool_result = None - is_error = False - next_message_i = i + 1 - while next_message_i < len(agent_messages): - next_message = agent_messages[next_message_i] - next_message_i += 1 - - if next_message.get("role") == "user": - content = next_message.get("content") - if content: - tool_result_dict = content[0].get("toolResult") - if tool_result_dict.get("toolUseId") == tool_id: - tool_result_content = tool_result_dict.get("content", []) - if len(tool_result_content) > 0: - tool_result = tool_result_content[0].get("text") - is_error = tool_result_dict.get("status") == "error" - break - - tools_used.append( - {"name": tool_name, "input": tool_input, "tool_result": tool_result, "is_error": is_error} - ) + tools.append(message.get("toolUse")) + + for tool in tools: + if tool: + tool_name = tool.get("name") + tool_input = tool.get("input") + tool_id = tool.get("toolUseId") + # get the tool result from the next message + tool_result = None + is_error = False + next_message_i = i + 1 + while next_message_i < len(agent_messages): + next_message = agent_messages[next_message_i] + next_message_i += 1 + + if next_message.get("role") == "user": + content = next_message.get("content") + if content: + tool_result_dict = content[0].get("toolResult") + if tool_result_dict.get("toolUseId") == tool_id: + tool_result_content = tool_result_dict.get("content", []) + if len(tool_result_content) > 0: + tool_result = tool_result_content[0].get("text") + is_error = tool_result_dict.get("status") == "error" + break + + tools_used.append( + {"name": tool_name, "input": tool_input, "tool_result": tool_result, "is_error": is_error} + ) return tools_used diff --git a/tests/strands_evals/extractors/test_tools_use_extractor.py b/tests/strands_evals/extractors/test_tools_use_extractor.py index 3ff70d3..1a0efb3 100644 --- a/tests/strands_evals/extractors/test_tools_use_extractor.py +++ b/tests/strands_evals/extractors/test_tools_use_extractor.py @@ -48,6 +48,70 @@ def test_tools_use_extractor_extract_from_messages_with_tools(): assert result[0]["is_error"] is False +def test_tools_use_extractor_extract_from_messages_with_multiple_tools(): + """Test extracting multiple tool usages from messages""" + messages = [ + {"role": "user", "content": [{"text": "Calculate 2+2 and search for weather"}]}, + { + "role": "assistant", + "content": [ + {"text": "I'll calculate and search for you."}, + { + "toolUse": { + "toolUseId": "tool1", + "name": "calculator", + "input": {"expression": "2+2"}, + } + }, + { + "toolUse": { + "toolUseId": "tool2", + "name": "web_search", + "input": {"query": "current weather"}, + } + }, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "status": "success", + "content": [{"text": "Result: 4"}], + "toolUseId": "tool1", + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "status": "success", + "content": [{"text": "Sunny, 25°C"}], + "toolUseId": "tool2", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Results: 4 and sunny weather."}]}, + ] + + result = extract_agent_tools_used_from_messages(messages) + + assert len(result) == 2 + assert result[0]["name"] == "calculator" + assert result[0]["input"] == {"expression": "2+2"} + assert result[0]["tool_result"] == "Result: 4" + assert result[0]["is_error"] is False + assert result[1]["name"] == "web_search" + assert result[1]["input"] == {"query": "current weather"} + assert result[1]["tool_result"] == "Sunny, 25°C" + assert result[1]["is_error"] is False + + def test_tools_use_extractor_extract_from_messages_no_tools(): """Test extracting tool usage from messages without tool usage""" messages = [