Skip to content

Commit 5dd00a2

Browse files
committed
fix: enhance agent state management during resume, ensuring correct agent usage and saving tool outputs to session
1 parent 2874402 commit 5dd00a2

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

src/agents/run.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,12 @@ async def run(
851851
tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
852852

853853
current_span: Span[AgentSpanData] | None = None
854-
current_agent = starting_agent
854+
# When resuming from state, use the current agent from the state (which may be different
855+
# from starting_agent if a handoff occurred). Otherwise use starting_agent.
856+
if is_resumed_state and run_state is not None and run_state._current_agent is not None:
857+
current_agent = run_state._current_agent
858+
else:
859+
current_agent = starting_agent
855860
should_run_agent_start_hooks = True
856861

857862
# save only the new user input to the session, not the combined history
@@ -1472,7 +1477,12 @@ async def _start_streaming(
14721477
streamed_result.trace.start(mark_as_current=True)
14731478

14741479
current_span: Span[AgentSpanData] | None = None
1475-
current_agent = starting_agent
1480+
# When resuming from state, use the current agent from the state (which may be different
1481+
# from starting_agent if a handoff occurred). Otherwise use starting_agent.
1482+
if run_state is not None and run_state._current_agent is not None:
1483+
current_agent = run_state._current_agent
1484+
else:
1485+
current_agent = starting_agent
14761486
current_turn = 0
14771487
should_run_agent_start_hooks = True
14781488
tool_use_tracker = AgentToolUseTracker()
@@ -1542,6 +1552,70 @@ async def _start_streaming(
15421552
run_config=run_config,
15431553
hooks=hooks,
15441554
)
1555+
# Save tool outputs to session immediately after approval
1556+
# This ensures incomplete function calls in the session are completed
1557+
if session is not None and streamed_result.new_items:
1558+
# Save tool_call_output_item items (the outputs)
1559+
tool_output_items: list[RunItem] = [
1560+
item
1561+
for item in streamed_result.new_items
1562+
if item.type == "tool_call_output_item"
1563+
]
1564+
# Also find and save the corresponding function_call items
1565+
# (they might not be in session if the run was interrupted before saving)
1566+
output_call_ids = {
1567+
item.raw_item.get("call_id")
1568+
if isinstance(item.raw_item, dict)
1569+
else getattr(item.raw_item, "call_id", None)
1570+
for item in tool_output_items
1571+
}
1572+
tool_call_items: list[RunItem] = [
1573+
item
1574+
for item in streamed_result.new_items
1575+
if item.type == "tool_call_item"
1576+
and (
1577+
item.raw_item.get("call_id")
1578+
if isinstance(item.raw_item, dict)
1579+
else getattr(item.raw_item, "call_id", None)
1580+
)
1581+
in output_call_ids
1582+
]
1583+
# Check which items are already in the session to avoid duplicates
1584+
# Get existing items from session and extract their call_ids
1585+
existing_items = await session.get_items()
1586+
existing_call_ids: set[str] = set()
1587+
for existing_item in existing_items:
1588+
if isinstance(existing_item, dict):
1589+
item_type = existing_item.get("type")
1590+
if item_type in ("function_call", "function_call_output"):
1591+
existing_call_id = existing_item.get(
1592+
"call_id"
1593+
) or existing_item.get("callId")
1594+
if existing_call_id and isinstance(existing_call_id, str):
1595+
existing_call_ids.add(existing_call_id)
1596+
1597+
# Filter out items that are already in the session
1598+
items_to_save: list[RunItem] = []
1599+
for item in tool_call_items + tool_output_items:
1600+
item_call_id: str | None = None
1601+
if isinstance(item.raw_item, dict):
1602+
raw_call_id = item.raw_item.get("call_id") or item.raw_item.get(
1603+
"callId"
1604+
)
1605+
item_call_id = (
1606+
cast(str | None, raw_call_id) if raw_call_id else None
1607+
)
1608+
elif hasattr(item.raw_item, "call_id"):
1609+
item_call_id = cast(
1610+
str | None, getattr(item.raw_item, "call_id", None)
1611+
)
1612+
1613+
# Only save if not already in session
1614+
if item_call_id is None or item_call_id not in existing_call_ids:
1615+
items_to_save.append(item)
1616+
1617+
if items_to_save:
1618+
await AgentRunner._save_result_to_session(session, [], items_to_save)
15451619
# Clear the current step since we've handled it
15461620
run_state._current_step = None
15471621

0 commit comments

Comments
 (0)