Skip to content

Commit 321bb68

Browse files
committed
fix: addressing edge cases when resuming (continued)
1 parent e8f4cdb commit 321bb68

File tree

3 files changed

+226
-33
lines changed

3 files changed

+226
-33
lines changed

src/agents/run.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,13 @@ def prepare_input(
166166

167167
# On first call (when there are no generated items yet), include the original input
168168
if not generated_items:
169-
input_items.extend(ItemHelpers.input_to_new_input_list(original_input))
169+
# Normalize original_input items to ensure field names are in snake_case
170+
# (items from RunState deserialization may have camelCase)
171+
raw_input_list = ItemHelpers.input_to_new_input_list(original_input)
172+
# Filter out function_call items that don't have corresponding function_call_output
173+
# (API requires every function_call to have a function_call_output)
174+
filtered_input_list = AgentRunner._filter_incomplete_function_calls(raw_input_list)
175+
input_items.extend(AgentRunner._normalize_input_items(filtered_input_list))
170176

171177
# First, collect call_ids from tool_call_output_item items
172178
# (completed tool calls with outputs) and build a map of
@@ -753,8 +759,8 @@ async def run(
753759
original_user_input = run_state._original_input
754760
# Normalize items to remove top-level providerData (API doesn't accept it there)
755761
if isinstance(original_user_input, list):
756-
prepared_input: str | list[TResponseInputItem] = (
757-
AgentRunner._normalize_input_items(original_user_input)
762+
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
763+
original_user_input
758764
)
759765
else:
760766
prepared_input = original_user_input
@@ -856,8 +862,7 @@ async def run(
856862
if session is not None and generated_items:
857863
# Save tool_call_output_item items (the outputs)
858864
tool_output_items: list[RunItem] = [
859-
item for item in generated_items
860-
if item.type == "tool_call_output_item"
865+
item for item in generated_items if item.type == "tool_call_output_item"
861866
]
862867
# Also find and save the corresponding function_call items
863868
# (they might not be in session if the run was interrupted before saving)
@@ -1455,9 +1460,12 @@ async def _start_streaming(
14551460
# state's input, causing duplicate items.
14561461
if run_state is not None:
14571462
# Resuming from state - normalize items to remove top-level providerData
1463+
# and filter incomplete function_call pairs
14581464
if isinstance(starting_input, list):
1465+
# Filter incomplete function_call pairs before normalizing
1466+
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
14591467
prepared_input: str | list[TResponseInputItem] = (
1460-
AgentRunner._normalize_input_items(starting_input)
1468+
AgentRunner._normalize_input_items(filtered)
14611469
)
14621470
else:
14631471
prepared_input = starting_input
@@ -2467,20 +2475,82 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
24672475

24682476
return run_config.model_provider.get_model(agent.model)
24692477

2478+
@staticmethod
2479+
def _filter_incomplete_function_calls(
2480+
items: list[TResponseInputItem],
2481+
) -> list[TResponseInputItem]:
2482+
"""Filter out function_call items that don't have corresponding function_call_output.
2483+
2484+
The OpenAI API requires every function_call in an assistant message to have a
2485+
corresponding function_call_output (tool message). This function ensures only
2486+
complete pairs are included to prevent API errors.
2487+
2488+
IMPORTANT: This only filters incomplete function_call items. All other items
2489+
(messages, complete function_call pairs, etc.) are preserved to maintain
2490+
conversation history integrity.
2491+
2492+
Args:
2493+
items: List of input items to filter
2494+
2495+
Returns:
2496+
Filtered list with only complete function_call pairs. All non-function_call
2497+
items and complete function_call pairs are preserved.
2498+
"""
2499+
# First pass: collect call_ids from function_call_output/function_call_result items
2500+
completed_call_ids: set[str] = set()
2501+
for item in items:
2502+
if isinstance(item, dict):
2503+
item_type = item.get("type")
2504+
# Handle both API format (function_call_output) and
2505+
# protocol format (function_call_result)
2506+
if item_type in ("function_call_output", "function_call_result"):
2507+
call_id = item.get("call_id") or item.get("callId")
2508+
if call_id and isinstance(call_id, str):
2509+
completed_call_ids.add(call_id)
2510+
2511+
# Second pass: only include function_call items that have corresponding outputs
2512+
filtered: list[TResponseInputItem] = []
2513+
for item in items:
2514+
if isinstance(item, dict):
2515+
item_type = item.get("type")
2516+
if item_type == "function_call":
2517+
call_id = item.get("call_id") or item.get("callId")
2518+
# Only include if there's a corresponding
2519+
# function_call_output/function_call_result
2520+
if call_id and call_id in completed_call_ids:
2521+
filtered.append(item)
2522+
else:
2523+
# Include all non-function_call items
2524+
filtered.append(item)
2525+
else:
2526+
# Include non-dict items as-is
2527+
filtered.append(item)
2528+
2529+
return filtered
2530+
24702531
@staticmethod
24712532
def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInputItem]:
2472-
"""Normalize input items by removing top-level providerData/provider_data.
2473-
2533+
"""Normalize input items by removing top-level providerData/provider_data
2534+
and normalizing field names (callId -> call_id).
2535+
24742536
The OpenAI API doesn't accept providerData at the top level of input items.
24752537
providerData should only be in content where it belongs. This function removes
24762538
top-level providerData while preserving it in content.
2477-
2539+
2540+
Also normalizes field names from camelCase (callId) to snake_case (call_id)
2541+
to match API expectations.
2542+
2543+
Normalizes item types: converts 'function_call_result' to 'function_call_output'
2544+
to match API expectations.
2545+
24782546
Args:
24792547
items: List of input items to normalize
2480-
2548+
24812549
Returns:
24822550
Normalized list of input items
24832551
"""
2552+
from .run_state import _normalize_field_names
2553+
24842554
normalized: list[TResponseInputItem] = []
24852555
for item in items:
24862556
if isinstance(item, dict):
@@ -2490,6 +2560,18 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
24902560
# The API doesn't accept providerData at the top level of input items
24912561
normalized_item.pop("providerData", None)
24922562
normalized_item.pop("provider_data", None)
2563+
# Normalize item type: API expects 'function_call_output',
2564+
# not 'function_call_result'
2565+
item_type = normalized_item.get("type")
2566+
if item_type == "function_call_result":
2567+
normalized_item["type"] = "function_call_output"
2568+
item_type = "function_call_output"
2569+
# Remove invalid fields based on item type
2570+
# function_call_output items should not have 'name' field
2571+
if item_type == "function_call_output":
2572+
normalized_item.pop("name", None)
2573+
# Normalize field names (callId -> call_id, responseId -> response_id)
2574+
normalized_item = _normalize_field_names(normalized_item)
24932575
normalized.append(cast(TResponseInputItem, normalized_item))
24942576
else:
24952577
# For non-dict items, keep as-is (they should already be in correct format)
@@ -2536,10 +2618,14 @@ async def _prepare_input_with_session(
25362618
f"Invalid `session_input_callback` value: {session_input_callback}. "
25372619
"Choose between `None` or a custom callable function."
25382620
)
2539-
2621+
2622+
# Filter incomplete function_call pairs before normalizing
2623+
# (API requires every function_call to have a function_call_output)
2624+
filtered = cls._filter_incomplete_function_calls(merged)
2625+
25402626
# Normalize items to remove top-level providerData and deduplicate by ID
2541-
normalized = cls._normalize_input_items(merged)
2542-
2627+
normalized = cls._normalize_input_items(filtered)
2628+
25432629
# Deduplicate items by ID to prevent sending duplicate items to the API
25442630
# This can happen when resuming from state and items are already in the session
25452631
seen_ids: set[str] = set()
@@ -2551,13 +2637,13 @@ async def _prepare_input_with_session(
25512637
item_id = cast(str | None, item.get("id"))
25522638
elif hasattr(item, "id"):
25532639
item_id = cast(str | None, getattr(item, "id", None))
2554-
2640+
25552641
# Only add items we haven't seen before (or items without IDs)
25562642
if item_id is None or item_id not in seen_ids:
25572643
deduplicated.append(item)
25582644
if item_id:
25592645
seen_ids.add(item_id)
2560-
2646+
25612647
return deduplicated
25622648

25632649
@classmethod

src/agents/run_state.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ class RunState(Generic[TContext, TAgent]):
4848
_current_turn: int = 0
4949
"""Current turn number in the conversation."""
5050

51-
_current_turn_persisted_item_count: int = 0
52-
"""Tracks how many generated run items from this turn were already persisted to session.
53-
54-
When saving to session, we slice off only new entries. When a turn is interrupted
55-
(e.g., awaiting tool approval) and later resumed, we rewind this counter before
56-
continuing so pending tool outputs still get stored.
57-
"""
58-
5951
_current_agent: TAgent | None = None
6052
"""The agent currently handling the conversation."""
6153

@@ -250,13 +242,63 @@ def to_json(self) -> dict[str, Any]:
250242
}
251243
model_responses.append(response_dict)
252244

245+
# Normalize and camelize originalInput if it's a list of items
246+
# Convert API format to protocol format to match TypeScript schema
247+
# Protocol expects function_call_result (not function_call_output)
248+
original_input_serialized = self._original_input
249+
if isinstance(original_input_serialized, list):
250+
# First pass: build a map of call_id -> function_call name
251+
# to help convert function_call_output to function_call_result
252+
call_id_to_name: dict[str, str] = {}
253+
for item in original_input_serialized:
254+
if isinstance(item, dict):
255+
item_type = item.get("type")
256+
call_id = item.get("call_id") or item.get("callId")
257+
name = item.get("name")
258+
if item_type == "function_call" and call_id and name:
259+
call_id_to_name[call_id] = name
260+
261+
normalized_items = []
262+
for item in original_input_serialized:
263+
if isinstance(item, dict):
264+
# Create a copy to avoid modifying the original
265+
normalized_item = dict(item)
266+
# Remove session/conversation metadata fields that shouldn't be in originalInput
267+
# These are not part of the input protocol schema
268+
normalized_item.pop("id", None)
269+
normalized_item.pop("created_at", None)
270+
# Remove top-level providerData/provider_data (protocol allows it but
271+
# we remove it for cleaner serialization)
272+
normalized_item.pop("providerData", None)
273+
normalized_item.pop("provider_data", None)
274+
# Convert API format to protocol format
275+
# API uses function_call_output, protocol uses function_call_result
276+
item_type = normalized_item.get("type")
277+
call_id = normalized_item.get("call_id") or normalized_item.get("callId")
278+
if item_type == "function_call_output":
279+
# Convert to protocol format: function_call_result
280+
normalized_item["type"] = "function_call_result"
281+
# Protocol format requires status field (default to 'completed')
282+
if "status" not in normalized_item:
283+
normalized_item["status"] = "completed"
284+
# Protocol format requires name field
285+
# Look it up from the corresponding function_call if missing
286+
if "name" not in normalized_item and call_id:
287+
normalized_item["name"] = call_id_to_name.get(call_id, "")
288+
# Normalize field names to camelCase for JSON (call_id -> callId)
289+
normalized_item = self._camelize_field_names(normalized_item)
290+
normalized_items.append(normalized_item)
291+
else:
292+
normalized_items.append(item)
293+
original_input_serialized = normalized_items
294+
253295
result = {
254296
"$schemaVersion": CURRENT_SCHEMA_VERSION,
255297
"currentTurn": self._current_turn,
256298
"currentAgent": {
257299
"name": self._current_agent.name,
258300
},
259-
"originalInput": self._original_input,
301+
"originalInput": original_input_serialized,
260302
"modelResponses": model_responses,
261303
"context": {
262304
"usage": {
@@ -345,7 +387,6 @@ def to_json(self) -> dict[str, Any]:
345387
if self._last_processed_response
346388
else None
347389
)
348-
result["currentTurnPersistedItemCount"] = self._current_turn_persisted_item_count
349390
result["trace"] = None
350391

351392
return result
@@ -571,18 +612,29 @@ async def from_string(
571612
context.usage = usage
572613
context._rebuild_approvals(context_data.get("approvals", {}))
573614

615+
# Normalize originalInput to remove providerData fields that may have been
616+
# included by TypeScript serialization. These fields are metadata and should
617+
# not be sent to the API.
618+
original_input_raw = state_json["originalInput"]
619+
if isinstance(original_input_raw, list):
620+
# Normalize each item in the list to remove providerData fields
621+
normalized_original_input = [
622+
_normalize_field_names(item) if isinstance(item, dict) else item
623+
for item in original_input_raw
624+
]
625+
else:
626+
# If it's a string, use it as-is
627+
normalized_original_input = original_input_raw
628+
574629
# Create the RunState instance
575630
state = RunState(
576631
context=context,
577-
original_input=state_json["originalInput"],
632+
original_input=normalized_original_input,
578633
starting_agent=current_agent,
579634
max_turns=state_json["maxTurns"],
580635
)
581636

582637
state._current_turn = state_json["currentTurn"]
583-
state._current_turn_persisted_item_count = state_json.get(
584-
"currentTurnPersistedItemCount", 0
585-
)
586638

587639
# Reconstruct model responses
588640
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))
@@ -679,18 +731,29 @@ async def from_json(
679731
context.usage = usage
680732
context._rebuild_approvals(context_data.get("approvals", {}))
681733

734+
# Normalize originalInput to remove providerData fields that may have been
735+
# included by TypeScript serialization. These fields are metadata and should
736+
# not be sent to the API.
737+
original_input_raw = state_json["originalInput"]
738+
if isinstance(original_input_raw, list):
739+
# Normalize each item in the list to remove providerData fields
740+
normalized_original_input = [
741+
_normalize_field_names(item) if isinstance(item, dict) else item
742+
for item in original_input_raw
743+
]
744+
else:
745+
# If it's a string, use it as-is
746+
normalized_original_input = original_input_raw
747+
682748
# Create the RunState instance
683749
state = RunState(
684750
context=context,
685-
original_input=state_json["originalInput"],
751+
original_input=normalized_original_input,
686752
starting_agent=current_agent,
687753
max_turns=state_json["maxTurns"],
688754
)
689755

690756
state._current_turn = state_json["currentTurn"]
691-
state._current_turn_persisted_item_count = state_json.get(
692-
"currentTurnPersistedItemCount", 0
693-
)
694757

695758
# Reconstruct model responses
696759
state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", []))

tests/test_run_state.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,50 @@ async def test_deserializes_various_item_types(self):
507507
assert isinstance(new_state._generated_items[1], ToolCallItem)
508508
assert isinstance(new_state._generated_items[2], ToolCallOutputItem)
509509

510+
async def test_serializes_original_input_with_function_call_output(self):
511+
"""Test that originalInput with function_call_output items is converted to protocol."""
512+
context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})
513+
agent = Agent(name="TestAgent")
514+
515+
# Create originalInput with function_call_output (API format)
516+
# This simulates items from session that are in API format
517+
original_input = [
518+
{
519+
"type": "function_call",
520+
"call_id": "call_123",
521+
"name": "test_tool",
522+
"arguments": '{"arg": "value"}',
523+
},
524+
{
525+
"type": "function_call_output",
526+
"call_id": "call_123",
527+
"output": "result",
528+
},
529+
]
530+
531+
state = RunState(
532+
context=context, original_input=original_input, starting_agent=agent, max_turns=5
533+
)
534+
535+
# Serialize - should convert function_call_output to function_call_result
536+
json_data = state.to_json()
537+
538+
# Verify originalInput was converted to protocol format
539+
assert isinstance(json_data["originalInput"], list)
540+
assert len(json_data["originalInput"]) == 2
541+
542+
# First item should remain function_call (with camelCase)
543+
assert json_data["originalInput"][0]["type"] == "function_call"
544+
assert json_data["originalInput"][0]["callId"] == "call_123"
545+
assert json_data["originalInput"][0]["name"] == "test_tool"
546+
547+
# Second item should be converted to function_call_result (protocol format)
548+
assert json_data["originalInput"][1]["type"] == "function_call_result"
549+
assert json_data["originalInput"][1]["callId"] == "call_123"
550+
assert json_data["originalInput"][1]["name"] == "test_tool" # Looked up from function_call
551+
assert json_data["originalInput"][1]["status"] == "completed" # Added default
552+
assert json_data["originalInput"][1]["output"] == "result"
553+
510554
async def test_deserialization_handles_unknown_agent_gracefully(self):
511555
"""Test that deserialization skips items with unknown agents."""
512556
context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})

0 commit comments

Comments
 (0)