Skip to content

Commit 9c64d68

Browse files
committed
WIP: Add support for the use of response_format to force a particular json schema for the response
1 parent 9b4de86 commit 9c64d68

File tree

13 files changed

+235
-42
lines changed

13 files changed

+235
-42
lines changed

examples/pydantic_ai_examples/flight_booking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ class Failed(BaseModel):
105105

106106

107107
# This agent is responsible for extracting the user's seat selection
108-
seat_preference_agent = Agent[None, SeatPreference | Failed](
108+
seat_preference_agent = Agent[
109+
None, SeatPreference | Failed
110+
](
109111
'openai:gpt-4o',
110112
result_type=SeatPreference | Failed, # type: ignore
111113
system_prompt=(

examples/pydantic_ai_examples/rag.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ async def retrieve(context: RunContext[Deps], search_query: str) -> str:
6767
model='text-embedding-3-small',
6868
)
6969

70-
assert len(embedding.data) == 1, (
71-
f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
72-
)
70+
assert (
71+
len(embedding.data) == 1
72+
), f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
7373
embedding = embedding.data[0].embedding
7474
embedding_json = pydantic_core.to_json(embedding).decode()
7575
rows = await context.deps.pool.fetch(
@@ -149,9 +149,9 @@ async def insert_doc_section(
149149
input=section.embedding_content(),
150150
model='text-embedding-3-small',
151151
)
152-
assert len(embedding.data) == 1, (
153-
f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
154-
)
152+
assert (
153+
len(embedding.data) == 1
154+
), f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
155155
embedding = embedding.data[0].embedding
156156
embedding_json = pydantic_core.to_json(embedding).decode()
157157
await pool.execute(

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Generic, Literal, Union, cast
1111

1212
import logfire_api
13+
from pydantic import ValidationError
1314
from typing_extensions import TypeVar, assert_never
1415

1516
from pydantic_graph import BaseNode, Graph, GraphRunContext
@@ -370,12 +371,16 @@ async def _run_stream(
370371

371372
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
372373
texts: list[str] = []
374+
structured_outputs: list[str] = []
373375
tool_calls: list[_messages.ToolCallPart] = []
374376
for part in self.model_response.parts:
375377
if isinstance(part, _messages.TextPart):
376378
# ignore empty content for text parts, see #437
377379
if part.content:
378380
texts.append(part.content)
381+
elif isinstance(part, _messages.StructuredOutputPart):
382+
if part.content:
383+
structured_outputs.append(part.content)
379384
elif isinstance(part, _messages.ToolCallPart):
380385
tool_calls.append(part)
381386
else:
@@ -391,6 +396,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
391396
elif texts:
392397
# No events are emitted during the handling of text responses, so we don't need to yield anything
393398
self._next_node = await self._handle_text_response(ctx, texts)
399+
elif structured_outputs:
400+
# No events are emitted during the handling of text responses, so we don't need to yield anything
401+
self._next_node = await self._handle_structured_outputs_response(ctx, texts)
394402
else:
395403
raise exceptions.UnexpectedModelBehavior('Received empty model response')
396404

@@ -487,6 +495,41 @@ async def _handle_text_response(
487495
)
488496
)
489497

498+
async def _handle_structured_outputs_response(
499+
self,
500+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
501+
structured_outputs: list[str],
502+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
503+
if len(structured_outputs) != 1:
504+
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
505+
result_schema = ctx.deps.result_schema
506+
if not result_schema:
507+
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
508+
509+
structured_output = structured_outputs[0]
510+
try:
511+
result_data_input = result_schema.structured_output_validator.validate_json(structured_output)
512+
except ValidationError as e:
513+
ctx.state.increment_retries(ctx.deps.max_result_retries)
514+
return ModelRequestNode[DepsT, NodeRunEndT](
515+
_messages.ModelRequest(
516+
parts=[
517+
_messages.RetryPromptPart(
518+
content='Structured output validation failed: ' + str(e),
519+
)
520+
]
521+
)
522+
)
523+
524+
try:
525+
result_data = await _validate_result(result_data_input, ctx, None)
526+
except _result.ToolRetryError as e:
527+
ctx.state.increment_retries(ctx.deps.max_result_retries)
528+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
529+
else:
530+
# The following cast is safe because we know `str` is an allowed result type
531+
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
532+
490533

491534
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
492535
"""Build a `RunContext` object from the current agent graph run context."""

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class ResultSchema(Generic[ResultDataT]):
8383
Similar to `Tool` but for the final result of running an agent.
8484
"""
8585

86+
structured_output_validator: TypeAdapter[ResultDataT]
8687
tools: dict[str, ResultTool[ResultDataT]]
8788
allow_text_result: bool
8889

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,22 @@ def has_content(self) -> bool:
217217
return bool(self.args)
218218

219219

220-
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
220+
@dataclass
221+
class StructuredOutputPart:
222+
"""A structured output response from a model."""
223+
224+
content: str
225+
"""The structured content of the response as a JSON-serialized string."""
226+
227+
part_kind: Literal['structured-output'] = 'structured-output'
228+
"""Part type identifier, this is available on all parts as a discriminator."""
229+
230+
def has_content(self) -> bool:
231+
"""Return `True` if the structured content is non-empty."""
232+
return bool(self.content)
233+
234+
235+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, StructuredOutputPart], pydantic.Discriminator('part_kind')]
221236
"""A message part returned by a model."""
222237

223238

@@ -275,6 +290,33 @@ def apply(self, part: ModelResponsePart) -> TextPart:
275290
return replace(part, content=part.content + self.content_delta)
276291

277292

293+
@dataclass
294+
class StructuredOutputPartDelta:
295+
"""A partial update (delta) for a `StructuredOutputPart` to append new text content."""
296+
297+
content_delta: str
298+
"""The incremental text content to add to the existing `StructuredOutputPart` content."""
299+
300+
part_delta_kind: Literal['structured-output'] = 'structured-output'
301+
"""Part delta type identifier, used as a discriminator."""
302+
303+
def apply(self, part: ModelResponsePart) -> StructuredOutputPart:
304+
"""Apply this text delta to an existing `TextPart`.
305+
306+
Args:
307+
part: The existing model response part, which must be a `TextPart`.
308+
309+
Returns:
310+
A new `TextPart` with updated text content.
311+
312+
Raises:
313+
ValueError: If `part` is not a `TextPart`.
314+
"""
315+
if not isinstance(part, StructuredOutputPart):
316+
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
317+
return replace(part, content=part.content + self.content_delta)
318+
319+
278320
@dataclass
279321
class ToolCallPartDelta:
280322
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
@@ -408,7 +450,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
408450
return part
409451

410452

411-
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
453+
ModelResponsePartDelta = Annotated[
454+
Union[TextPartDelta, ToolCallPartDelta, StructuredOutputPartDelta], pydantic.Discriminator('part_delta_kind')
455+
]
412456
"""A partial update (delta) for any model response part."""
413457

414458

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from itertools import chain
66
from typing import Literal, Union, cast
77

8-
from cohere import TextAssistantMessageContentItem
98
from httpx import AsyncClient as AsyncHTTPClient
109
from typing_extensions import assert_never
1110

@@ -17,6 +16,7 @@
1716
ModelResponse,
1817
ModelResponsePart,
1918
RetryPromptPart,
19+
StructuredOutputPart,
2020
SystemPromptPart,
2121
TextPart,
2222
ToolCallPart,
@@ -37,7 +37,11 @@
3737
AsyncClientV2,
3838
ChatMessageV2,
3939
ChatResponse,
40+
JsonObjectResponseFormatV2,
41+
ResponseFormatV2,
4042
SystemChatMessageV2,
43+
TextAssistantMessageContentItem,
44+
TextResponseFormatV2,
4145
ToolCallV2,
4246
ToolCallV2Function,
4347
ToolChatMessageV2,
@@ -152,7 +156,30 @@ async def _chat(
152156
model_settings: CohereModelSettings,
153157
model_request_parameters: ModelRequestParameters,
154158
) -> ChatResponse:
155-
tools = self._get_tools(model_request_parameters)
159+
if model_settings.get('force_response_format', False):
160+
tools: list[ToolV2] = OMIT
161+
response_format: ResponseFormatV2
162+
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
163+
response_format = TextResponseFormatV2()
164+
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
165+
result_tool = model_request_parameters.result_tools[0]
166+
response_format = JsonObjectResponseFormatV2(
167+
type='json_object',
168+
json_schema=result_tool.parameters_json_schema,
169+
)
170+
else:
171+
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
172+
if model_request_parameters.allow_text_result:
173+
json_schemas.append({'type': 'string'})
174+
response_format = JsonObjectResponseFormatV2(
175+
type='json_object',
176+
json_schema={'anyOf': json_schemas},
177+
)
178+
else:
179+
# standalone function to make it easier to override
180+
tools = self._get_tools(model_request_parameters)
181+
response_format = OMIT
182+
156183
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
157184
return await self.client.chat(
158185
model=self._model_name,
@@ -162,6 +189,7 @@ async def _chat(
162189
temperature=model_settings.get('temperature', OMIT),
163190
p=model_settings.get('top_p', OMIT),
164191
seed=model_settings.get('seed', OMIT),
192+
response_format=response_format,
165193
presence_penalty=model_settings.get('presence_penalty', OMIT),
166194
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
167195
)
@@ -193,7 +221,7 @@ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
193221
texts: list[str] = []
194222
tool_calls: list[ToolCallV2] = []
195223
for item in message.parts:
196-
if isinstance(item, TextPart):
224+
if isinstance(item, (TextPart, StructuredOutputPart)):
197225
texts.append(item.content)
198226
elif isinstance(item, ToolCallPart):
199227
tool_calls.append(self._map_tool_call(item))

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ async def request_stream(
109109
model_settings,
110110
)
111111

112-
assert self.stream_function is not None, (
113-
'FunctionModel must receive a `stream_function` to support streamed requests'
114-
)
112+
assert (
113+
self.stream_function is not None
114+
), 'FunctionModel must receive a `stream_function` to support streamed requests'
115115

116116
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117117

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,26 +183,40 @@ async def _make_request(
183183
model_settings: GeminiModelSettings,
184184
model_request_parameters: ModelRequestParameters,
185185
) -> AsyncIterator[HTTPResponse]:
186-
tools = self._get_tools(model_request_parameters)
187-
tool_config = self._get_tool_config(model_request_parameters, tools)
188186
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
189-
190187
request_data = _GeminiRequest(contents=contents)
191188
if sys_prompt_parts:
192189
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
193-
if tools is not None:
194-
request_data['tools'] = tools
195-
if tool_config is not None:
196-
request_data['tool_config'] = tool_config
197190

198191
generation_config: _GeminiGenerationConfig = {}
192+
if model_settings.get('force_response_format', False):
193+
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
194+
generation_config['response_mimetype'] = 'text/plain'
195+
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
196+
generation_config['response_mimetype'] = 'application/json'
197+
generation_config['response_schema'] = model_request_parameters.result_tools[0].parameters_json_schema
198+
else:
199+
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
200+
if model_request_parameters.allow_text_result:
201+
json_schemas.append({'type': 'string'})
202+
generation_config['response_schema'] = {'anyOf': json_schemas}
203+
else:
204+
tools = self._get_tools(model_request_parameters)
205+
tool_config = self._get_tool_config(model_request_parameters, tools)
206+
if tools is not None:
207+
request_data['tools'] = tools
208+
if tool_config is not None:
209+
request_data['tool_config'] = tool_config
210+
199211
if model_settings:
200212
if (max_tokens := model_settings.get('max_tokens')) is not None:
201213
generation_config['max_output_tokens'] = max_tokens
202214
if (temperature := model_settings.get('temperature')) is not None:
203215
generation_config['temperature'] = temperature
204216
if (top_p := model_settings.get('top_p')) is not None:
205217
generation_config['top_p'] = top_p
218+
if (seed := model_settings.get('seed')) is not None:
219+
generation_config['seed'] = seed
206220
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
207221
generation_config['presence_penalty'] = presence_penalty
208222
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
@@ -465,9 +479,12 @@ class _GeminiGenerationConfig(TypedDict, total=False):
465479
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
466480
"""
467481

482+
response_mimetype: Literal['text/plain', 'application/json']
483+
response_schema: dict[str, Any]
468484
max_output_tokens: int
469485
temperature: float
470486
top_p: float
487+
seed: int
471488
presence_penalty: float
472489
frequency_penalty: float
473490

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ async def _completions_create(
202202
tools=tools or NOT_GIVEN,
203203
tool_choice=tool_choice or NOT_GIVEN,
204204
stream=stream,
205+
response_format=response_format,
205206
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
206207
temperature=model_settings.get('temperature', NOT_GIVEN),
207208
top_p=model_settings.get('top_p', NOT_GIVEN),

0 commit comments

Comments
 (0)