Skip to content

Support structured and manual JSON output_type modes in addition to tool calls #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .format_prompt import format_as_xml
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
from .result import ToolOutput
from .result import JSONSchemaOutput, ToolOutput
from .tools import RunContext, Tool

__all__ = (
Expand Down Expand Up @@ -43,6 +43,7 @@
'RunContext',
# result
'ToolOutput',
'JSONSchemaOutput',
# format_prompt
'format_as_xml',
)
Expand Down
70 changes: 41 additions & 29 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
result,
usage as _usage,
)
from .result import OutputDataT, ToolOutput
from .result import OutputDataT
from .settings import ModelSettings, merge_model_settings
from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc

Expand Down Expand Up @@ -249,10 +249,28 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []

output_schema = ctx.deps.output_schema
model = ctx.deps.model

output_mode = None
output_object = None
output_tools = []
require_tool_use = False
if output_schema:
output_mode = output_schema.forced_mode or model.default_output_mode
output_object = output_schema.object_schema.definition
output_tools = output_schema.tool_defs()
require_tool_use = output_mode == 'tool' and not output_schema.allow_plain_text_output

supported_modes = model.supported_output_modes
if output_mode not in supported_modes:
raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}")

return models.ModelRequestParameters(
function_tools=function_tool_defs,
allow_text_output=allow_text_output(output_schema),
output_tools=output_schema.tool_defs() if output_schema is not None else [],
output_mode=output_mode,
output_object=output_object,
output_tools=output_tools,
require_tool_use=require_tool_use,
)


Expand Down Expand Up @@ -437,7 +455,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
# when the model has already returned text along side tool calls
# in this scenario, if text responses are allowed, we return text from the most recent model
# response, if any
if allow_text_output(ctx.deps.output_schema):
if _output.allow_text_output(ctx.deps.output_schema):
for message in reversed(ctx.state.message_history):
if isinstance(message, _messages.ModelResponse):
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
Expand Down Expand Up @@ -520,27 +538,24 @@ async def _handle_text_response(
output_schema = ctx.deps.output_schema

text = '\n\n'.join(texts)
if allow_text_output(output_schema):
# The following cast is safe because we know `str` is an allowed result type
result_data_input = cast(NodeRunEndT, text)
try:
result_data = await _validate_output(result_data_input, ctx, None)
except _output.ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
try:
if output_schema is None or output_schema.allow_plain_text_output:
# The following cast is safe because we know `str` is an allowed result type
result_data = cast(NodeRunEndT, text)
elif output_schema.allow_json_text_output:
result_data = output_schema.validate(text)
else:
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
else:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(
parts=[
_messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
)
]
m = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
)
)
raise _output.ToolRetryError(m)

result_data = await _validate_output(result_data, ctx, None)
except _output.ToolRetryError as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunate exception name now

Copy link
Contributor

@dmontagu dmontagu May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use ModelRetry directly here? Maybe that's hard/annoying if things are set up to convert those into ToolRetryError (because it was assumed to be handled in a tool). I don't remember the details..

ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
else:
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])


def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
Expand Down Expand Up @@ -782,11 +797,6 @@ async def _validate_output(
return result_data


def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
"""Check if the result schema allows text results."""
return output_schema is None or output_schema.allow_text_output


@dataclasses.dataclass
class _RunMessages:
messages: list[_messages.ModelMessage]
Expand Down Expand Up @@ -836,7 +846,9 @@ def get_captured_run_messages() -> _RunMessages:


def build_agent_graph(
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
name: str | None,
deps_type: type[DepsT],
output_type: _output.OutputType[OutputT],
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
nodes = (
Expand Down
Loading
Loading