Skip to content

Commit e334bba

Browse files
committed
WIP: Support structured and manual JSON output_type modes in addition to tool calls
1 parent 54fb56a commit e334bba

21 files changed

+936
-70
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from .format_prompt import format_as_xml
1414
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15-
from .result import ToolOutput
15+
from .result import StructuredOutput, ToolOutput
1616
from .tools import RunContext, Tool
1717

1818
__all__ = (
@@ -43,6 +43,7 @@
4343
'RunContext',
4444
# result
4545
'ToolOutput',
46+
'StructuredOutput',
4647
# format_prompt
4748
'format_as_xml',
4849
)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
result,
2525
usage as _usage,
2626
)
27-
from .result import OutputDataT, ToolOutput
27+
from .result import OutputDataT, StructuredOutput, ToolOutput
2828
from .settings import ModelSettings, merge_model_settings
2929
from .tools import RunContext, Tool, ToolDefinition
3030

@@ -125,9 +125,6 @@ def is_agent_node(
125125
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
126126
user_prompt: str | Sequence[_messages.UserContent] | None
127127

128-
instructions: str | None
129-
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
130-
131128
system_prompts: tuple[str, ...]
132129
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
133130
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
@@ -244,6 +241,8 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
244241
function_tools=function_tool_defs,
245242
allow_text_output=allow_text_output(output_schema),
246243
output_tools=output_schema.tool_defs() if output_schema is not None else [],
244+
output_schema=output_schema.json_schema if output_schema is not None else None,
245+
preferred_output_mode=output_schema.preferred_mode if output_schema is not None else None,
247246
)
248247

249248

