Skip to content

Add support for the use of response_format to force a particular json schema for the response #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/pydantic_ai_examples/flight_booking.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ class Failed(BaseModel):


# This agent is responsible for extracting the user's seat selection
seat_preference_agent = Agent[None, SeatPreference | Failed](
seat_preference_agent = Agent[
None, SeatPreference | Failed
](
'openai:gpt-4o',
result_type=SeatPreference | Failed, # type: ignore
system_prompt=(
Expand Down
12 changes: 6 additions & 6 deletions examples/pydantic_ai_examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ async def retrieve(context: RunContext[Deps], search_query: str) -> str:
model='text-embedding-3-small',
)

assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
)
assert (
len(embedding.data) == 1
), f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
Expand Down Expand Up @@ -149,9 +149,9 @@ async def insert_doc_section(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, (
f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
)
assert (
len(embedding.data) == 1
), f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
Expand Down
43 changes: 43 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Generic, Literal, Union, cast

import logfire_api
from pydantic import ValidationError
from typing_extensions import TypeVar, assert_never

from pydantic_graph import BaseNode, Graph, GraphRunContext
Expand Down Expand Up @@ -370,12 +371,16 @@ async def _run_stream(

async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts: list[str] = []
structured_outputs: list[str] = []
tool_calls: list[_messages.ToolCallPart] = []
for part in self.model_response.parts:
if isinstance(part, _messages.TextPart):
# ignore empty content for text parts, see #437
if part.content:
texts.append(part.content)
elif isinstance(part, _messages.StructuredOutputPart):
if part.content:
structured_outputs.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
else:
Expand All @@ -391,6 +396,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
elif texts:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_text_response(ctx, texts)
elif structured_outputs:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_structured_outputs_response(ctx, texts)
else:
raise exceptions.UnexpectedModelBehavior('Received empty model response')

Expand Down Expand Up @@ -487,6 +495,41 @@ async def _handle_text_response(
)
)

async def _handle_structured_outputs_response(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
structured_outputs: list[str],
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
if len(structured_outputs) != 1:
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
result_schema = ctx.deps.result_schema
if not result_schema:
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')

structured_output = structured_outputs[0]
try:
result_data_input = result_schema.structured_output_validator.validate_json(structured_output)
except ValidationError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(
parts=[
_messages.RetryPromptPart(
content='Structured output validation failed: ' + str(e),
)
]
)
)

try:
result_data = await _validate_result(result_data_input, ctx, None)
except _result.ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
else:
# The following cast is safe because we know `str` is an allowed result type
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])


def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
"""Build a `RunContext` object from the current agent graph run context."""
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ResultSchema(Generic[ResultDataT]):
Similar to `Tool` but for the final result of running an agent.
"""

structured_output_validator: TypeAdapter[ResultDataT]
tools: dict[str, ResultTool[ResultDataT]]
allow_text_result: bool

Expand Down
48 changes: 46 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,22 @@ def has_content(self) -> bool:
return bool(self.args)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@dataclass
class StructuredOutputPart:
"""A structured output response from a model."""

content: str
"""The structured content of the response as a JSON-serialized string."""

part_kind: Literal['structured-output'] = 'structured-output'
"""Part type identifier, this is available on all parts as a discriminator."""

def has_content(self) -> bool:
"""Return `True` if the structured content is non-empty."""
return bool(self.content)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, StructuredOutputPart], pydantic.Discriminator('part_kind')]
"""A message part returned by a model."""


Expand Down Expand Up @@ -275,6 +290,33 @@ def apply(self, part: ModelResponsePart) -> TextPart:
return replace(part, content=part.content + self.content_delta)


@dataclass
class StructuredOutputPartDelta:
"""A partial update (delta) for a `StructuredOutputPart` to append new text content."""

content_delta: str
"""The incremental text content to add to the existing `StructuredOutputPart` content."""

part_delta_kind: Literal['structured-output'] = 'structured-output'
"""Part delta type identifier, used as a discriminator."""

def apply(self, part: ModelResponsePart) -> StructuredOutputPart:
"""Apply this text delta to an existing `TextPart`.

Args:
part: The existing model response part, which must be a `TextPart`.

Returns:
A new `TextPart` with updated text content.

