Skip to content

Commit 36985b5

Browse files
fix: workflow duplicate task name hanging issue (#279)
1 parent bbffa8c commit 36985b5

File tree

4 files changed

+57
-50
lines changed

4 files changed

+57
-50
lines changed

src/strands_tools/batch.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -96,83 +96,73 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult:
9696
agent = kwargs.get("agent")
9797
invocations = kwargs.get("invocations", [])
9898
results = []
99-
99+
100100
try:
101101
if not hasattr(agent, "tool") or agent.tool is None:
102102
raise AttributeError("Agent does not have a valid 'tool' attribute.")
103-
103+
104104
for invocation in invocations:
105105
tool_name = invocation.get("name")
106106
arguments = invocation.get("arguments", {})
107107
tool_fn = getattr(agent.tool, tool_name, None)
108-
108+
109109
if callable(tool_fn):
110110
try:
111111
# Call the tool function with the provided arguments
112112
result = tool_fn(**arguments)
113-
113+
114114
# Create a consistent result structure
115-
batch_result = {
116-
"name": tool_name,
117-
"status": "success",
118-
"result": result
119-
}
115+
batch_result = {"name": tool_name, "status": "success", "result": result}
120116
results.append(batch_result)
121-
117+
122118
except Exception as e:
123119
error_msg = f"Error executing tool '{tool_name}': {str(e)}"
124120
console.print(error_msg)
125-
121+
126122
batch_result = {
127123
"name": tool_name,
128124
"status": "error",
129125
"error": str(e),
130-
"traceback": traceback.format_exc()
126+
"traceback": traceback.format_exc(),
131127
}
132128
results.append(batch_result)
133129
else:
134130
error_msg = f"Tool '{tool_name}' not found in agent"
135131
console.print(error_msg)
136-
137-
batch_result = {
138-
"name": tool_name,
139-
"status": "error",
140-
"error": error_msg
141-
}
132+
133+
batch_result = {"name": tool_name, "status": "error", "error": error_msg}
142134
results.append(batch_result)
143-
135+
144136
# Create a readable summary for the agent
145137
summary_lines = []
146138
summary_lines.append(f"Batch execution completed with {len(results)} tool(s):")
147-
139+
148140
for result in results:
149141
if result["status"] == "success":
150142
summary_lines.append(f"✓ {result['name']}: Success")
151143
else:
152144
summary_lines.append(f"✗ {result['name']}: Error - {result['error']}")
153-
145+
154146
summary_text = "\n".join(summary_lines)
155-
147+
156148
return {
157149
"toolUseId": tool_use_id,
158150
"status": "success",
159151
"content": [
160-
{
161-
"text": summary_text
162-
},
152+
{"text": summary_text},
163153
{
164154
"json": {
165155
"batch_summary": {
166156
"total_tools": len(results),
167157
"successful": len([r for r in results if r["status"] == "success"]),
168-
"failed": len([r for r in results if r["status"] == "error"])
158+
"failed": len([r for r in results if r["status"] == "error"]),
169159
},
170-
"results": results
160+
"results": results,
171161
}
172-
}
173-
]
162+
},
163+
],
174164
}
175-
165+
176166
except Exception as e:
177167
error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}"
178168
console.print(f"Error in batch tool: {str(e)}")

src/strands_tools/workflow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,11 @@ def start_workflow(self, workflow_id: str) -> Dict:
612612
for task in ready_tasks[:current_batch_size]:
613613
task_id = task["task_id"]
614614
if task_id not in active_futures and task_id not in completed_tasks:
615+
# Namespace task_id with workflow_id to prevent conflicts
616+
namespaced_task_id = f"{workflow_id}:{task_id}"
615617
tasks_to_submit.append(
616618
(
617-
task_id,
619+
namespaced_task_id,
618620
self.execute_task,
619621
(task, workflow),
620622
{},
@@ -633,9 +635,13 @@ def start_workflow(self, workflow_id: str) -> Dict:
633635

634636
# Process completed tasks
635637
completed_task_ids = []
636-
for task_id, future in active_futures.items():
638+
for namespaced_task_id, future in active_futures.items():
637639
if future in done:
638-
completed_task_ids.append(task_id)
640+
# Extract original task_id from namespaced version
641+
task_id = (
642+
namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
643+
)
644+
completed_task_ids.append(namespaced_task_id)
639645
try:
640646
result = future.result()
641647

tests/test_batch.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,29 @@
88
def mock_agent():
99
"""Fixture to create a mock agent with tools."""
1010
agent = MagicMock()
11-
11+
1212
# Create a mock tool registry that mimics the real agent's tool access pattern
1313
mock_tool_registry = MagicMock()
1414
mock_tool_registry.registry = {
1515
"http_request": MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}}),
1616
"use_aws": MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}}),
17-
"error_tool": MagicMock(side_effect=Exception("Tool execution failed"))
17+
"error_tool": MagicMock(side_effect=Exception("Tool execution failed")),
1818
}
1919
agent.tool_registry = mock_tool_registry
20-
20+
2121
# Create a custom mock tool object that properly handles getattr
2222
class MockTool:
2323
def __init__(self):
2424
self.http_request = mock_tool_registry.registry["http_request"]
2525
self.use_aws = mock_tool_registry.registry["use_aws"]
2626
self.error_tool = mock_tool_registry.registry["error_tool"]
27-
27+
2828
def __getattr__(self, name):
2929
# Return None for non-existent tools (this will make callable() return False)
3030
return None
31-
31+
3232
agent.tool = MockTool()
33-
33+
3434
return agent
3535

