Skip to content

Commit 47851ef

Browse files
committed
fix: bring coverage back up, addressing edge cases
1 parent 33c41d6 commit 47851ef

File tree

9 files changed

+907
-63
lines changed

9 files changed

+907
-63
lines changed

src/agents/handoffs/history.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ def _build_summary_message(transcript: list[TResponseInputItem]) -> TResponseInp
126126
end_marker,
127127
]
128128
content = "\n".join(content_lines)
129-
assistant_message: dict[str, Any] = {
130-
"role": "assistant",
129+
summary_message: dict[str, Any] = {
130+
"role": "system",
131131
"content": content,
132132
}
133-
return cast(TResponseInputItem, assistant_message)
133+
return cast(TResponseInputItem, summary_message)
134134

135135

136136
def _format_transcript_item(item: TResponseInputItem) -> str:

src/agents/items.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
import json
45
import weakref
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast
@@ -56,6 +57,44 @@
5657
)
5758
from .usage import Usage
5859

60+
61+
def normalize_function_call_output_payload(payload: dict[str, Any]) -> dict[str, Any]:
62+
"""Ensure function_call_output payloads conform to Responses API expectations."""
63+
64+
payload_type = payload.get("type")
65+
if payload_type not in {"function_call_output", "function_call_result"}:
66+
return payload
67+
68+
output_value = payload.get("output")
69+
70+
if output_value is None:
71+
payload["output"] = ""
72+
return payload
73+
74+
if isinstance(output_value, list):
75+
if all(
76+
isinstance(entry, dict) and entry.get("type") in _ALLOWED_FUNCTION_CALL_OUTPUT_TYPES
77+
for entry in output_value
78+
):
79+
return payload
80+
payload["output"] = json.dumps(output_value)
81+
return payload
82+
83+
if isinstance(output_value, dict):
84+
entry_type = output_value.get("type")
85+
if entry_type in _ALLOWED_FUNCTION_CALL_OUTPUT_TYPES:
86+
payload["output"] = [output_value]
87+
else:
88+
payload["output"] = json.dumps(output_value)
89+
return payload
90+
91+
if isinstance(output_value, str):
92+
return payload
93+
94+
payload["output"] = json.dumps(output_value)
95+
return payload
96+
97+
5998
if TYPE_CHECKING:
6099
from .agent import Agent
61100

@@ -75,6 +114,15 @@
75114

76115
# Distinguish a missing dict entry from an explicit None value.
77116
_MISSING_ATTR_SENTINEL = object()
117+
_ALLOWED_FUNCTION_CALL_OUTPUT_TYPES: set[str] = {
118+
"input_text",
119+
"input_image",
120+
"output_text",
121+
"refusal",
122+
"input_file",
123+
"computer_screenshot",
124+
"summary_text",
125+
}
78126

79127

80128
@dataclass
@@ -220,6 +268,21 @@ def release_agent(self) -> None:
220268
# Preserve dataclass fields for repr/asdict while dropping strong refs.
221269
self.__dict__["target_agent"] = None
222270

271+
def to_input_item(self) -> TResponseInputItem:
272+
"""Convert handoff output into the API format expected by the model."""
273+
274+
if isinstance(self.raw_item, dict):
275+
payload = dict(self.raw_item)
276+
if payload.get("type") == "function_call_result":
277+
payload["type"] = "function_call_output"
278+
payload.pop("name", None)
279+
payload.pop("status", None)
280+
281+
payload = normalize_function_call_output_payload(payload)
282+
return cast(TResponseInputItem, payload)
283+
284+
return super().to_input_item()
285+
223286

