Skip to content

Commit 64a3ea6

Browse files
committed
WIP: Remove OutputPart, work around allow_text_output instead
1 parent 19f8ad2 commit 64a3ea6

22 files changed

+133
-246
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,12 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
254254
output_mode = None
255255
output_object = None
256256
output_tools = []
257-
allow_text_output = _output.allow_text_output(output_schema)
257+
require_tool_use = False
258258
if output_schema:
259259
output_mode = output_schema.forced_mode or model.default_output_mode
260260
output_object = output_schema.object_schema.definition
261261
output_tools = output_schema.tool_defs()
262-
if output_mode != 'tool':
263-
allow_text_output = False
262+
require_tool_use = output_mode == 'tool' and not output_schema.allow_plain_text_output
264263

265264
supported_modes = model.supported_output_modes
266265
if output_mode not in supported_modes:
@@ -271,7 +270,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
271270
output_mode=output_mode,
272271
output_object=output_object,
273272
output_tools=output_tools,
274-
allow_text_output=allow_text_output,
273+
require_tool_use=require_tool_use,
275274
)
276275

277276

@@ -422,24 +421,20 @@ async def stream(
422421
async for _event in stream:
423422
pass
424423

425-
async def _run_stream( # noqa: C901
424+
async def _run_stream(
426425
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
427426
) -> AsyncIterator[_messages.HandleResponseEvent]:
428427
if self._events_iterator is None:
429428
# Ensure that the stream is only run once
430429

431430
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
432431
texts: list[str] = []
433-
outputs: list[str] = []
434432
tool_calls: list[_messages.ToolCallPart] = []
435433
for part in self.model_response.parts:
436434
if isinstance(part, _messages.TextPart):
437435
# ignore empty content for text parts, see #437
438436
if part.content:
439437
texts.append(part.content)
440-
elif isinstance(part, _messages.OutputPart):
441-
if part.content:
442-
outputs.append(part.content)
443438
elif isinstance(part, _messages.ToolCallPart):
444439
tool_calls.append(part)
445440
else:
@@ -452,9 +447,6 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
452447
if tool_calls:
453448
async for event in self._handle_tool_calls(ctx, tool_calls):
454449
yield event
455-
elif outputs: # TODO: Can we have tool calls and structured output? Should we handle both?
456-
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
457-
self._next_node = await self._handle_outputs(ctx, outputs)
458450
elif texts:
459451
# No events are emitted during the handling of text responses, so we don't need to yield anything
460452
self._next_node = await self._handle_text_response(ctx, texts)
@@ -546,42 +538,18 @@ async def _handle_text_response(
546538
output_schema = ctx.deps.output_schema
547539

548540
text = '\n\n'.join(texts)
549-
if _output.allow_text_output(output_schema):
550-
# The following cast is safe because we know `str` is an allowed result type
551-
result_data_input = cast(NodeRunEndT, text)
552-
try:
553-
result_data = await _validate_output(result_data_input, ctx, None)
554-
except _output.ToolRetryError as e:
555-
ctx.state.increment_retries(ctx.deps.max_result_retries)
556-
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
541+
try:
542+
if output_schema is None or output_schema.allow_plain_text_output:
543+
# The following cast is safe because we know `str` is an allowed result type
544+
result_data = cast(NodeRunEndT, text)
545+
elif output_schema.allow_json_text_output:
546+
result_data = output_schema.validate(text)
557547
else:
558-
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
559-
else:
560-
ctx.state.increment_retries(ctx.deps.max_result_retries)
561-
return ModelRequestNode[DepsT, NodeRunEndT](
562-
_messages.ModelRequest(
563-
parts=[
564-
_messages.RetryPromptPart(
565-
content='Plain text responses are not permitted, please include your response in a tool call',
566-
)
567-
]
548+
m = _messages.RetryPromptPart(
549+
content='Plain text responses are not permitted, please include your response in a tool call',
568550
)
569-
)
551+
raise _output.ToolRetryError(m)
570552

571-
async def _handle_outputs(
572-
self,
573-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
574-
outputs: list[str],
575-
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
576-
if len(outputs) != 1:
577-
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
578-
output_schema = ctx.deps.output_schema
579-
if not output_schema:
580-
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
581-
582-
structured_output = outputs[0]
583-
try:
584-
result_data = output_schema.validate(structured_output)
585553
result_data = await _validate_output(result_data, ctx, None)
586554
except _output.ToolRetryError as e:
587555
ctx.state.increment_retries(ctx.deps.max_result_retries)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ class OutputSchema(Generic[OutputDataT]):
220220
forced_mode: OutputMode | None
221221
object_schema: OutputObjectSchema[OutputDataT]
222222
tools: dict[str, OutputTool[OutputDataT]]
223-
allow_text_output: bool # TODO: Verify structured output works correctly with string as a union member
223+
allow_plain_text_output: bool
224+
allow_json_text_output: bool # TODO: Turn into allowed_text_output: Literal['plain', 'json'] | None
224225

225226
@classmethod
226227
def build(
@@ -235,11 +236,14 @@ def build(
235236
return None
236237

237238
forced_mode = None
239+
allow_json_text_output = True
240+
allow_plain_text_output = False
238241
tool_output_type = None
239-
allow_text_output = False
240242
if isinstance(output_type, ToolOutput):
241-
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
242243
forced_mode = 'tool'
244+
allow_json_text_output = False
245+
246+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
243247
name = output_type.name
244248
description = output_type.description
245249
output_type_ = output_type.output_type
@@ -255,12 +259,15 @@ def build(
255259
name = output_type.name
256260
description = output_type.description
257261
output_type_ = output_type.output_type
258-
else:
262+
elif output_type_other_than_str := extract_str_from_union(output_type):
263+
forced_mode = 'tool'
259264
output_type_ = output_type
260265

261-
if output_type_other_than_str := extract_str_from_union(output_type):
262-
allow_text_output = True
263-
tool_output_type = output_type_other_than_str.value
266+
allow_json_text_output = False
267+
allow_plain_text_output = True
268+
tool_output_type = output_type_other_than_str.value
269+
else:
270+
output_type_ = output_type
264271

265272
output_object_schema = OutputObjectSchema(
266273
output_type=output_type_, name=name, description=description, strict=strict
@@ -292,7 +299,8 @@ def build(
292299
forced_mode=forced_mode,
293300
object_schema=output_object_schema,
294301
tools=tools,
295-
allow_text_output=allow_text_output,
302+
allow_plain_text_output=allow_plain_text_output,
303+
allow_json_text_output=allow_json_text_output,
296304
)
297305

298306
def find_named_tool(
@@ -341,8 +349,7 @@ def validate(
341349

342350

343351
def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
344-
"""Check if the result schema allows text results."""
345-
return output_schema is None or output_schema.allow_text_output
352+
return output_schema is None or output_schema.allow_plain_text_output or output_schema.allow_json_text_output
346353

347354

348355
@dataclass

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,8 +1013,6 @@ async def stream_to_final(
10131013
elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
10141014
for call, _ in output_schema.find_tool([new_part]):
10151015
return FinalResult(s, call.tool_name, call.tool_call_id)
1016-
elif isinstance(new_part, _messages.OutputPart) and output_schema:
1017-
return FinalResult(s, None, None)
10181016
return None
10191017

10201018
final_result_details = await stream_to_final(streamed_response)

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -479,21 +479,6 @@ def has_content(self) -> bool:
479479
return bool(self.content)
480480

481481

482-
@dataclass
483-
class OutputPart:
484-
"""An output response from a model."""
485-
486-
content: str
487-
"""The output content of the response as a JSON-serialized string."""
488-
489-
part_kind: Literal['output'] = 'output'
490-
"""Part type identifier, this is available on all parts as a discriminator."""
491-
492-
def has_content(self) -> bool:
493-
"""Return `True` if the output content is non-empty."""
494-
return bool(self.content)
495-
496-
497482
@dataclass
498483
class ToolCallPart:
499484
"""A tool call from a model."""
@@ -548,7 +533,7 @@ def has_content(self) -> bool:
548533
return bool(self.args)
549534

550535

551-
ModelResponsePart = Annotated[Union[TextPart, OutputPart, ToolCallPart], pydantic.Discriminator('part_kind')]
536+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
552537
"""A message part returned by a model."""
553538

554539

@@ -654,33 +639,6 @@ def apply(self, part: ModelResponsePart) -> TextPart:
654639
return replace(part, content=part.content + self.content_delta)
655640

656641

657-
@dataclass
658-
class OutputPartDelta:
659-
"""A partial update (delta) for a `OutputPart` to append new structured output content."""
660-
661-
content_delta: str
662-
"""The incremental structured output content to add to the existing `OutputPart` content."""
663-
664-
part_delta_kind: Literal['output'] = 'output'
665-
"""Part delta type identifier, used as a discriminator."""
666-
667-
def apply(self, part: ModelResponsePart) -> OutputPart:
668-
"""Apply this structured output delta to an existing `OutputPart`.
669-
670-
Args:
671-
part: The existing model response part, which must be a `OutputPart`.
672-
673-
Returns:
674-
A new `OutputPart` with updated structured output content.
675-
676-
Raises:
677-
ValueError: If `part` is not a `OutputPart`.
678-
"""
679-
if not isinstance(part, OutputPart):
680-
raise ValueError('Cannot apply OutputPartDeltas to non-OutputParts')
681-
return replace(part, content=part.content + self.content_delta)
682-
683-
684642
@dataclass
685643
class ToolCallPartDelta:
686644
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
@@ -798,9 +756,7 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
798756
return part
799757

800758

801-
ModelResponsePartDelta = Annotated[
802-
Union[TextPartDelta, OutputPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')
803-
]
759+
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
804760
"""A partial update (delta) for any model response part."""
805761

806762

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class ModelRequestParameters:
266266
output_mode: OutputMode | None = None
267267
output_object: OutputObjectDefinition | None = None
268268
output_tools: list[ToolDefinition] = field(default_factory=list)
269-
allow_text_output: bool = True
269+
require_tool_use: bool = True
270270

271271

272272
class Model(ABC):

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ModelResponse,
2222
ModelResponsePart,
2323
ModelResponseStreamEvent,
24-
OutputPart,
2524
RetryPromptPart,
2625
SystemPromptPart,
2726
TextPart,
@@ -213,7 +212,7 @@ async def _messages_create(
213212
if not tools:
214213
tool_choice = None
215214
else:
216-
if not model_request_parameters.allow_text_output:
215+
if model_request_parameters.require_tool_use:
217216
tool_choice = {'type': 'any'}
218217
else:
219218
tool_choice = {'type': 'auto'}
@@ -322,7 +321,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
322321
elif isinstance(m, ModelResponse):
323322
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
324323
for response_part in m.parts:
325-
if isinstance(response_part, (TextPart, OutputPart)):
324+
if isinstance(response_part, TextPart):
326325
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
327326
else:
328327
tool_use_block_param = ToolUseBlockParam(

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def _messages_create(
305305
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
306306
if not tools or not support_tools_choice:
307307
tool_choice: ToolChoiceTypeDef = {}
308-
elif not model_request_parameters.allow_text_output:
308+
elif model_request_parameters.require_tool_use:
309309
tool_choice = {'any': {}} # pragma: no cover
310310
else:
311311
tool_choice = {'auto': {}}

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
ModelRequest,
1414
ModelResponse,
1515
ModelResponsePart,
16-
OutputPart,
1716
RetryPromptPart,
1817
SystemPromptPart,
1918
TextPart,
@@ -206,7 +205,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
206205
texts: list[str] = []
207206
tool_calls: list[ToolCallV2] = []
208207
for item in message.parts:
209-
if isinstance(item, (TextPart, OutputPart)):
208+
if isinstance(item, TextPart):
210209
texts.append(item.content)
211210
elif isinstance(item, ToolCallPart):
212211
tool_calls.append(self._map_tool_call(item))

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ModelRequest,
2222
ModelResponse,
2323
ModelResponseStreamEvent,
24-
OutputPart,
2524
RetryPromptPart,
2625
SystemPromptPart,
2726
TextPart,
@@ -92,7 +91,7 @@ async def request(
9291
) -> ModelResponse:
9392
agent_info = AgentInfo(
9493
model_request_parameters.function_tools,
95-
model_request_parameters.allow_text_output,
94+
not model_request_parameters.require_tool_use,
9695
model_request_parameters.output_tools,
9796
model_settings,
9897
)
@@ -121,7 +120,7 @@ async def request_stream(
121120
) -> AsyncIterator[StreamedResponse]:
122121
agent_info = AgentInfo(
123122
model_request_parameters.function_tools,
124-
model_request_parameters.allow_text_output,
123+
not model_request_parameters.require_tool_use,
125124
model_request_parameters.output_tools,
126125
model_settings,
127126
)
@@ -267,7 +266,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
267266
assert_never(part)
268267
elif isinstance(message, ModelResponse):
269268
for part in message.parts:
270-
if isinstance(part, (TextPart, OutputPart)):
269+
if isinstance(part, TextPart):
271270
response_tokens += _estimate_string_tokens(part.content)
272271
elif isinstance(part, ToolCallPart):
273272
call = part

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ModelResponse,
2929
ModelResponsePart,
3030
ModelResponseStreamEvent,
31-
OutputPart,
3231
RetryPromptPart,
3332
SystemPromptPart,
3433
TextPart,
@@ -183,7 +182,7 @@ def _customize_output_object_def(o: OutputObjectDefinition):
183182
if model_request_parameters.output_object
184183
else None,
185184
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
186-
allow_text_output=model_request_parameters.allow_text_output,
185+
require_tool_use=model_request_parameters.require_tool_use,
187186
)
188187

189188
@property
@@ -205,7 +204,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin
205204
def _get_tool_config(
206205
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
207206
) -> _GeminiToolConfig | None:
208-
if model_request_parameters.allow_text_output:
207+
if not model_request_parameters.require_tool_use:
209208
return None
210209
elif tools:
211210
return _tool_config([t['name'] for t in tools['function_declarations']])
@@ -563,7 +562,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
563562
for item in m.parts:
564563
if isinstance(item, ToolCallPart):
565564
parts.append(_function_call_part_from_call(item))
566-
elif isinstance(item, (TextPart, OutputPart)):
565+
elif isinstance(item, TextPart):
567566
if item.content:
568567
parts.append(_GeminiTextPart(text=item.content))
569568
else:

0 commit comments

Comments
 (0)