3636

@@ -47,18 +47,18 @@ def test_batch_success(mock_agent):
4747
assert result["toolUseId"] == "mock_tool_id"
4848
assert result["status"] == "success"
4949
assert len(result["content"]) == 2
50-
50+
5151
# Check the summary text
5252
assert "Batch execution completed with 2 tool(s):" in result["content"][0]["text"]
5353
assert "✓ http_request: Success" in result["content"][0]["text"]
5454
assert "✓ use_aws: Success" in result["content"][0]["text"]
55-
55+
5656
# Check the JSON results
5757
json_content = result["content"][1]["json"]
5858
assert json_content["batch_summary"]["total_tools"] == 2
5959
assert json_content["batch_summary"]["successful"] == 2
6060
assert json_content["batch_summary"]["failed"] == 0
61-
61+
6262
results = json_content["results"]
6363
assert len(results) == 2
6464
assert results[0]["name"] == "http_request"
@@ -81,17 +81,17 @@ def test_batch_missing_tool(mock_agent):
8181
assert result["toolUseId"] == "mock_tool_id"
8282
assert result["status"] == "success"
8383
assert len(result["content"]) == 2
84-
84+
8585
# Check the summary text
8686
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
8787
assert "✗ non_existent_tool: Error" in result["content"][0]["text"]
88-
88+
8989
# Check the JSON results
9090
json_content = result["content"][1]["json"]
9191
assert json_content["batch_summary"]["total_tools"] == 1
9292
assert json_content["batch_summary"]["successful"] == 0
9393
assert json_content["batch_summary"]["failed"] == 1
94-
94+
9595
results = json_content["results"]
9696
assert len(results) == 1
9797
assert results[0]["name"] == "non_existent_tool"
@@ -111,17 +111,17 @@ def test_batch_tool_error(mock_agent):
111111
assert result["toolUseId"] == "mock_tool_id"
112112
assert result["status"] == "success"
113113
assert len(result["content"]) == 2
114-
114+
115115
# Check the summary text
116116
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
117117
assert "✗ error_tool: Error" in result["content"][0]["text"]
118-
118+
119119
# Check the JSON results
120120
json_content = result["content"][1]["json"]
121121
assert json_content["batch_summary"]["total_tools"] == 1
122122
assert json_content["batch_summary"]["successful"] == 0
123123
assert json_content["batch_summary"]["failed"] == 1
124-
124+
125125
results = json_content["results"]
126126
assert len(results) == 1
127127
assert results[0]["name"] == "error_tool"
@@ -140,10 +140,10 @@ def test_batch_no_invocations(mock_agent):
140140
assert result["toolUseId"] == "mock_tool_id"
141141
assert result["status"] == "success"
142142
assert len(result["content"]) == 2
143-
143+
144144
# Check the summary text
145145
assert "Batch execution completed with 0 tool(s):" in result["content"][0]["text"]
146-
146+
147147
# Check the JSON results
148148
json_content = result["content"][1]["json"]
149149
assert json_content["batch_summary"]["total_tools"] == 0

tests/test_workflow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,17 @@ def test_start_workflow_not_found(self, mock_parent_agent):
247247
assert result["status"] == "error"
248248
assert "not found" in result["content"][0]["text"]
249249

250+
def test_task_id_namespacing(self):
251+
"""Test task ID namespacing and extraction logic."""
252+
workflow_id = "test_workflow"
253+
task_id = "task1"
254+
255+
namespaced_task_id = f"{workflow_id}:{task_id}"
256+
assert namespaced_task_id == "test_workflow:task1"
257+
258+
extracted_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
259+
assert extracted_id == "task1"
260+
250261

251262
class TestWorkflowStatus:
252263
"""Test workflow status functionality."""

0 commit comments

Comments
 (0)