Skip to content

Commit f305aa5

Browse files
committed
WIP: Support structured and manual JSON output_type modes in addition to tool calls
1 parent c0afc0d commit f305aa5

25 files changed

+1181
-189
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from .format_prompt import format_as_xml
1414
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15-
from .result import ToolOutput
15+
from .result import StructuredOutput, ToolOutput
1616
from .tools import RunContext, Tool
1717

1818
__all__ = (
@@ -43,6 +43,7 @@
4343
'RunContext',
4444
# result
4545
'ToolOutput',
46+
'StructuredOutput',
4647
# format_prompt
4748
'format_as_xml',
4849
)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
result,
2525
usage as _usage,
2626
)
27-
from .result import OutputDataT, ToolOutput
27+
from .result import OutputDataT, StructuredOutput, ToolOutput
2828
from .settings import ModelSettings, merge_model_settings
2929
from .tools import RunContext, Tool, ToolDefinition
3030

@@ -244,6 +244,8 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
244244
function_tools=function_tool_defs,
245245
allow_text_output=allow_text_output(output_schema),
246246
output_tools=output_schema.tool_defs() if output_schema is not None else [],
247+
output_object=output_schema.output_object_schema.definition if output_schema is not None else None,
248+
preferred_output_mode=output_schema.preferred_mode if output_schema is not None else None,
247249
)
248250

249251

@@ -394,20 +396,24 @@ async def stream(
394396
async for _event in stream:
395397
pass
396398

397-
async def _run_stream(
399+
async def _run_stream( # noqa: C901
398400
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
399401
) -> AsyncIterator[_messages.HandleResponseEvent]:
400402
if self._events_iterator is None:
401403
# Ensure that the stream is only run once
402404

403405
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
404406
texts: list[str] = []
407+
structured_outputs: list[str] = []
405408
tool_calls: list[_messages.ToolCallPart] = []
406409
for part in self.model_response.parts:
407410
if isinstance(part, _messages.TextPart):
408411
# ignore empty content for text parts, see #437
409412
if part.content:
410413
texts.append(part.content)
414+
elif isinstance(part, _messages.StructuredOutputPart):
415+
if part.content:
416+
structured_outputs.append(part.content)
411417
elif isinstance(part, _messages.ToolCallPart):
412418
tool_calls.append(part)
413419
else:
@@ -420,6 +426,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
420426
if tool_calls:
421427
async for event in self._handle_tool_calls(ctx, tool_calls):
422428
yield event
429+
elif structured_outputs:
430+
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
431+
self._next_node = await self._handle_structured_outputs(ctx, structured_outputs)
423432
elif texts:
424433
# No events are emitted during the handling of text responses, so we don't need to yield anything
425434
self._next_node = await self._handle_text_response(ctx, texts)
@@ -533,6 +542,27 @@ async def _handle_text_response(
533542
)
534543
)
535544

545+
async def _handle_structured_outputs(
546+
self,
547+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
548+
structured_outputs: list[str],
549+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
550+
if len(structured_outputs) != 1:
551+
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
552+
output_schema = ctx.deps.output_schema
553+
if not output_schema:
554+
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
555+
556+
structured_output = structured_outputs[0]
557+
try:
558+
result_data = output_schema.validate(structured_output)
559+
result_data = await _validate_output(result_data, ctx, None)
560+
except _output.ToolRetryError as e:
561+
ctx.state.increment_retries(ctx.deps.max_result_retries)
562+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
563+
else:
564+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
565+
536566

537567
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
538568
"""Build a `RunContext` object from the current agent graph run context."""
@@ -827,7 +857,9 @@ def get_captured_run_messages() -> _RunMessages:
827857

828858

829859
def build_agent_graph(
830-
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
860+
name: str | None,
861+
deps_type: type[DepsT],
862+
output_type: type[OutputT] | ToolOutput[OutputT] | StructuredOutput[OutputT],
831863
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
832864
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
833865
nodes = (

0 commit comments

Comments
 (0)