Raises:
ValueError: If `part` is not a `TextPart`.
"""
if not isinstance(part, StructuredOutputPart):
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
return replace(part, content=part.content + self.content_delta)


@dataclass
class ToolCallPartDelta:
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
Expand Down Expand Up @@ -408,7 +450,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
return part


ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
ModelResponsePartDelta = Annotated[
Union[TextPartDelta, ToolCallPartDelta, StructuredOutputPartDelta], pydantic.Discriminator('part_delta_kind')
]
"""A partial update (delta) for any model response part."""


Expand Down
34 changes: 31 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import chain
from typing import Literal, Union, cast

from cohere import TextAssistantMessageContentItem
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

Expand All @@ -17,6 +16,7 @@
ModelResponse,
ModelResponsePart,
RetryPromptPart,
StructuredOutputPart,
SystemPromptPart,
TextPart,
ToolCallPart,
Expand All @@ -37,7 +37,11 @@
AsyncClientV2,
ChatMessageV2,
ChatResponse,
JsonObjectResponseFormatV2,
ResponseFormatV2,
SystemChatMessageV2,
TextAssistantMessageContentItem,
TextResponseFormatV2,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
Expand Down Expand Up @@ -152,7 +156,30 @@ async def _chat(
model_settings: CohereModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
if model_settings.get('force_response_format', False):
tools: list[ToolV2] = OMIT
response_format: ResponseFormatV2
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
response_format = TextResponseFormatV2()
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
result_tool = model_request_parameters.result_tools[0]
response_format = JsonObjectResponseFormatV2(
type='json_object',
json_schema=result_tool.parameters_json_schema,
)
else:
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
if model_request_parameters.allow_text_result:
json_schemas.append({'type': 'string'})
response_format = JsonObjectResponseFormatV2(
type='json_object',
json_schema={'anyOf': json_schemas},
)
else:
# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters)
response_format = OMIT

cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
return await self.client.chat(
model=self._model_name,
Expand All @@ -162,6 +189,7 @@ async def _chat(
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
response_format=response_format,
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
Expand Down Expand Up @@ -193,7 +221,7 @@ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
texts: list[str] = []
tool_calls: list[ToolCallV2] = []
for item in message.parts:
if isinstance(item, TextPart):
if isinstance(item, (TextPart, StructuredOutputPart)):
texts.append(item.content)
elif isinstance(item, ToolCallPart):
tool_calls.append(self._map_tool_call(item))
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ async def request_stream(
model_settings,
)

assert self.stream_function is not None, (
'FunctionModel must receive a `stream_function` to support streamed requests'
)
assert (
self.stream_function is not None
), 'FunctionModel must receive a `stream_function` to support streamed requests'

response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))

Expand Down
31 changes: 24 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,40 @@ async def _make_request(
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
sys_prompt_parts, contents = self._message_to_gemini_content(messages)

request_data = _GeminiRequest(contents=contents)
if sys_prompt_parts:
request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts)
if tools is not None:
request_data['tools'] = tools
if tool_config is not None:
request_data['tool_config'] = tool_config

generation_config: _GeminiGenerationConfig = {}
if model_settings.get('force_response_format', False):
if (n_result_tools := len(model_request_parameters.result_tools)) == 0:
generation_config['response_mimetype'] = 'text/plain'
elif n_result_tools == 1 and not model_request_parameters.allow_text_result:
generation_config['response_mimetype'] = 'application/json'
generation_config['response_schema'] = model_request_parameters.result_tools[0].parameters_json_schema
else:
json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools]
if model_request_parameters.allow_text_result:
json_schemas.append({'type': 'string'})
generation_config['response_schema'] = {'anyOf': json_schemas}
else:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
if tools is not None:
request_data['tools'] = tools
if tool_config is not None:
request_data['tool_config'] = tool_config

if model_settings:
if (max_tokens := model_settings.get('max_tokens')) is not None:
generation_config['max_output_tokens'] = max_tokens
if (temperature := model_settings.get('temperature')) is not None:
generation_config['temperature'] = temperature
if (top_p := model_settings.get('top_p')) is not None:
generation_config['top_p'] = top_p
if (seed := model_settings.get('seed')) is not None:
generation_config['seed'] = seed
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
generation_config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
Expand Down Expand Up @@ -465,9 +479,12 @@ class _GeminiGenerationConfig(TypedDict, total=False):
See <https://ai.google.dev/api/generate-content#generationconfig> for API docs.
"""

response_mimetype: Literal['text/plain', 'application/json']
response_schema: dict[str, Any]
max_output_tokens: int
temperature: float
top_p: float
seed: int
presence_penalty: float
frequency_penalty: float

Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ async def _completions_create(
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
response_format=response_format,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
Expand Down
Loading
Loading