@@ -396,20 +395,24 @@ async def stream(
396395
async for _event in stream:
397396
pass
398397

399-
async def _run_stream(
398+
async def _run_stream( # noqa: C901
400399
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
401400
) -> AsyncIterator[_messages.HandleResponseEvent]:
402401
if self._events_iterator is None:
403402
# Ensure that the stream is only run once
404403

405404
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
406405
texts: list[str] = []
406+
structured_outputs: list[str] = []
407407
tool_calls: list[_messages.ToolCallPart] = []
408408
for part in self.model_response.parts:
409409
if isinstance(part, _messages.TextPart):
410410
# ignore empty content for text parts, see #437
411411
if part.content:
412412
texts.append(part.content)
413+
elif isinstance(part, _messages.StructuredOutputPart):
414+
if part.content:
415+
structured_outputs.append(part.content)
413416
elif isinstance(part, _messages.ToolCallPart):
414417
tool_calls.append(part)
415418
else:
@@ -422,6 +425,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
422425
if tool_calls:
423426
async for event in self._handle_tool_calls(ctx, tool_calls):
424427
yield event
428+
elif structured_outputs:
429+
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
430+
self._next_node = await self._handle_structured_outputs(ctx, structured_outputs)
425431
elif texts:
426432
# No events are emitted during the handling of text responses, so we don't need to yield anything
427433
self._next_node = await self._handle_text_response(ctx, texts)
@@ -535,6 +541,27 @@ async def _handle_text_response(
535541
)
536542
)
537543

544+
async def _handle_structured_outputs(
545+
self,
546+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
547+
structured_outputs: list[str],
548+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
549+
if len(structured_outputs) != 1:
550+
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
551+
output_schema = ctx.deps.output_schema
552+
if not output_schema:
553+
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
554+
555+
structured_output = structured_outputs[0]
556+
try:
557+
result_data = output_schema.validate(structured_output)
558+
result_data = await _validate_output(result_data, ctx, None)
559+
except _output.ToolRetryError as e:
560+
ctx.state.increment_retries(ctx.deps.max_result_retries)
561+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
562+
else:
563+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
564+
538565

539566
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
540567
"""Build a `RunContext` object from the current agent graph run context."""
@@ -829,7 +856,9 @@ def get_captured_run_messages() -> _RunMessages:
829856

830857

831858
def build_agent_graph(
832-
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
859+
name: str | None,
860+
deps_type: type[DepsT],
861+
output_type: type[OutputT] | ToolOutput[OutputT] | StructuredOutput[OutputT],
833862
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
834863
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
835864
nodes = (

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@
1212

1313
from . import _utils, messages as _messages
1414
from .exceptions import ModelRetry
15-
from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput
16-
from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition
15+
from .result import (
16+
DEFAULT_OUTPUT_TOOL_NAME,
17+
OutputDataT,
18+
OutputDataT_inv,
19+
OutputValidatorFunc,
20+
StructuredOutput,
21+
ToolOutput,
22+
)
23+
from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition
1724

1825
T = TypeVar('T')
1926
"""An invariant TypeVar."""
@@ -83,13 +90,18 @@ class OutputSchema(Generic[OutputDataT]):
8390
Similar to `Tool` but for the final output of running an agent.
8491
"""
8592

93+
# TODO: Since this is currently called "preferred", models that don't have structured output implemented yet ignore it and use tools (except for Mistral).
94+
# We should likely raise an error if an unsupported mode is used, _and_ allow the model to pick its own preferred mode if none is forced.
95+
preferred_mode: Literal['tool', 'structured'] | None # TODO: Add mode for manual JSON
96+
type_adapter: TypeAdapter[OutputDataT]
8697
tools: dict[str, OutputSchemaTool[OutputDataT]]
87-
allow_text_output: bool
98+
allow_text_output: bool # TODO: Verify structured output works correctly with string as a union member
99+
json_schema: ObjectJsonSchema # TODO: Verify structured output works correctly with a union
88100

89101
@classmethod
90102
def build(
91103
cls: type[OutputSchema[T]],
92-
output_type: type[T] | ToolOutput[T],
104+
output_type: type[T] | ToolOutput[T] | StructuredOutput[T], # TODO: Support a list of output types/markers
93105
name: str | None = None,
94106
description: str | None = None,
95107
strict: bool | None = None,
@@ -98,15 +110,34 @@ def build(
98110
if output_type is str:
99111
return None
100112

113+
preferred_mode = None
101114
if isinstance(output_type, ToolOutput):
102115
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
103116
name = output_type.name
104117
description = output_type.description
105118
output_type_ = output_type.output_type
106119
strict = output_type.strict
120+
preferred_mode = 'tool'
121+
elif isinstance(output_type, StructuredOutput):
122+
name = output_type.name # TODO: Get this to the response_format model request arg
123+
description = output_type.description # TODO: Get this to the response_format model request arg
124+
output_type_ = output_type.output_type
125+
strict = output_type.strict # TODO: Get this to the response_format model request arg
126+
preferred_mode = 'structured'
107127
else:
108128
output_type_ = output_type
109129

130+
type_adapter = cast(TypeAdapter[T], TypeAdapter(output_type_))
131+
json_schema = _utils.check_object_json_schema(type_adapter.json_schema(schema_generator=GenerateToolJsonSchema))
132+
133+
# TODO: Make this description available to the model params
134+
if json_schema_description := json_schema.pop('description', None):
135+
if description is None:
136+
description = json_schema_description
137+
else:
138+
description = f'{description}. {json_schema_description}'
139+
140+
# No need to include an output tool for string output
110141
if output_type_option := extract_str_from_union(output_type):
111142
output_type_ = output_type_option.value
112143
allow_text_output = True
@@ -134,7 +165,13 @@ def build(
134165
),
135166
)
136167

137-
return cls(tools=tools, allow_text_output=allow_text_output)
168+
return cls(
169+
preferred_mode=preferred_mode,
170+
tools=tools,
171+
allow_text_output=allow_text_output,
172+
type_adapter=type_adapter,
173+
json_schema=json_schema,
174+
)
138175

139176
def find_named_tool(
140177
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
@@ -163,6 +200,35 @@ def tool_defs(self) -> list[ToolDefinition]:
163200
"""Get tool definitions to register with the model."""
164201
return [t.tool_def for t in self.tools.values()]
165202

203+
def validate(
204+
self, output_text: str, allow_partial: bool = False, wrap_validation_errors: bool = True
205+
) -> OutputDataT:
206+
"""Validate a structured output message.
207+
208+
Args:
209+
output_text: The structured output from the LLM to validate.
210+
allow_partial: If true, allow partial validation.
211+
wrap_validation_errors: If true, wrap the validation errors in a retry message.
212+
213+
Returns:
214+
Either the validated output data (left) or a retry message (right).
215+
"""
216+
try:
217+
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
218+
output = self.type_adapter.validate_json(output_text, experimental_allow_partial=pyd_allow_partial)
219+
except ValidationError as e:
220+
if wrap_validation_errors:
221+
m = _messages.RetryPromptPart(
222+
content=e.errors(include_url=False),
223+
)
224+
raise ToolRetryError(m) from e
225+
else:
226+
raise
227+
else:
228+
return output
229+
230+
# TODO: Build instructions for manual JSON
231+
166232

167233
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
168234

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ModelResponseStreamEvent,
2424
PartDeltaEvent,
2525
PartStartEvent,
26+
StructuredOutputPartDelta,
2627
TextPart,
2728
TextPartDelta,
2829
ToolCallPart,
@@ -57,12 +58,12 @@ class ModelResponsePartsManager:
5758
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
5859

5960
def get_parts(self) -> list[ModelResponsePart]:
60-
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
61+
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's or StructuredOutputPartDelta's).
6162
6263
Returns:
63-
A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
64+
A list of ModelResponsePart objects. ToolCallPartDelta and StructuredOutputPartDelta objects are excluded.
6465
"""
65-
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
66+
return [p for p in self._parts if not isinstance(p, (ToolCallPartDelta, StructuredOutputPartDelta))]
6667

6768
def handle_text_delta(
6869
self,
@@ -91,6 +92,8 @@ def handle_text_delta(
9192
"""
9293
existing_text_part_and_index: tuple[TextPart, int] | None = None
9394

95+
# TODO: Parse out structured output or manual JSON, with a separate message?
96+
9497
if vendor_part_id is None:
9598
# If the vendor_part_id is None, check if the latest part is a TextPart to update
9699
if self._parts:

0 commit comments

Comments
 (0)