24
24
result ,
25
25
usage as _usage ,
26
26
)
27
- from .result import OutputDataT , ToolOutput
27
+ from .result import OutputDataT , StructuredOutput , ToolOutput
28
28
from .settings import ModelSettings , merge_model_settings
29
29
from .tools import RunContext , Tool , ToolDefinition
30
30
@@ -244,6 +244,8 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
244
244
function_tools = function_tool_defs ,
245
245
allow_text_output = allow_text_output (output_schema ),
246
246
output_tools = output_schema .tool_defs () if output_schema is not None else [],
247
+ output_object = output_schema .output_object_schema .definition if output_schema is not None else None ,
248
+ preferred_output_mode = output_schema .preferred_mode if output_schema is not None else None ,
247
249
)
248
250
249
251
@@ -394,20 +396,24 @@ async def stream(
394
396
async for _event in stream :
395
397
pass
396
398
397
- async def _run_stream (
399
+ async def _run_stream ( # noqa: C901
398
400
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
399
401
) -> AsyncIterator [_messages .HandleResponseEvent ]:
400
402
if self ._events_iterator is None :
401
403
# Ensure that the stream is only run once
402
404
403
405
async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]:
404
406
texts : list [str ] = []
407
+ structured_outputs : list [str ] = []
405
408
tool_calls : list [_messages .ToolCallPart ] = []
406
409
for part in self .model_response .parts :
407
410
if isinstance (part , _messages .TextPart ):
408
411
# ignore empty content for text parts, see #437
409
412
if part .content :
410
413
texts .append (part .content )
414
+ elif isinstance (part , _messages .StructuredOutputPart ):
415
+ if part .content :
416
+ structured_outputs .append (part .content )
411
417
elif isinstance (part , _messages .ToolCallPart ):
412
418
tool_calls .append (part )
413
419
else :
@@ -420,6 +426,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
420
426
if tool_calls :
421
427
async for event in self ._handle_tool_calls (ctx , tool_calls ):
422
428
yield event
429
+ elif structured_outputs :
430
+ # No events are emitted during the handling of structured outputs, so we don't need to yield anything
431
+ self ._next_node = await self ._handle_structured_outputs (ctx , structured_outputs )
423
432
elif texts :
424
433
# No events are emitted during the handling of text responses, so we don't need to yield anything
425
434
self ._next_node = await self ._handle_text_response (ctx , texts )
@@ -533,6 +542,27 @@ async def _handle_text_response(
533
542
)
534
543
)
535
544
545
+ async def _handle_structured_outputs (
546
+ self ,
547
+ ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
548
+ structured_outputs : list [str ],
549
+ ) -> ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]]:
550
+ if len (structured_outputs ) != 1 :
551
+ raise exceptions .UnexpectedModelBehavior ('Received multiple structured outputs in a single response' )
552
+ output_schema = ctx .deps .output_schema
553
+ if not output_schema :
554
+ raise exceptions .UnexpectedModelBehavior ('Must specify a non-str result_type when using structured outputs' )
555
+
556
+ structured_output = structured_outputs [0 ]
557
+ try :
558
+ result_data = output_schema .validate (structured_output )
559
+ result_data = await _validate_output (result_data , ctx , None )
560
+ except _output .ToolRetryError as e :
561
+ ctx .state .increment_retries (ctx .deps .max_result_retries )
562
+ return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
563
+ else :
564
+ return self ._handle_final_result (ctx , result .FinalResult (result_data , None , None ), [])
565
+
536
566
537
567
def build_run_context (ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]) -> RunContext [DepsT ]:
538
568
"""Build a `RunContext` object from the current agent graph run context."""
@@ -827,7 +857,9 @@ def get_captured_run_messages() -> _RunMessages:
827
857
828
858
829
859
def build_agent_graph (
830
- name : str | None , deps_type : type [DepsT ], output_type : type [OutputT ] | ToolOutput [OutputT ]
860
+ name : str | None ,
861
+ deps_type : type [DepsT ],
862
+ output_type : type [OutputT ] | ToolOutput [OutputT ] | StructuredOutput [OutputT ],
831
863
) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result .FinalResult [OutputT ]], result .FinalResult [OutputT ]]:
832
864
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
833
865
nodes = (
0 commit comments