224287
ToolCallItemTypes: TypeAlias = Union[
225288
ResponseFunctionToolCall,
@@ -273,15 +336,25 @@ def to_input_item(self) -> TResponseInputItem:
273336
Hosted tool outputs (e.g. shell/apply_patch) carry a `status` field for the SDK's
274337
book-keeping, but the Responses API does not yet accept that parameter. Strip it from the
275338
payload we send back to the model while keeping the original raw item intact.
339+
340+
Also converts protocol format (function_call_result) to API format (function_call_output).
276341
"""
277342

278343
if isinstance(self.raw_item, dict):
279344
payload = dict(self.raw_item)
280345
payload_type = payload.get("type")
281-
if payload_type == "shell_call_output":
346+
# Convert protocol format to API format
347+
# Protocol uses function_call_result, API expects function_call_output
348+
if payload_type == "function_call_result":
349+
payload["type"] = "function_call_output"
350+
# Remove fields that are in protocol format but not in API format
351+
payload.pop("name", None)
352+
payload.pop("status", None)
353+
elif payload_type == "shell_call_output":
282354
payload.pop("status", None)
283355
payload.pop("shell_output", None)
284356
payload.pop("provider_data", None)
357+
payload = normalize_function_call_output_payload(payload)
285358
return cast(TResponseInputItem, payload)
286359

287360
return super().to_input_item()
@@ -392,6 +465,17 @@ def arguments(self) -> str | None:
392465
return self.raw_item.arguments
393466
return None
394467

468+
def to_input_item(self) -> TResponseInputItem:
469+
"""ToolApprovalItem should never be converted to input items.
470+
471+
These items represent pending approvals and should be filtered out before
472+
preparing input for the API. This method raises an error to prevent accidental usage.
473+
"""
474+
raise AgentsException(
475+
"ToolApprovalItem cannot be converted to an input item. "
476+
"These items should be filtered out before preparing input for the API."
477+
)
478+
395479

396480
RunItem: TypeAlias = Union[
397481
MessageOutputItem,

src/agents/run.py

Lines changed: 131 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ToolCallItem,
6060
ToolCallItemTypes,
6161
TResponseInputItem,
62+
normalize_function_call_output_payload,
6263
)
6364
from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase
6465
from .logger import logger
@@ -758,10 +759,15 @@ async def run(
758759
# Resuming from a saved state
759760
run_state = cast(RunState[TContext], input)
760761
original_user_input = run_state._original_input
761-
# Normalize items to remove top-level providerData (API doesn't accept it there)
762+
# Normalize items to remove top-level providerData and convert protocol to API format
763+
# Then filter incomplete function calls to ensure API compatibility
762764
if isinstance(original_user_input, list):
763-
prepared_input: str | list[TResponseInputItem] = AgentRunner._normalize_input_items(
764-
original_user_input
765+
# Normalize first (converts protocol format to API format, normalizes field names)
766+
normalized = AgentRunner._normalize_input_items(original_user_input)
767+
# Filter incomplete function calls after normalization
768+
# This ensures consistent field names (call_id vs callId) for matching
769+
prepared_input: str | list[TResponseInputItem] = (
770+
AgentRunner._filter_incomplete_function_calls(normalized)
765771
)
766772
else:
767773
prepared_input = original_user_input
@@ -810,12 +816,16 @@ async def run(
810816
if is_resumed_state and run_state is not None:
811817
# Restore state from RunState
812818
current_turn = run_state._current_turn
813-
# Normalize original_input to remove top-level providerData
814-
# (API doesn't accept it there)
819+
# Normalize original_input: remove top-level providerData,
820+
# convert protocol to API format, then filter incomplete function calls
815821
raw_original_input = run_state._original_input
816822
if isinstance(raw_original_input, list):
823+
# Normalize first (converts protocol to API format, normalizes field names)
824+
normalized = AgentRunner._normalize_input_items(raw_original_input)
825+
# Filter incomplete function calls after normalization
826+
# This ensures consistent field names (call_id vs callId) for matching
817827
original_input: str | list[TResponseInputItem] = (
818-
AgentRunner._normalize_input_items(raw_original_input)
828+
AgentRunner._filter_incomplete_function_calls(normalized)
819829
)
820830
else:
821831
original_input = raw_original_input
@@ -884,8 +894,40 @@ async def run(
884894
)
885895
in output_call_ids
886896
]
887-
# Save both function_call and function_call_output together
888-
items_to_save = tool_call_items + tool_output_items
897+
# Check which items are already in the session to avoid duplicates
898+
# Get existing items from session and extract their call_ids
899+
existing_items = await session.get_items()
900+
existing_call_ids: set[str] = set()
901+
for existing_item in existing_items:
902+
if isinstance(existing_item, dict):
903+
item_type = existing_item.get("type")
904+
if item_type in ("function_call", "function_call_output"):
905+
existing_call_id = existing_item.get(
906+
"call_id"
907+
) or existing_item.get("callId")
908+
if existing_call_id and isinstance(existing_call_id, str):
909+
existing_call_ids.add(existing_call_id)
910+
911+
# Filter out items that are already in the session
912+
items_to_save: list[RunItem] = []
913+
for item in tool_call_items + tool_output_items:
914+
item_call_id: str | None = None
915+
if isinstance(item.raw_item, dict):
916+
raw_call_id = item.raw_item.get("call_id") or item.raw_item.get(
917+
"callId"
918+
)
919+
item_call_id = (
920+
cast(str | None, raw_call_id) if raw_call_id else None
921+
)
922+
elif hasattr(item.raw_item, "call_id"):
923+
item_call_id = cast(
924+
str | None, getattr(item.raw_item, "call_id", None)
925+
)
926+
927+
# Only save if not already in session
928+
if item_call_id is None or item_call_id not in existing_call_ids:
929+
items_to_save.append(item)
930+
889931
if items_to_save:
890932
await self._save_result_to_session(session, [], items_to_save)
891933
# Clear the current step since we've handled it
@@ -1463,11 +1505,12 @@ async def _start_streaming(
14631505
# Resuming from state - normalize items to remove top-level providerData
14641506
# and filter incomplete function_call pairs
14651507
if isinstance(starting_input, list):
1466-
# Filter incomplete function_call pairs before normalizing
1467-
filtered = AgentRunner._filter_incomplete_function_calls(starting_input)
1468-
prepared_input: str | list[TResponseInputItem] = (
1469-
AgentRunner._normalize_input_items(filtered)
1470-
)
1508+
# Normalize field names first (camelCase -> snake_case) to ensure
1509+
# consistent field names for filtering
1510+
normalized_input = AgentRunner._normalize_input_items(starting_input)
1511+
# Filter incomplete function_call pairs after normalizing
1512+
filtered = AgentRunner._filter_incomplete_function_calls(normalized_input)
1513+
prepared_input: str | list[TResponseInputItem] = filtered
14711514
else:
14721515
prepared_input = starting_input
14731516
else:
@@ -2653,33 +2696,67 @@ def _normalize_input_items(items: list[TResponseInputItem]) -> list[TResponseInp
26532696
"""
26542697
from .run_state import _normalize_field_names
26552698

2699+
def _coerce_to_dict(value: TResponseInputItem) -> dict[str, Any] | None:
2700+
if isinstance(value, dict):
2701+
return dict(value)
2702+
if hasattr(value, "model_dump"):
2703+
try:
2704+
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
2705+
except Exception:
2706+
return None
2707+
return None
2708+
26562709
normalized: list[TResponseInputItem] = []
26572710
for item in items:
2658-
if isinstance(item, dict):
2659-
# Create a copy to avoid modifying the original
2660-
normalized_item = dict(item)
2661-
# Remove top-level providerData/provider_data - these should only be in content
2662-
# The API doesn't accept providerData at the top level of input items
2663-
normalized_item.pop("providerData", None)
2664-
normalized_item.pop("provider_data", None)
2665-
# Normalize item type: API expects 'function_call_output',
2666-
# not 'function_call_result'
2667-
item_type = normalized_item.get("type")
2668-
if item_type == "function_call_result":
2669-
normalized_item["type"] = "function_call_output"
2670-
item_type = "function_call_output"
2671-
# Remove invalid fields based on item type
2672-
# function_call_output items should not have 'name' field
2673-
if item_type == "function_call_output":
2674-
normalized_item.pop("name", None)
2675-
# Normalize field names (callId -> call_id, responseId -> response_id)
2676-
normalized_item = _normalize_field_names(normalized_item)
2677-
normalized.append(cast(TResponseInputItem, normalized_item))
2678-
else:
2679-
# For non-dict items, keep as-is (they should already be in correct format)
2711+
coerced = _coerce_to_dict(item)
2712+
if coerced is None:
26802713
normalized.append(item)
2714+
continue
2715+
2716+
normalized_item = dict(coerced)
2717+
normalized_item.pop("providerData", None)
2718+
normalized_item.pop("provider_data", None)
2719+
item_type = normalized_item.get("type")
2720+
if item_type == "function_call_result":
2721+
normalized_item["type"] = "function_call_output"
2722+
item_type = "function_call_output"
2723+
if item_type == "function_call_output":
2724+
normalized_item.pop("name", None)
2725+
normalized_item.pop("status", None)
2726+
normalized_item = normalize_function_call_output_payload(normalized_item)
2727+
normalized_item = _normalize_field_names(normalized_item)
2728+
normalized.append(cast(TResponseInputItem, normalized_item))
26812729
return normalized
26822730

2731+
@staticmethod
2732+
def _ensure_api_input_item(item: TResponseInputItem) -> TResponseInputItem:
2733+
"""Ensure item is in API format (function_call_output, snake_case fields)."""
2734+
2735+
def _coerce_dict(value: TResponseInputItem) -> dict[str, Any] | None:
2736+
if isinstance(value, dict):
2737+
return dict(value)
2738+
if hasattr(value, "model_dump"):
2739+
try:
2740+
return cast(dict[str, Any], value.model_dump(exclude_unset=True))
2741+
except Exception:
2742+
return None
2743+
return None
2744+
2745+
coerced = _coerce_dict(item)
2746+
if coerced is None:
2747+
return item
2748+
2749+
normalized = dict(coerced)
2750+
item_type = normalized.get("type")
2751+
if item_type == "function_call_result":
2752+
normalized["type"] = "function_call_output"
2753+
normalized.pop("name", None)
2754+
normalized.pop("status", None)
2755+
2756+
if normalized.get("type") == "function_call_output":
2757+
normalized = normalize_function_call_output_payload(normalized)
2758+
return cast(TResponseInputItem, normalized)
2759+
26832760
@classmethod
26842761
async def _prepare_input_with_session(
26852762
cls,
@@ -2704,13 +2781,19 @@ async def _prepare_input_with_session(
27042781
# Get previous conversation history
27052782
history = await session.get_items()
27062783

2784+
# Convert protocol format items from session to API format.
2785+
# TypeScript may save protocol format (function_call_result) to sessions,
2786+
# but the API expects API format (function_call_output).
2787+
converted_history = [cls._ensure_api_input_item(item) for item in history]
2788+
27072789
# Convert input to list format
27082790
new_input_list = ItemHelpers.input_to_new_input_list(input)
2791+
new_input_list = [cls._ensure_api_input_item(item) for item in new_input_list]
27092792

27102793
if session_input_callback is None:
2711-
merged = history + new_input_list
2794+
merged = converted_history + new_input_list
27122795
elif callable(session_input_callback):
2713-
res = session_input_callback(history, new_input_list)
2796+
res = session_input_callback(converted_history, new_input_list)
27142797
if inspect.isawaitable(res):
27152798
merged = await res
27162799
else:
@@ -2764,10 +2847,19 @@ async def _save_result_to_session(
27642847
return
27652848

27662849
# Convert original input to list format if needed
2767-
input_list = ItemHelpers.input_to_new_input_list(original_input)
2850+
input_list = [
2851+
cls._ensure_api_input_item(item)
2852+
for item in ItemHelpers.input_to_new_input_list(original_input)
2853+
]
2854+
2855+
# Filter out tool_approval_item items before converting to input format
2856+
# These items represent pending approvals and shouldn't be sent to the API
2857+
items_to_convert = [item for item in new_items if item.type != "tool_approval_item"]
27682858

27692859
# Convert new items to input format
2770-
new_items_as_input = [item.to_input_item() for item in new_items]
2860+
new_items_as_input = [
2861+
cls._ensure_api_input_item(item.to_input_item()) for item in items_to_convert
2862+
]
27712863

27722864
# Save all items from this turn
27732865
items_to_save = input_list + new_items_as_input

0 commit comments

Comments
 (0)