Skip to content

Commit c63491f

Browse files
committed
Fix reasoning
1 parent d292ac4 commit c63491f

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class ToolReturnPart:
351351
tool_call_id: str
352352
"""The tool call identifier, this is used by some models including OpenAI."""
353353

354+
id: str | None = None
355+
354356
timestamp: datetime = field(default_factory=_now_utc)
355357
"""The timestamp, when the tool returned."""
356358

@@ -510,6 +512,8 @@ class ToolCallPart:
510512
This is stored either as a JSON string or a Python dictionary depending on how data was received.
511513
"""
512514

515+
id: str | None = None
516+
513517
tool_call_id: str = field(default_factory=_generate_tool_call_id)
514518
"""The tool call identifier, this is used by some models including OpenAI.
515519
@@ -799,8 +803,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
799803
updated_dict = {**(part.args or {}), **self.args_delta}
800804
part = replace(part, args=updated_dict)
801805

806+
# Does this ever change? Not sure why this is needed
802807
if self.tool_call_id:
803-
part = replace(part, tool_call_id=self.tool_call_id)
808+
part = replace(part, id=self.tool_call_id)
804809
return part
805810

806811

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -723,24 +723,22 @@ async def _map_messages(
723723
else:
724724
assert_never(part)
725725
elif isinstance(message, ModelResponse):
726-
thinking_parts: list[ThinkingPart] = []
727726
for item in message.parts:
728727
if isinstance(item, TextPart):
729728
openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
730729
elif isinstance(item, ToolCallPart):
731730
openai_messages.append(self._map_tool_call(item))
732731
elif isinstance(item, ThinkingPart):
733-
thinking_parts.append(item)
732+
# Reasoning needs to precede the tool call, it shouldn't get added last
733+
openai_messages.append(
734+
responses.ResponseReasoningItemParam(
735+
id=item.signature or '',
736+
summary=[Summary(text=item.content, type='summary_text')],
737+
type='reasoning',
738+
)
739+
)
734740
else:
735741
assert_never(item)
736-
if thinking_parts:
737-
openai_messages.append(
738-
responses.ResponseReasoningItemParam(
739-
id=thinking_parts[0].signature or '',
740-
summary=[Summary(text=item.content, type='summary_text') for item in thinking_parts],
741-
type='reasoning',
742-
)
743-
)
744742
else:
745743
assert_never(message)
746744
instructions = self._get_instructions(messages) or NOT_GIVEN
@@ -749,6 +747,7 @@ async def _map_messages(
749747
@staticmethod
750748
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
751749
return responses.ResponseFunctionToolCallParam(
750+
id=t.id or '',
752751
arguments=t.args_as_json_str(),
753752
call_id=_guard_tool_call_id(t=t),
754753
name=t.tool_name,
@@ -892,6 +891,15 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
892891
elif isinstance(chunk, responses.ResponseContentPartDoneEvent):
893892
pass # there's nothing we need to do here
894893

894+
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
895+
pass # there's nothing we need to do here
896+
897+
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
898+
pass # there's nothing we need to do here
899+
900+
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDoneEvent):
901+
pass # there's nothing we need to do here
902+
895903
elif isinstance(chunk, responses.ResponseCreatedEvent):
896904
pass # there's nothing we need to do here
897905

@@ -923,11 +931,15 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
923931
vendor_part_id=chunk.item.id,
924932
tool_name=chunk.item.name,
925933
args=chunk.item.arguments,
926-
tool_call_id=chunk.item.id,
934+
tool_call_id=chunk.item.call_id,
927935
)
928936
elif isinstance(chunk.item, responses.ResponseReasoningItem):
929937
content = chunk.item.summary[0].text if chunk.item.summary else ''
930-
yield self._parts_manager.handle_thinking_delta(vendor_part_id=chunk.item.id, content=content)
938+
yield self._parts_manager.handle_thinking_delta(
939+
vendor_part_id=chunk.item.id,
940+
content=content,
941+
signature=chunk.item.id,
942+
)
931943
elif isinstance(chunk.item, responses.ResponseOutputMessage):
932944
pass
933945
else:
@@ -940,6 +952,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
940952
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
941953
yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
942954

955+
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
956+
yield self._parts_manager.handle_thinking_delta(
957+
vendor_part_id=chunk.item_id,
958+
content=chunk.delta,
959+
signature=chunk.item_id,
960+
)
961+
943962
elif isinstance(chunk, responses.ResponseTextDoneEvent):
944963
pass # there's nothing we need to do here
945964

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dependencies = [
5656
# WARNING if you add optional groups, please update docs/install.md
5757
logfire = ["logfire>=3.11.0"]
5858
# Models
59-
openai = ["openai>=1.75.0"]
59+
openai = ["openai>=1.76.0"]
6060
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
6161
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
6262
anthropic = ["anthropic>=0.49.0"]

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)