Skip to content

Commit c7b3988

Browse files
committed
fix: update RunState with current turn persisted item tracking
1 parent 1406aa0 commit c7b3988

File tree

2 files changed

+60
-28
lines changed

2 files changed

+60
-28
lines changed

src/agents/run.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -831,9 +831,6 @@ async def run(
831831
# If resuming from an interrupted state, execute approved tools first
832832
if is_resumed_state and run_state is not None and run_state._current_step is not None:
833833
if isinstance(run_state._current_step, NextStepInterruption):
834-
# Track items before executing approved tools
835-
items_before_execution = len(generated_items)
836-
837834
# We're resuming from an interruption - execute approved tools
838835
await self._execute_approved_tools(
839836
agent=current_agent,
@@ -844,14 +841,9 @@ async def run(
844841
hooks=hooks,
845842
)
846843

847-
# Save the newly executed tool outputs to the session
848-
new_tool_outputs: list[RunItem] = [
849-
item
850-
for item in generated_items[items_before_execution:]
851-
if item.type == "tool_call_output_item"
852-
]
853-
if new_tool_outputs and session is not None:
854-
await self._save_result_to_session(session, [], new_tool_outputs)
844+
# Save new items (counter tracks what's already saved)
845+
if session is not None:
846+
await self._save_result_to_session(session, [], generated_items, run_state)
855847

856848
# Clear the current step since we've handled it
857849
run_state._current_step = None
@@ -881,6 +873,9 @@ async def run(
881873
current_span.span_data.tools = [t.name for t in all_tools]
882874

883875
current_turn += 1
876+
if run_state is not None:
877+
run_state._current_turn_persisted_item_count = 0
878+
884879
if current_turn > max_turns:
885880
_error_tracing.attach_error_to_span(
886881
current_span,
@@ -995,7 +990,7 @@ async def run(
995990
for guardrail_result in input_guardrail_results
996991
):
997992
await self._save_result_to_session(
998-
session, [], turn_result.new_step_items
993+
session, [], turn_result.new_step_items, run_state
999994
)
1000995
return result
1001996
elif isinstance(turn_result.next_step, NextStepInterruption):
@@ -1035,7 +1030,7 @@ async def run(
10351030
for guardrail_result in input_guardrail_results
10361031
):
10371032
await self._save_result_to_session(
1038-
session, [], turn_result.new_step_items
1033+
session, [], turn_result.new_step_items, run_state
10391034
)
10401035
else:
10411036
raise AgentsException(
@@ -1416,9 +1411,6 @@ async def _start_streaming(
14161411
# If resuming from an interrupted state, execute approved tools first
14171412
if run_state is not None and run_state._current_step is not None:
14181413
if isinstance(run_state._current_step, NextStepInterruption):
1419-
# Track items before executing approved tools
1420-
items_before_execution = len(streamed_result.new_items)
1421-
14221414
# We're resuming from an interruption - execute approved tools
14231415
await cls._execute_approved_tools_static(
14241416
agent=current_agent,
@@ -1429,14 +1421,11 @@ async def _start_streaming(
14291421
hooks=hooks,
14301422
)
14311423

1432-
# Save the newly executed tool outputs to the session
1433-
new_tool_outputs: list[RunItem] = [
1434-
item
1435-
for item in streamed_result.new_items[items_before_execution:]
1436-
if item.type == "tool_call_output_item"
1437-
]
1438-
if new_tool_outputs and session is not None:
1439-
await cls._save_result_to_session(session, [], new_tool_outputs)
1424+
# Save new items (counter tracks what's already saved)
1425+
if session is not None:
1426+
await cls._save_result_to_session(
1427+
session, [], streamed_result.new_items, run_state
1428+
)
14401429

14411430
# Clear the current step since we've handled it
14421431
run_state._current_step = None
@@ -1475,6 +1464,8 @@ async def _start_streaming(
14751464
current_span.span_data.tools = tool_names
14761465
current_turn += 1
14771466
streamed_result.current_turn = current_turn
1467+
if run_state is not None:
1468+
run_state._current_turn_persisted_item_count = 0
14781469

14791470
if current_turn > max_turns:
14801471
_error_tracing.attach_error_to_span(
@@ -1604,7 +1595,7 @@ async def _start_streaming(
16041595
)
16051596
if should_skip_session_save is False:
16061597
await AgentRunner._save_result_to_session(
1607-
session, [], turn_result.new_step_items
1598+
session, [], turn_result.new_step_items, run_state
16081599
)
16091600

16101601
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1623,7 +1614,7 @@ async def _start_streaming(
16231614
)
16241615
if should_skip_session_save is False:
16251616
await AgentRunner._save_result_to_session(
1626-
session, [], turn_result.new_step_items
1617+
session, [], turn_result.new_step_items, run_state
16271618
)
16281619

16291620
# Check for soft cancel after turn completion
@@ -2494,9 +2485,14 @@ async def _save_result_to_session(
24942485
session: Session | None,
24952486
original_input: str | list[TResponseInputItem],
24962487
new_items: list[RunItem],
2488+
run_state: RunState[Any] | None = None,
24972489
) -> None:
24982490
"""
2499-
Save the conversation turn to session.
2491+
Save the conversation turn to session with incremental tracking.
2492+
2493+
Uses run_state._current_turn_persisted_item_count to track which items
2494+
have already been persisted, allowing partial saves within a turn.
2495+
25002496
It does not account for any filtering or modification performed by
25012497
`RunConfig.session_input_callback`.
25022498
"""
@@ -2506,13 +2502,34 @@ async def _save_result_to_session(
25062502
# Convert original input to list format if needed
25072503
input_list = ItemHelpers.input_to_new_input_list(original_input)
25082504

2505+
# Track which items have already been persisted this turn
2506+
already_persisted = 0
2507+
if run_state is not None:
2508+
already_persisted = run_state._current_turn_persisted_item_count
2509+
2510+
# Only save items that haven't been persisted yet
2511+
new_run_items = new_items[already_persisted:]
2512+
25092513
# Convert new items to input format
2510-
new_items_as_input = [item.to_input_item() for item in new_items]
2514+
new_items_as_input = [item.to_input_item() for item in new_run_items]
25112515

25122516
# Save all items from this turn
25132517
items_to_save = input_list + new_items_as_input
2518+
2519+
if len(items_to_save) == 0:
2520+
# Update counter even if nothing to save
2521+
if run_state is not None:
2522+
run_state._current_turn_persisted_item_count = already_persisted + len(
2523+
new_run_items
2524+
)
2525+
return
2526+
25142527
await session.add_items(items_to_save)
25152528

2529+
# Update the counter after successful save
2530+
if run_state is not None:
2531+
run_state._current_turn_persisted_item_count = already_persisted + len(new_run_items)
2532+
25162533
@staticmethod
25172534
async def _input_guardrail_tripwire_triggered_for_stream(
25182535
streamed_result: RunResultStreaming,

src/agents/run_state.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51+
_current_turn_persisted_item_count: int = 0
52+
"""Tracks how many generated run items from this turn were already persisted to session.
53+
54+
When saving to session, we slice off only new entries. When a turn is interrupted
55+
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56+
continuing so pending tool outputs still get stored.
57+
"""
58+
5159
_current_agent: TAgent | None = None
5260
"""The agent currently handling the conversation."""
5361

@@ -337,6 +345,7 @@ def to_json(self) -> dict[str, Any]:
337345
if self._last_processed_response
338346
else None
339347
)
348+
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
340349
result["trace"] = None
341350

342351
return result
@@ -571,6 +580,9 @@ async def from_string(
571580
)
572581

573582
state._current_turn = state_json["currentTurn"]
583+
state._current_turn_persisted_item_count = state_json.get(
584+
"currentTurnPersistedItemCount", 0
585+
)
574586

575587
# Reconstruct model responses
576588
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -676,6 +688,9 @@ async def from_json(
676688
)
677689

678690
state._current_turn = state_json["currentTurn"]
691+
state._current_turn_persisted_item_count = state_json.get(
692+
"currentTurnPersistedItemCount", 0
693+
)
679694

680695
# Reconstruct model responses
681696
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

0 commit comments

Comments
 (0)