Skip to content

Commit f74cf39

Browse files
committed
WIP: Support structured and manual JSON output_type modes in addition to tool calls
1 parent 54fb56a commit f74cf39

21 files changed

+1033
-109
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from .format_prompt import format_as_xml
1414
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15-
from .result import ToolOutput
15+
from .result import StructuredOutput, ToolOutput
1616
from .tools import RunContext, Tool
1717

1818
__all__ = (
@@ -43,6 +43,7 @@
4343
'RunContext',
4444
# result
4545
'ToolOutput',
46+
'StructuredOutput',
4647
# format_prompt
4748
'format_as_xml',
4849
)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
result,
2525
usage as _usage,
2626
)
27-
from .result import OutputDataT, ToolOutput
27+
from .result import OutputDataT, StructuredOutput, ToolOutput
2828
from .settings import ModelSettings, merge_model_settings
2929
from .tools import RunContext, Tool, ToolDefinition
3030

@@ -125,9 +125,6 @@ def is_agent_node(
125125
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
126126
user_prompt: str | Sequence[_messages.UserContent] | None
127127

128-
instructions: str | None
129-
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
130-
131128
system_prompts: tuple[str, ...]
132129
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
133130
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
@@ -244,6 +241,8 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
244241
function_tools=function_tool_defs,
245242
allow_text_output=allow_text_output(output_schema),
246243
output_tools=output_schema.tool_defs() if output_schema is not None else [],
244+
output_object=output_schema.output_object_schema.definition if output_schema is not None else None,
245+
preferred_output_mode=output_schema.preferred_mode if output_schema is not None else None,
247246
)
248247

249248

@@ -396,20 +395,24 @@ async def stream(
396395
async for _event in stream:
397396
pass
398397

399-
async def _run_stream(
398+
async def _run_stream( # noqa: C901
400399
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
401400
) -> AsyncIterator[_messages.HandleResponseEvent]:
402401
if self._events_iterator is None:
403402
# Ensure that the stream is only run once
404403

405404
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
406405
texts: list[str] = []
406+
structured_outputs: list[str] = []
407407
tool_calls: list[_messages.ToolCallPart] = []
408408
for part in self.model_response.parts:
409409
if isinstance(part, _messages.TextPart):
410410
# ignore empty content for text parts, see #437
411411
if part.content:
412412
texts.append(part.content)
413+
elif isinstance(part, _messages.StructuredOutputPart):
414+
if part.content:
415+
structured_outputs.append(part.content)
413416
elif isinstance(part, _messages.ToolCallPart):
414417
tool_calls.append(part)
415418
else:
@@ -422,6 +425,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
422425
if tool_calls:
423426
async for event in self._handle_tool_calls(ctx, tool_calls):
424427
yield event
428+
elif structured_outputs:
429+
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
430+
self._next_node = await self._handle_structured_outputs(ctx, structured_outputs)
425431
elif texts:
426432
# No events are emitted during the handling of text responses, so we don't need to yield anything
427433
self._next_node = await self._handle_text_response(ctx, texts)
@@ -535,6 +541,27 @@ async def _handle_text_response(
535541
)
536542
)
537543

544+
async def _handle_structured_outputs(
545+
self,
546+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
547+
structured_outputs: list[str],
548+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
549+
if len(structured_outputs) != 1:
550+
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
551+
output_schema = ctx.deps.output_schema
552+
if not output_schema:
553+
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
554+
555+
structured_output = structured_outputs[0]
556+
try:
557+
result_data = output_schema.validate(structured_output)
558+
result_data = await _validate_output(result_data, ctx, None)
559+
except _output.ToolRetryError as e:
560+
ctx.state.increment_retries(ctx.deps.max_result_retries)
561+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
562+
else:
563+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
564+
538565

539566
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
540567
"""Build a `RunContext` object from the current agent graph run context."""
@@ -829,7 +856,9 @@ def get_captured_run_messages() -> _RunMessages:
829856

830857

831858
def build_agent_graph(
832-
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
859+
name: str | None,
860+
deps_type: type[DepsT],
861+
output_type: type[OutputT] | ToolOutput[OutputT] | StructuredOutput[OutputT],
833862
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
834863
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
835864
nodes = (

0 commit comments

Comments
 (0)