24
24
result ,
25
25
usage as _usage ,
26
26
)
27
- from .result import OutputDataT , ToolOutput
27
+ from .result import OutputDataT , StructuredOutput , ToolOutput
28
28
from .settings import ModelSettings , merge_model_settings
29
29
from .tools import RunContext , Tool , ToolDefinition
30
30
@@ -125,9 +125,6 @@ def is_agent_node(
125
125
class UserPromptNode (AgentNode [DepsT , NodeRunEndT ]):
126
126
user_prompt : str | Sequence [_messages .UserContent ] | None
127
127
128
- instructions : str | None
129
- instructions_functions : list [_system_prompt .SystemPromptRunner [DepsT ]]
130
-
131
128
system_prompts : tuple [str , ...]
132
129
system_prompt_functions : list [_system_prompt .SystemPromptRunner [DepsT ]]
133
130
system_prompt_dynamic_functions : dict [str , _system_prompt .SystemPromptRunner [DepsT ]]
@@ -244,6 +241,8 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
244
241
function_tools = function_tool_defs ,
245
242
allow_text_output = allow_text_output (output_schema ),
246
243
output_tools = output_schema .tool_defs () if output_schema is not None else [],
244
+ output_object = output_schema .output_object_schema .definition if output_schema is not None else None ,
245
+ preferred_output_mode = output_schema .preferred_mode if output_schema is not None else None ,
247
246
)
248
247
249
248
@@ -396,20 +395,24 @@ async def stream(
396
395
async for _event in stream :
397
396
pass
398
397
399
- async def _run_stream (
398
+ async def _run_stream ( # noqa: C901
400
399
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
401
400
) -> AsyncIterator [_messages .HandleResponseEvent ]:
402
401
if self ._events_iterator is None :
403
402
# Ensure that the stream is only run once
404
403
405
404
async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]:
406
405
texts : list [str ] = []
406
+ structured_outputs : list [str ] = []
407
407
tool_calls : list [_messages .ToolCallPart ] = []
408
408
for part in self .model_response .parts :
409
409
if isinstance (part , _messages .TextPart ):
410
410
# ignore empty content for text parts, see #437
411
411
if part .content :
412
412
texts .append (part .content )
413
+ elif isinstance (part , _messages .StructuredOutputPart ):
414
+ if part .content :
415
+ structured_outputs .append (part .content )
413
416
elif isinstance (part , _messages .ToolCallPart ):
414
417
tool_calls .append (part )
415
418
else :
@@ -422,6 +425,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
422
425
if tool_calls :
423
426
async for event in self ._handle_tool_calls (ctx , tool_calls ):
424
427
yield event
428
+ elif structured_outputs :
429
+ # No events are emitted during the handling of structured outputs, so we don't need to yield anything
430
+ self ._next_node = await self ._handle_structured_outputs (ctx , structured_outputs )
425
431
elif texts :
426
432
# No events are emitted during the handling of text responses, so we don't need to yield anything
427
433
self ._next_node = await self ._handle_text_response (ctx , texts )
@@ -535,6 +541,27 @@ async def _handle_text_response(
535
541
)
536
542
)
537
543
544
+ async def _handle_structured_outputs (
545
+ self ,
546
+ ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
547
+ structured_outputs : list [str ],
548
+ ) -> ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]]:
549
+ if len (structured_outputs ) != 1 :
550
+ raise exceptions .UnexpectedModelBehavior ('Received multiple structured outputs in a single response' )
551
+ output_schema = ctx .deps .output_schema
552
+ if not output_schema :
553
+ raise exceptions .UnexpectedModelBehavior ('Must specify a non-str result_type when using structured outputs' )
554
+
555
+ structured_output = structured_outputs [0 ]
556
+ try :
557
+ result_data = output_schema .validate (structured_output )
558
+ result_data = await _validate_output (result_data , ctx , None )
559
+ except _output .ToolRetryError as e :
560
+ ctx .state .increment_retries (ctx .deps .max_result_retries )
561
+ return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
562
+ else :
563
+ return self ._handle_final_result (ctx , result .FinalResult (result_data , None , None ), [])
564
+
538
565
539
566
def build_run_context (ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]) -> RunContext [DepsT ]:
540
567
"""Build a `RunContext` object from the current agent graph run context."""
@@ -829,7 +856,9 @@ def get_captured_run_messages() -> _RunMessages:
829
856
830
857
831
858
def build_agent_graph (
832
- name : str | None , deps_type : type [DepsT ], output_type : type [OutputT ] | ToolOutput [OutputT ]
859
+ name : str | None ,
860
+ deps_type : type [DepsT ],
861
+ output_type : type [OutputT ] | ToolOutput [OutputT ] | StructuredOutput [OutputT ],
833
862
) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result .FinalResult [OutputT ]], result .FinalResult [OutputT ]]:
834
863
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
835
864
nodes = (
0 commit comments