From 19f8ad26369346d02c182806c6e6759b22153258 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 13 May 2025 19:05:50 +0200 Subject: [PATCH 1/2] WIP: With OutputPart --- pydantic_ai_slim/pydantic_ai/__init__.py | 3 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 68 +++- pydantic_ai_slim/pydantic_ai/_output.py | 358 +++++++++++++++--- .../pydantic_ai/_parts_manager.py | 2 + pydantic_ai_slim/pydantic_ai/agent.py | 33 +- pydantic_ai_slim/pydantic_ai/messages.py | 48 ++- .../pydantic_ai/models/__init__.py | 16 +- .../pydantic_ai/models/anthropic.py | 3 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 3 +- .../pydantic_ai/models/function.py | 3 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 3 +- .../pydantic_ai/models/mistral.py | 9 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 104 ++++- pydantic_ai_slim/pydantic_ai/models/test.py | 3 +- pydantic_ai_slim/pydantic_ai/result.py | 138 +++---- pydantic_ai_slim/pydantic_ai/tools.py | 5 +- .../test_openai_json_schema_output.yaml | 223 +++++++++++ .../test_openai_manual_json_output.yaml | 211 +++++++++++ .../test_openai/test_openai_tool_output.yaml | 227 +++++++++++ tests/models/test_fallback.py | 6 +- tests/models/test_gemini.py | 60 ++- tests/models/test_instrumented.py | 18 +- tests/models/test_model_request_parameters.py | 6 +- tests/models/test_openai.py | 257 ++++++++++++- tests/test_logfire.py | 4 +- 26 files changed, 1604 insertions(+), 220 deletions(-) create mode 100644 tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_tool_output.yaml diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 21ef4dec6..e49c31a03 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import ToolOutput +from .result import JSONSchemaOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -43,6 +43,7 @@ 'RunContext', # result 'ToolOutput', + 'JSONSchemaOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ccc1d18f7..7bf26e69a 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -24,7 +24,7 @@ result, usage as _usage, ) -from .result import OutputDataT, ToolOutput +from .result import OutputDataT from .settings import ModelSettings, merge_model_settings from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc @@ -249,10 +249,29 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema + model = ctx.deps.model + + output_mode = None + output_object = None + output_tools = [] + allow_text_output = _output.allow_text_output(output_schema) + if output_schema: + output_mode = output_schema.forced_mode or model.default_output_mode + output_object = output_schema.object_schema.definition + output_tools = output_schema.tool_defs() + if output_mode != 'tool': + allow_text_output = False + + supported_modes = model.supported_output_modes + if output_mode not in supported_modes: + raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}") + return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_output=allow_text_output(output_schema), - output_tools=output_schema.tool_defs() if output_schema is not None else [], + output_mode=output_mode, + output_object=output_object, + output_tools=output_tools, + allow_text_output=allow_text_output, ) @@ -403,7 +422,7 @@ async def stream( async for _event in stream: pass - async def _run_stream( + async def _run_stream( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: if self._events_iterator is None: @@ -411,12 +430,16 @@ async def _run_stream( async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: texts: list[str] = [] + outputs: list[str] = [] tool_calls: list[_messages.ToolCallPart] = [] for part in self.model_response.parts: if isinstance(part, _messages.TextPart): # ignore empty content for text parts, see #437 if part.content: texts.append(part.content) + elif isinstance(part, _messages.OutputPart): + if part.content: + outputs.append(part.content) elif isinstance(part, _messages.ToolCallPart): tool_calls.append(part) else: @@ -429,6 +452,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: if tool_calls: async for event in self._handle_tool_calls(ctx, tool_calls): yield event + elif outputs: # TODO: Can we have tool calls and structured output? Should we handle both? + # No events are emitted during the handling of structured outputs, so we don't need to yield anything + self._next_node = await self._handle_outputs(ctx, outputs) elif texts: # No events are emitted during the handling of text responses, so we don't need to yield anything self._next_node = await self._handle_text_response(ctx, texts) @@ -437,7 +463,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if allow_text_output(ctx.deps.output_schema): + if _output.allow_text_output(ctx.deps.output_schema): for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -520,7 +546,7 @@ async def _handle_text_response( output_schema = ctx.deps.output_schema text = '\n\n'.join(texts) - if allow_text_output(output_schema): + if _output.allow_text_output(output_schema): # The following cast is safe because we know `str` is an allowed result type result_data_input = cast(NodeRunEndT, text) try: @@ -542,6 +568,27 @@ async def _handle_text_response( ) ) + async def _handle_outputs( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + outputs: list[str], + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: + if len(outputs) != 1: + raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response') + output_schema = ctx.deps.output_schema + if not output_schema: + raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs') + + structured_output = outputs[0] + try: + result_data = output_schema.validate(structured_output) + result_data = await _validate_output(result_data, ctx, None) + except _output.ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries) + return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + else: + return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) + def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" @@ -782,11 +829,6 @@ async def _validate_output( return result_data -def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool: - """Check if the result schema allows text results.""" - return output_schema is None or output_schema.allow_text_output - - @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] @@ -836,7 +878,9 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( - name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT] + name: str | None, + deps_type: type[DepsT], + output_type: _output.OutputType[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index f2246ffb9..288ff388f 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,8 +1,10 @@ from __future__ import annotations as _annotations import inspect +import json from collections.abc import Awaitable, Iterable, Iterator from dataclasses import dataclass, field +from textwrap import dedent from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError @@ -12,11 +14,54 @@ from . import _utils, messages as _messages from .exceptions import ModelRetry -from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput -from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition +from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition T = TypeVar('T') """An invariant TypeVar.""" +OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) +""" +An invariant type variable for the result data of a model. + +We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used +in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types +possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and +changing it would have negative consequences for the ergonomics of the library. + +At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would +resolve these potential variance issues. +""" +OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) +"""Covariant type variable for the result data type of a run.""" + +OutputValidatorFunc = Union[ + Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], + Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], + Callable[[OutputDataT_inv], OutputDataT_inv], + Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], +] +""" +A function that always takes and returns the same type of data (which is the result type of an agent run), and: + +* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument +* may or may not be async + +Usage `OutputValidatorFunc[AgentDepsT, T]`. +""" + + +DEFAULT_OUTPUT_TOOL_NAME = 'final_result' +DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' +DEFAULT_MANUAL_JSON_PROMPT = dedent( + """ + Always respond with a JSON object matching this description and schema: + + {description} + + {schema} + + Don't include any text or Markdown fencing before or after. + """ +) @dataclass @@ -76,69 +121,183 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): super().__init__() +@dataclass(init=False) +class ToolOutput(Generic[OutputDataT]): + """Marker class to use tools for outputs, and customize the tool.""" + + output_type: type[OutputDataT] + # TODO: Add `output_call` support, for calling a function to get the output + # output_call: Callable[..., OutputDataT] | None + name: str | None + description: str | None + max_retries: int | None + strict: bool | None + + def __init__( + self, + *, + type_: type[OutputDataT], + # call: Callable[..., OutputDataT] | None = None, + name: str | None = None, + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ): + self.output_type = type_ + self.name = name + self.description = description + self.max_retries = max_retries + self.strict = strict + + # TODO: add support for call and make type_ optional, with the following logic: + # if type_ is None and call is None: + # raise ValueError('Either type_ or call must be provided') + # if call is not None: + # if type_ is None: + # type_ = get_type_hints(call).get('return') + # if type_ is None: + # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') + # self.output_call = call + + +@dataclass(init=False) +class JSONSchemaOutput(Generic[OutputDataT]): + """Marker class to use JSON schema output for outputs.""" + + output_type: type[OutputDataT] + name: str | None + description: str | None + strict: bool | None + + def __init__( + self, + *, + type_: type[OutputDataT], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): + self.output_type = type_ + self.name = name + self.description = description + self.strict = strict + + +class ManualJSONOutput(Generic[OutputDataT]): + """Marker class to use manual JSON mode for outputs.""" + + output_type: type[OutputDataT] + name: str | None + description: str | None + + def __init__( + self, + *, + type_: type[OutputDataT], + name: str | None = None, + description: str | None = None, + ): + self.output_type = type_ + self.name = name + self.description = description + + +# TODO: Use TypeAliasType +type OutputType[OutputDataT] = ( + type[OutputDataT] | ToolOutput[OutputDataT] | JSONSchemaOutput[OutputDataT] | ManualJSONOutput[OutputDataT] +) +# TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation +type OutputMode = Literal['tool', 'json_schema', 'manual_json'] + + @dataclass class OutputSchema(Generic[OutputDataT]): - """Model the final response from an agent run. + """Model the final output from an agent run. Similar to `Tool` but for the final output of running an agent. """ - tools: dict[str, OutputSchemaTool[OutputDataT]] - allow_text_output: bool + forced_mode: OutputMode | None + object_schema: OutputObjectSchema[OutputDataT] + tools: dict[str, OutputTool[OutputDataT]] + allow_text_output: bool # TODO: Verify structured output works correctly with string as a union member @classmethod def build( cls: type[OutputSchema[T]], - output_type: type[T] | ToolOutput[T], + output_type: OutputType[T], # TODO: Support a list of output types/markers name: str | None = None, description: str | None = None, strict: bool | None = None, ) -> OutputSchema[T] | None: - """Build an OutputSchema dataclass from a response type.""" + """Build an OutputSchema dataclass from an output type.""" if output_type is str: return None + forced_mode = None + tool_output_type = None + allow_text_output = False if isinstance(output_type, ToolOutput): # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + forced_mode = 'tool' name = output_type.name description = output_type.description output_type_ = output_type.output_type strict = output_type.strict + elif isinstance(output_type, JSONSchemaOutput): + forced_mode = 'json_schema' + name = output_type.name + description = output_type.description + output_type_ = output_type.output_type + strict = output_type.strict + elif isinstance(output_type, ManualJSONOutput): + forced_mode = 'manual_json' + name = output_type.name + description = output_type.description + output_type_ = output_type.output_type else: output_type_ = output_type - if output_type_option := extract_str_from_union(output_type): - output_type_ = output_type_option.value - allow_text_output = True - else: - allow_text_output = False + if output_type_other_than_str := extract_str_from_union(output_type): + allow_text_output = True + tool_output_type = output_type_other_than_str.value + + output_object_schema = OutputObjectSchema( + output_type=output_type_, name=name, description=description, strict=strict + ) - tools: dict[str, OutputSchemaTool[T]] = {} - if args := get_union_args(output_type_): + tool_output_type = tool_output_type or output_type_ + + tools: dict[str, OutputTool[T]] = {} + if args := get_union_args(tool_output_type): for i, arg in enumerate(args, start=1): tool_name = raw_tool_name = union_tool_name(name, arg) while tool_name in tools: tool_name = f'{raw_tool_name}_{i}' + + parameters_schema = OutputObjectSchema(output_type=arg, description=description, strict=strict) tools[tool_name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=arg, name=tool_name, description=description, multiple=True, strict=strict - ), + OutputTool[T], + OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=True), ) else: - name = name or DEFAULT_OUTPUT_TOOL_NAME - tools[name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=output_type_, name=name, description=description, multiple=False, strict=strict - ), + tool_name = name or DEFAULT_OUTPUT_TOOL_NAME + parameters_schema = OutputObjectSchema(output_type=tool_output_type, description=description, strict=strict) + tools[tool_name] = cast( + OutputTool[T], + OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=False), ) - return cls(tools=tools, allow_text_output=allow_text_output) + return cls( + forced_mode=forced_mode, + object_schema=output_object_schema, + tools=tools, + allow_text_output=allow_text_output, + ) def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None: + ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: """Find a tool that matches one of the calls, with a specific name.""" for part in parts: # pragma: no branch if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -148,7 +307,7 @@ def find_named_tool( def find_tool( self, parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]: + ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: """Find a tool that matches one of the calls.""" for part in parts: if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -163,56 +322,141 @@ def tool_defs(self) -> list[ToolDefinition]: """Get tool definitions to register with the model.""" return [t.tool_def for t in self.tools.values()] + def validate( + self, data: str | dict[str, Any], allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> OutputDataT: + """Validate an output message. -DEFAULT_DESCRIPTION = 'The final response which ends this conversation' + Args: + data: The output data to validate. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + return self.object_schema.validate( + data, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + +def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: + """Check if the result schema allows text results.""" + return output_schema is None or output_schema.allow_text_output + + +@dataclass +class OutputObjectDefinition: + name: str + json_schema: ObjectJsonSchema + description: str | None = None + strict: bool | None = None + + @property + def manual_json_instructions(self) -> str: + """Get instructions for model to output manual JSON matching the schema.""" + description = ': '.join([v for v in [self.name, self.description] if v]) + return DEFAULT_MANUAL_JSON_PROMPT.format(schema=json.dumps(self.json_schema), description=description) @dataclass(init=False) -class OutputSchemaTool(Generic[OutputDataT]): - tool_def: ToolDefinition +class OutputObjectSchema(Generic[OutputDataT]): + definition: OutputObjectDefinition type_adapter: TypeAdapter[Any] + outer_typed_dict_key: str | None = None def __init__( - self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None + self, + *, + output_type: type[OutputDataT], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, ): - """Build a OutputSchemaTool from a response type.""" if _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) - outer_typed_dict_key: str | None = None - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) else: + self.outer_typed_dict_key = 'response' response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] ) self.type_adapter = TypeAdapter(response_data_typed_dict) - outer_typed_dict_key = 'response' - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) + + json_schema = _utils.check_object_json_schema( + self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + ) + if self.outer_typed_dict_key: # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM - parameters_json_schema.pop('title') + json_schema.pop('title') - if json_schema_description := parameters_json_schema.pop('description', None): + if json_schema_description := json_schema.pop('description', None): if description is None: - tool_description = json_schema_description + description = json_schema_description else: - tool_description = f'{description}. {json_schema_description}' # pragma: no cover + description = f'{description}. {json_schema_description}' + + self.definition = OutputObjectDefinition( + name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), + description=description, + json_schema=json_schema, + strict=strict, + ) + + def validate( + self, data: str | dict[str, Any], allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> OutputDataT: + """Validate an output message. + + Args: + data: The output data to validate. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + try: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self.type_adapter.validate_json(data, experimental_allow_partial=pyd_allow_partial) + else: + output = self.type_adapter.validate_python(data, experimental_allow_partial=pyd_allow_partial) + except ValidationError as e: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=e.errors(include_url=False), + ) + raise ToolRetryError(m) from e + else: + raise else: - tool_description = description or DEFAULT_DESCRIPTION + if k := self.outer_typed_dict_key: + output = output[k] + return output + + +@dataclass(init=False) +class OutputTool(Generic[OutputDataT]): + parameters_schema: OutputObjectSchema[OutputDataT] + tool_def: ToolDefinition + + def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): + self.parameters_schema = parameters_schema + definition = parameters_schema.definition + + description = definition.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION if multiple: - tool_description = f'{union_arg_name(output_type)}: {tool_description}' + description = f'{definition.name}: {description}' self.tool_def = ToolDefinition( name=name, - description=tool_description, - parameters_json_schema=parameters_json_schema, - outer_typed_dict_key=outer_typed_dict_key, - strict=strict, + description=description, + parameters_json_schema=definition.json_schema, + strict=definition.strict, + outer_typed_dict_key=parameters_schema.outer_typed_dict_key, ) def validate( @@ -229,11 +473,9 @@ def validate( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(tool_call.args, str): - output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial) - else: - output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial) + output = self.parameters_schema.validate( + tool_call.args, allow_partial=allow_partial, wrap_validation_errors=False + ) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -245,8 +487,6 @@ def validate( else: raise # pragma: lax no cover else: - if k := self.tool_def.outer_typed_dict_key: - output = output[k] return output diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index d99253d1e..8b9d1ca62 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -91,6 +91,8 @@ def handle_text_delta( """ existing_text_part_and_index: tuple[TextPart, int] | None = None + # TODO: Parse out structured output or manual JSON, with a separate message? + if vendor_part_id is None: # If the vendor_part_id is None, check if the latest part is a TextPart to update if self._parts: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 324247886..b4b35a076 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -29,7 +29,7 @@ usage as _usage, ) from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput +from .result import FinalResult, OutputDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -127,7 +127,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: type[OutputDataT] | ToolOutput[OutputDataT] + output_type: _output.OutputType[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -166,7 +166,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str, + output_type: _output.OutputType[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -203,7 +203,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, - result_tool_name: str = 'final_result', + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, result_tool_description: str | None = None, result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), @@ -218,7 +218,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - # TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads + # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads output_type: Any = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] @@ -330,6 +330,7 @@ def __init__( self._instructions_functions = [] if isinstance(instructions, (str, Callable)): instructions = [instructions] + # TODO: Add OutputSchema to the instructions in JSON mode for instruction in instructions or []: if isinstance(instruction, str): self._instructions += instruction + '\n' @@ -378,7 +379,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -408,7 +409,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -496,7 +497,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -528,7 +529,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -783,7 +784,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -813,7 +814,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -896,7 +897,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -927,7 +928,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -1007,11 +1008,13 @@ async def stream_to_final( if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart): - if _agent_graph.allow_text_output(output_schema): + if _output.allow_text_output(output_schema): return FinalResult(s, None, None) elif isinstance(new_part, _messages.ToolCallPart) and output_schema: for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) + elif isinstance(new_part, _messages.OutputPart) and output_schema: + return FinalResult(s, None, None) return None final_result_details = await stream_to_final(streamed_response) @@ -1641,7 +1644,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None + self, output_type: _output.OutputType[RunOutputDataT] | None ) -> _output.OutputSchema[RunOutputDataT] | None: if output_type is not None: if self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index e972f9498..36f40f4b9 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -479,6 +479,21 @@ def has_content(self) -> bool: return bool(self.content) +@dataclass +class OutputPart: + """An output response from a model.""" + + content: str + """The output content of the response as a JSON-serialized string.""" + + part_kind: Literal['output'] = 'output' + """Part type identifier, this is available on all parts as a discriminator.""" + + def has_content(self) -> bool: + """Return `True` if the output content is non-empty.""" + return bool(self.content) + + @dataclass class ToolCallPart: """A tool call from a model.""" @@ -533,7 +548,7 @@ def has_content(self) -> bool: return bool(self.args) -ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')] +ModelResponsePart = Annotated[Union[TextPart, OutputPart, ToolCallPart], pydantic.Discriminator('part_kind')] """A message part returned by a model.""" @@ -639,6 +654,33 @@ def apply(self, part: ModelResponsePart) -> TextPart: return replace(part, content=part.content + self.content_delta) +@dataclass +class OutputPartDelta: + """A partial update (delta) for a `OutputPart` to append new structured output content.""" + + content_delta: str + """The incremental structured output content to add to the existing `OutputPart` content.""" + + part_delta_kind: Literal['output'] = 'output' + """Part delta type identifier, used as a discriminator.""" + + def apply(self, part: ModelResponsePart) -> OutputPart: + """Apply this structured output delta to an existing `OutputPart`. + + Args: + part: The existing model response part, which must be a `OutputPart`. + + Returns: + A new `OutputPart` with updated structured output content. + + Raises: + ValueError: If `part` is not a `OutputPart`. + """ + if not isinstance(part, OutputPart): + raise ValueError('Cannot apply OutputPartDeltas to non-OutputParts') + return replace(part, content=part.content + self.content_delta) + + @dataclass class ToolCallPartDelta: """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID.""" @@ -756,7 +798,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: return part -ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] +ModelResponsePartDelta = Annotated[ + Union[TextPartDelta, OutputPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind') +] """A partial update (delta) for any model response part.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 40052df76..8cdc53bb3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -16,6 +16,7 @@ import httpx from typing_extensions import Literal, TypeAliasType +from .._output import OutputMode, OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent @@ -261,8 +262,11 @@ class ModelRequestParameters: """Configuration for an agent's request to a model, specifically related to tools and output handling.""" function_tools: list[ToolDefinition] = field(default_factory=list) - allow_text_output: bool = True + + output_mode: OutputMode | None = None + output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) + allow_text_output: bool = True class Model(ABC): @@ -367,6 +371,16 @@ def _get_instructions(messages: list[ModelMessage]) -> str | None: return None + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + return {'tool'} # TODO: Support manual_json on all + + @property + def default_output_mode(self) -> OutputMode: + """The default output mode for the model.""" + return 'tool' + @dataclass class StreamedResponse(ABC): diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 68d3f1ab6..0d5592ff7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -21,6 +21,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -321,7 +322,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me elif isinstance(m, ModelResponse): assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = [] for response_part in m.parts: - if isinstance(response_part, TextPart): + if isinstance(response_part, (TextPart, OutputPart)): assistant_content_params.append(TextBlockParam(text=response_part.content, type='text')) else: tool_use_block_param = ToolUseBlockParam( diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 5c2ef6bb9..a3fa1ecbb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -13,6 +13,7 @@ ModelRequest, ModelResponse, ModelResponsePart, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -205,7 +206,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: texts: list[str] = [] tool_calls: list[ToolCallV2] = [] for item in message.parts: - if isinstance(item, TextPart): + if isinstance(item, (TextPart, OutputPart)): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 22bcddffb..d20277715 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -21,6 +21,7 @@ ModelRequest, ModelResponse, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -266,7 +267,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: assert_never(part) elif isinstance(message, ModelResponse): for part in message.parts: - if isinstance(part, TextPart): + if isinstance(part, (TextPart, OutputPart)): response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolCallPart): call = part diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4390bc7d6..ba0ab499f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -17,6 +17,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage +from .._output import OutputObjectDefinition from ..messages import ( AudioUrl, BinaryContent, @@ -27,6 +28,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -171,10 +173,17 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar def _customize_tool_def(t: ToolDefinition): return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).walk()) + def _customize_output_object_def(o: OutputObjectDefinition): + return replace(o, json_schema=_GeminiJsonSchema(o.json_schema).walk()) + return ModelRequestParameters( function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools], - allow_text_output=model_request_parameters.allow_text_output, + output_mode=model_request_parameters.output_mode, + output_object=_customize_output_object_def(model_request_parameters.output_object) + if model_request_parameters.output_object + else None, output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools], + allow_text_output=model_request_parameters.allow_text_output, ) @property @@ -554,7 +563,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent: for item in m.parts: if isinstance(item, ToolCallPart): parts.append(_function_call_part_from_call(item)) - elif isinstance(item, TextPart): + elif isinstance(item, (TextPart, OutputPart)): if item.content: parts.append(_GeminiTextPart(text=item.content)) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 47b0f693d..28e9287b0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -20,6 +20,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -274,7 +275,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletio texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: - if isinstance(item, TextPart): + if isinstance(item, (TextPart, OutputPart)): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ef76996af..7a026103e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -22,6 +22,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -246,6 +247,7 @@ async def _stream_completions_create( ) elif model_request_parameters.output_tools: + # TODO: Port to native "manual JSON" mode # Json Mode parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] user_output_format_message = self._generate_user_output_format(parameters_json_schemas) @@ -254,7 +256,9 @@ async def _stream_completions_create( response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, - response_format={'type': 'json_object'}, + response_format={ + 'type': 'json_object' + }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, http_headers={'User-Agent': get_user_agent()}, ) @@ -478,7 +482,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: tool_calls: list[MistralToolCall] = [] for part in message.parts: - if isinstance(part, TextPart): + if isinstance(part, (TextPart, OutputPart)): content_chunks.append(MistralTextChunk(text=part.content)) elif isinstance(part, ToolCallPart): tool_calls.append(self._map_tool_call(part)) @@ -562,6 +566,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Attempt to produce an output tool call from the received text if self._output_tools: self._delta_content += text + # TODO: Port to native "manual JSON" mode maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools) if maybe_tool_call_part: yield self._parts_manager.handle_tool_call_part( diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 14156ab33..3778d4951 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -14,6 +14,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import OutputMode, OutputObjectDefinition from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( AudioUrl, @@ -25,6 +26,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -203,7 +205,7 @@ async def request( response = await self._completions_create( messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) - model_response = self._process_response(response) + model_response = self._process_response(response, model_request_parameters) model_response.usage.requests = 1 return model_response @@ -234,6 +236,11 @@ def system(self) -> str: """The system / model provider.""" return self._system + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + return {'tool', 'json_schema', 'manual_json'} + @overload async def _completions_create( self, @@ -259,18 +266,28 @@ async def _completions_create( model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: - tools = self._get_tools(model_request_parameters) - - # standalone function to make it easier to override - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: - tool_choice = 'required' - else: - tool_choice = 'auto' - openai_messages = await self._map_messages(messages) + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + tool_choice: Literal['none', 'required', 'auto'] | NotGiven = NOT_GIVEN + response_format: chat.completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN + + if model_request_parameters.output_mode == 'tool': + tools.extend(self._map_tool_definition(r) for r in model_request_parameters.output_tools) + + if not model_request_parameters.allow_text_output: + tool_choice = 'required' + elif output_object := model_request_parameters.output_object: + if model_request_parameters.output_mode == 'json_schema': + response_format = self._map_output_object_definition(output_object) + elif model_request_parameters.output_mode == 'manual_json': + openai_messages.insert( + 0, + chat.ChatCompletionSystemMessageParam( + content=output_object.manual_json_instructions, role='system' + ), + ) + try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) @@ -288,6 +305,7 @@ async def _completions_create( temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), + response_format=response_format, seed=model_settings.get('seed', NOT_GIVEN), presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), @@ -304,7 +322,9 @@ async def _completions_create( raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover - def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: + def _process_response( + self, response: chat.ChatCompletion, model_request_parameters: ModelRequestParameters + ) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) choice = response.choices[0] @@ -329,7 +349,11 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: } if choice.message.content is not None: - items.append(TextPart(choice.message.content)) + if model_request_parameters.output_mode in {'json_schema', 'manual_json'}: + # TODO: Strip Markdown fence and text before/after + items.append(OutputPart(choice.message.content)) + else: + items.append(TextPart(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) @@ -374,7 +398,7 @@ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCom texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: - if isinstance(item, TextPart): + if isinstance(item, (TextPart, OutputPart)): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) @@ -402,6 +426,22 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) + @staticmethod + def _map_output_object_definition(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: + # TODO: Use ResponseFormatJSONObject on older models + response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] + 'type': 'json_schema', + 'json_schema': { + 'name': o.name, + 'schema': o.json_schema, + }, + } + if o.description: + response_format_param['json_schema']['description'] = o.description + if o.strict: + response_format_param['json_schema']['strict'] = o.strict + return response_format_param + @staticmethod def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: tool_param: chat.ChatCompletionToolParam = { @@ -557,6 +597,11 @@ def system(self) -> str: """The system / model provider.""" return self._system + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + return {'tool', 'json_schema'} + async def request( self, messages: list[ModelRequest | ModelResponse], @@ -567,7 +612,7 @@ async def request( response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) - return self._process_response(response) + return self._process_response(response, model_request_parameters) @asynccontextmanager async def request_stream( @@ -586,11 +631,17 @@ async def request_stream( def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: return _customize_request_parameters(model_request_parameters) - def _process_response(self, response: responses.Response) -> ModelResponse: + def _process_response( + self, response: responses.Response, model_request_parameters: ModelRequestParameters + ) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc) items: list[ModelResponsePart] = [] - items.append(TextPart(response.output_text)) + # TODO: Parse out manual JSON, a la split_content_into_text_and_thinking + if model_request_parameters.output_mode == 'json_schema': + items.append(OutputPart(response.output_text)) + else: + items.append(TextPart(response.output_text)) for item in response.output: if item.type == 'function_call': items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id)) @@ -654,6 +705,8 @@ async def _responses_create( try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) + # TODO: Pass text.format = ResponseFormatTextJSONSchemaConfigParam(...): {'type': 'json_schema', 'strict': True, 'name': '...', 'schema': ...} + # TODO: Fall back on ResponseFormatJSONObject/json_object on older models? return await self.client.responses.create( input=openai_messages, model=self._model_name, @@ -740,7 +793,7 @@ async def _map_messages( assert_never(part) elif isinstance(message, ModelResponse): for item in message.parts: - if isinstance(item, TextPart): + if isinstance(item, (TextPart, OutputPart)): openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content)) elif isinstance(item, ToolCallPart): openai_messages.append(self._map_tool_call(item)) @@ -850,6 +903,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: + # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) for dtc in choice.delta.tool_calls or []: @@ -931,6 +985,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(chunk, responses.ResponseTextDeltaEvent): + # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta) elif isinstance(chunk, responses.ResponseTextDoneEvent): @@ -1115,8 +1170,19 @@ def _customize_tool_def(t: ToolDefinition): t = replace(t, strict=schema_transformer.is_strict_compatible) return replace(t, parameters_json_schema=parameters_json_schema) + def _customize_output_object_def(o: OutputObjectDefinition): + schema_transformer = _OpenAIJsonSchema(o.json_schema, strict=o.strict) + parameters_json_schema = schema_transformer.walk() + if o.strict is None: + o = replace(o, strict=schema_transformer.is_strict_compatible) + return replace(o, json_schema=parameters_json_schema) + return ModelRequestParameters( function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools], - allow_text_output=model_request_parameters.allow_text_output, + output_mode=model_request_parameters.output_mode, + output_object=_customize_output_object_def(model_request_parameters.output_object) + if model_request_parameters.output_object + else None, output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools], + allow_text_output=model_request_parameters.allow_text_output, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 0daad25bc..01f459566 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -17,6 +17,7 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + OutputPart, RetryPromptPart, TextPart, ToolCallPart, @@ -241,7 +242,7 @@ def __post_init__(self, _messages: Iterable[ModelMessage]): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): - if isinstance(part, TextPart): + if isinstance(part, (TextPart, OutputPart)): text = part.content *words, last_word = text.split(' ') words = [f'{word} ' for word in words] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 6d9d39739..f8b5658af 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,100 +5,36 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Union, cast +from typing import Generic, cast from typing_extensions import TypeVar, assert_type, deprecated, overload -from . import _utils, exceptions, messages as _messages, models +from . import _output, _utils, exceptions, messages as _messages, models +from ._output import ( + JSONSchemaOutput, + OutputDataT, + OutputDataT_inv, + OutputSchema, + OutputValidator, + OutputValidatorFunc, + ToolOutput, +) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -if TYPE_CHECKING: - from . import _output - -__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc' +__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'JSONSchemaOutput', 'OutputValidatorFunc' T = TypeVar('T') """An invariant TypeVar.""" -OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) -""" -An invariant type variable for the result data of a model. - -We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used -in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types -possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and -changing it would have negative consequences for the ergonomics of the library. - -At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would -resolve these potential variance issues. -""" -OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) -"""Covariant type variable for the result data type of a run.""" - -OutputValidatorFunc = Union[ - Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], - Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], - Callable[[OutputDataT_inv], OutputDataT_inv], - Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], -] -""" -A function that always takes and returns the same type of data (which is the result type of an agent run), and: - -* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument -* may or may not be async - -Usage `OutputValidatorFunc[AgentDepsT, T]`. -""" - -DEFAULT_OUTPUT_TOOL_NAME = 'final_result' - - -@dataclass(init=False) -class ToolOutput(Generic[OutputDataT]): - """Marker class to use tools for structured outputs, and customize the tool.""" - - output_type: type[OutputDataT] - # TODO: Add `output_call` support, for calling a function to get the output - # output_call: Callable[..., OutputDataT] | None - name: str - description: str | None - max_retries: int | None - strict: bool | None - - def __init__( - self, - *, - type_: type[OutputDataT], - # call: Callable[..., OutputDataT] | None = None, - name: str = 'final_result', - description: str | None = None, - max_retries: int | None = None, - strict: bool | None = None, - ): - self.output_type = type_ - self.name = name - self.description = description - self.max_retries = max_retries - self.strict = strict - - # TODO: add support for call and make type_ optional, with the following logic: - # if type_ is None and call is None: - # raise ValueError('Either type_ or call must be provided') - # if call is not None: - # if type_ is None: - # type_ = get_type_hints(call).get('return') - # if type_ is None: - # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') - # self.output_call = call @dataclass class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_schema: OutputSchema[OutputDataT] | None + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -157,6 +93,18 @@ async def _validate_response( for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) return result_data + elif ( + self._output_schema is not None + and len(message.parts) > 0 + and (part := message.parts[0]) + and isinstance(part, _messages.OutputPart) + ): + result_data = self._output_schema.validate( + part.content, allow_partial=allow_partial, wrap_validation_errors=False + ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) + return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) for validator in self._output_validators: @@ -180,7 +128,6 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: async def aiter(): output_schema = self._output_schema - allow_text_output = output_schema is None or output_schema.allow_text_output def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" @@ -192,7 +139,10 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. return _messages.FinalResultEvent( tool_name=call.tool_name, tool_call_id=call.tool_call_id ) - elif allow_text_output: # pragma: no branch + elif isinstance(new_part, _messages.OutputPart): + if output_schema: + return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) + elif _output.allow_text_output(output_schema): # pragma: no branch assert_type(e, _messages.PartStartEvent) return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) @@ -224,9 +174,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] | None _run_ctx: RunContext[AgentDepsT] - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] @@ -471,6 +421,30 @@ async def validate_structured_output( for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover return result_data + elif ( + self._output_schema is not None + and len(message.parts) > 0 + and (part := message.parts[0]) + and isinstance(part, _messages.OutputPart) + ): + result_data = self._output_schema.validate( + part.content, allow_partial=allow_partial, wrap_validation_errors=False + ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover + return result_data + elif ( + self._output_schema is not None + and len(message.parts) > 0 + and (part := message.parts[0]) + and isinstance(part, _messages.OutputPart) + ): + result_data = self._output_schema.validate( + part.content, allow_partial=allow_partial, wrap_validation_errors=False + ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) + return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) for validator in self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index cee99dfd2..584d77e02 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -13,10 +13,11 @@ from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar -from . import _pydantic, _utils, messages as _messages, models +from . import _pydantic, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: + from .models import Model from .result import Usage __all__ = ( @@ -45,7 +46,7 @@ class RunContext(Generic[AgentDepsT]): deps: AgentDepsT """Dependencies for the agent.""" - model: models.Model + model: Model """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml new file mode 100644 index 000000000..ff4477f3d --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml @@ -0,0 +1,223 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '522' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '341' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + created: 1746142582 + id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 71 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 83 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '753' + content-type: + - application/json + cookie: + - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; + _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + - content: Mexico + role: tool + tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '852' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '553' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1746142583 + id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 92 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 107 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml b/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml new file mode 100644 index 000000000..56023f426 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml @@ -0,0 +1,211 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '627' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |2 + + Always respond with a JSON object matching this description and schema: + + CityLocation + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + stream: false + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '430' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_uTjt2vMkeTr0GYqQyQYrUUhl + type: function + created: 1747154400 + id: chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_55d88aaf2f + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 106 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 118 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '858' + content-type: + - application/json + cookie: + - __cf_bm=95NT6qevASASUyV3RVHQoxZGp8lnU1dQzcdShJ0rQ8o-1747154400-1.0.1.1-zowTt2i3mTZlYQ8gezUuRRLY_0dw6L6iD5qfaNySs0KmHmLd2JFwYun1kZJ61S03BecMhUdxy.FiOWLq2LdY.RuTR7vePLyoCrMmCDa4vpk; + _cfuvid=hgD2spnngVs.0HuyvQx7_W1uCro2gMmGvsKkZTUk3H0-1747154400314-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |2 + + Always respond with a JSON object matching this description and schema: + + CityLocation + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_uTjt2vMkeTr0GYqQyQYrUUhl + type: function + - content: Mexico + role: tool + tool_call_id: call_uTjt2vMkeTr0GYqQyQYrUUhl + model: gpt-4o + n: 1 + stream: false + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '853' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '2453' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1747154401 + id: chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_d8864f8b6b + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 127 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 139 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_tool_output.yaml b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml new file mode 100644 index 000000000..56f7441f1 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml @@ -0,0 +1,227 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '561' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '348' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + created: 1746142584 + id: chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 68 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 80 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '792' + content-type: + - application/json + cookie: + - __cf_bm=yM.C6I_kAzJk3Dm7H52actN1zAEW8fj.Gd2yeJ7tKN0-1746142584-1.0.1.1-xk91aElDtLLC8aROrOKHlp5vck_h.R.zQkS6OrsiBOwuFA8rE1kGswpactMEtYxV9WgWDN2B4S2B4zs8heyxmcfiNjmOf075n.OPqYpVla4; + _cfuvid=JCllInpf6fg1JdOS7xSj3bZOXYf9PYJ8uoamRTx7ku4-1746142584855-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + - content: Mexico + role: tool + tool_call_id: call_iXFttys57ap0o16JSlC8yhYo + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1113' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1919' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"city": "Mexico City", "country": "Mexico"}' + name: final_result + id: call_gmD2oUZUzSoCkmNmp3JPUF7R + type: function + created: 1746142585 + id: chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 36 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 89 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 125 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index db6277527..805281ced 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -127,7 +127,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', 'gen_ai.system': 'function', @@ -200,7 +200,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', 'gen_ai.system': 'function', @@ -272,7 +272,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'fallback:function,function', 'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:', diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ef21c5ac0..84859fae8 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -63,7 +63,9 @@ async def test_model_simple(allow_model_requests: None): assert m.model_name == 'gemini-1.5-flash' assert 'x-goog-api-key' in m.client.headers - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[]) + mrp = ModelRequestParameters( + function_tools=[], allow_text_output=True, output_tools=[], output_mode=None, output_object=None + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -96,7 +98,13 @@ async def test_model_tools(allow_model_requests: None): {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, ) - mrp = ModelRequestParameters(function_tools=tools, allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=tools, + allow_text_output=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -138,7 +146,13 @@ async def test_require_response_tool(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -219,7 +233,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( { @@ -298,7 +318,13 @@ class QueryDetails(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + output_mode=None, + allow_text_output=True, + output_tools=[output_tool], + output_object=None, + ) mrp = m.customize_request_parameters(mrp) # This tests that the enum values are properly converted to strings for Gemini @@ -340,7 +366,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( @@ -404,7 +436,13 @@ class Location(BaseModel): json_schema, ) with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) @@ -436,7 +474,13 @@ class FormattedStringFields(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index f7caad399..51e743975 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -134,6 +134,8 @@ async def test_instrumented_model(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -151,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -330,6 +332,8 @@ async def test_instrumented_model_not_recording(): function_tools=[], allow_text_output=True, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -352,6 +356,8 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode=None, + output_object=None, ), ) as response_stream: assert [event async for event in response_stream] == snapshot( @@ -375,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -440,6 +446,8 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode=None, + output_object=None, ), ) as response_stream: async for event in response_stream: # pragma: no branch @@ -460,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -543,6 +551,8 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -560,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 03910db11..008648f34 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -4,9 +4,13 @@ def test_model_request_parameters_are_serializable(): - params = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[]) + params = ModelRequestParameters( + function_tools=[], output_mode=None, allow_text_output=False, output_tools=[], output_object=None + ) assert TypeAdapter(ModelRequestParameters).dump_python(params) == { 'function_tools': [], + 'preferred_output_mode': None, 'allow_text_output': False, 'output_tools': [], + 'output_object': None, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index bc146fcd4..4816f31d2 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,6 +15,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior +from pydantic_ai._output import ManualJSONOutput from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -22,6 +23,7 @@ ImageUrl, ModelRequest, ModelResponse, + OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -31,7 +33,7 @@ ) from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import JSONSchemaOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import @@ -1604,3 +1606,256 @@ async def test_openai_instructions_with_logprobs(allow_model_requests: None): 'top_logprobs': [], } ] + + +@pytest.mark.vcr() +async def test_openai_tool_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(type_=CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=68, + response_tokens=12, + total_tokens=80, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City", "country": "Mexico"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=89, + response_tokens=36, + total_tokens=125, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_json_schema_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JSONSchemaOutput(type_=CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=71, + response_tokens=12, + total_tokens=83, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[OutputPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=92, + response_tokens=15, + total_tokens=107, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_manual_json_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ManualJSONOutput(type_=CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=106, + response_tokens=12, + total_tokens=118, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[OutputPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=127, + response_tokens=12, + total_tokens=139, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU', + ), + ] + ) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index e63f358f3..57ee697e6 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -223,8 +223,10 @@ async def my_ret(x: int) -> str: 'strict': None, } ], + 'preferred_output_mode': None, 'allow_text_output': True, 'output_tools': [], + 'output_object': None, } ) ), @@ -404,7 +406,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'test', 'gen_ai.request.model': 'test', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', 'logfire.span_type': 'span', 'logfire.msg': 'chat test', 'gen_ai.usage.input_tokens': 51, From 64a3ea63a65d8a34e68ca2a595d8dbb70af7b21b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 15 May 2025 11:50:17 -0600 Subject: [PATCH 2/2] WIP: Remove OutputPart, work around allow_text_output instead --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 58 +++---------- pydantic_ai_slim/pydantic_ai/_output.py | 27 +++--- pydantic_ai_slim/pydantic_ai/agent.py | 2 - pydantic_ai_slim/pydantic_ai/messages.py | 48 +---------- .../pydantic_ai/models/__init__.py | 2 +- .../pydantic_ai/models/anthropic.py | 5 +- .../pydantic_ai/models/bedrock.py | 2 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 3 +- .../pydantic_ai/models/function.py | 7 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 7 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 5 +- .../pydantic_ai/models/mistral.py | 5 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 36 ++++---- pydantic_ai_slim/pydantic_ai/models/test.py | 7 +- pydantic_ai_slim/pydantic_ai/result.py | 85 ++++++------------- tests/models/test_fallback.py | 6 +- tests/models/test_gemini.py | 16 ++-- tests/models/test_instrumented.py | 18 ++-- tests/models/test_model_request_parameters.py | 4 +- tests/models/test_openai.py | 17 ++-- tests/test_agent.py | 13 +-- tests/test_logfire.py | 6 +- 22 files changed, 133 insertions(+), 246 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 7bf26e69a..297da5ce5 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -254,13 +254,12 @@ async def add_mcp_server_tools(server: MCPServer) -> None: output_mode = None output_object = None output_tools = [] - allow_text_output = _output.allow_text_output(output_schema) + require_tool_use = False if output_schema: output_mode = output_schema.forced_mode or model.default_output_mode output_object = output_schema.object_schema.definition output_tools = output_schema.tool_defs() - if output_mode != 'tool': - allow_text_output = False + require_tool_use = output_mode == 'tool' and not output_schema.allow_plain_text_output supported_modes = model.supported_output_modes if output_mode not in supported_modes: @@ -271,7 +270,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None: output_mode=output_mode, output_object=output_object, output_tools=output_tools, - allow_text_output=allow_text_output, + require_tool_use=require_tool_use, ) @@ -422,7 +421,7 @@ async def stream( async for _event in stream: pass - async def _run_stream( # noqa: C901 + async def _run_stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: if self._events_iterator is None: @@ -430,16 +429,12 @@ async def _run_stream( # noqa: C901 async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: texts: list[str] = [] - outputs: list[str] = [] tool_calls: list[_messages.ToolCallPart] = [] for part in self.model_response.parts: if isinstance(part, _messages.TextPart): # ignore empty content for text parts, see #437 if part.content: texts.append(part.content) - elif isinstance(part, _messages.OutputPart): - if part.content: - outputs.append(part.content) elif isinstance(part, _messages.ToolCallPart): tool_calls.append(part) else: @@ -452,9 +447,6 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: if tool_calls: async for event in self._handle_tool_calls(ctx, tool_calls): yield event - elif outputs: # TODO: Can we have tool calls and structured output? Should we handle both? - # No events are emitted during the handling of structured outputs, so we don't need to yield anything - self._next_node = await self._handle_outputs(ctx, outputs) elif texts: # No events are emitted during the handling of text responses, so we don't need to yield anything self._next_node = await self._handle_text_response(ctx, texts) @@ -546,42 +538,18 @@ async def _handle_text_response( output_schema = ctx.deps.output_schema text = '\n\n'.join(texts) - if _output.allow_text_output(output_schema): - # The following cast is safe because we know `str` is an allowed result type - result_data_input = cast(NodeRunEndT, text) - try: - result_data = await _validate_output(result_data_input, ctx, None) - except _output.ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + try: + if output_schema is None or output_schema.allow_plain_text_output: + # The following cast is safe because we know `str` is an allowed result type + result_data = cast(NodeRunEndT, text) + elif output_schema.allow_json_text_output: + result_data = output_schema.validate(text) else: - return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) - else: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest( - parts=[ - _messages.RetryPromptPart( - content='Plain text responses are not permitted, please include your response in a tool call', - ) - ] + m = _messages.RetryPromptPart( + content='Plain text responses are not permitted, please include your response in a tool call', ) - ) + raise _output.ToolRetryError(m) - async def _handle_outputs( - self, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - outputs: list[str], - ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: - if len(outputs) != 1: - raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response') - output_schema = ctx.deps.output_schema - if not output_schema: - raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs') - - structured_output = outputs[0] - try: - result_data = output_schema.validate(structured_output) result_data = await _validate_output(result_data, ctx, None) except _output.ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 288ff388f..825d83e7c 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -220,7 +220,8 @@ class OutputSchema(Generic[OutputDataT]): forced_mode: OutputMode | None object_schema: OutputObjectSchema[OutputDataT] tools: dict[str, OutputTool[OutputDataT]] - allow_text_output: bool # TODO: Verify structured output works correctly with string as a union member + allow_plain_text_output: bool + allow_json_text_output: bool # TODO: Turn into allowed_text_output: Literal['plain', 'json'] | None @classmethod def build( @@ -235,11 +236,14 @@ def build( return None forced_mode = None + allow_json_text_output = True + allow_plain_text_output = False tool_output_type = None - allow_text_output = False if isinstance(output_type, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads forced_mode = 'tool' + allow_json_text_output = False + + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads name = output_type.name description = output_type.description output_type_ = output_type.output_type @@ -255,12 +259,15 @@ def build( name = output_type.name description = output_type.description output_type_ = output_type.output_type - else: + elif output_type_other_than_str := extract_str_from_union(output_type): + forced_mode = 'tool' output_type_ = output_type - if output_type_other_than_str := extract_str_from_union(output_type): - allow_text_output = True - tool_output_type = output_type_other_than_str.value + allow_json_text_output = False + allow_plain_text_output = True + tool_output_type = output_type_other_than_str.value + else: + output_type_ = output_type output_object_schema = OutputObjectSchema( output_type=output_type_, name=name, description=description, strict=strict @@ -292,7 +299,8 @@ def build( forced_mode=forced_mode, object_schema=output_object_schema, tools=tools, - allow_text_output=allow_text_output, + allow_plain_text_output=allow_plain_text_output, + allow_json_text_output=allow_json_text_output, ) def find_named_tool( @@ -341,8 +349,7 @@ def validate( def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: - """Check if the result schema allows text results.""" - return output_schema is None or output_schema.allow_text_output + return output_schema is None or output_schema.allow_plain_text_output or output_schema.allow_json_text_output @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b4b35a076..0b31aee54 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1013,8 +1013,6 @@ async def stream_to_final( elif isinstance(new_part, _messages.ToolCallPart) and output_schema: for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) - elif isinstance(new_part, _messages.OutputPart) and output_schema: - return FinalResult(s, None, None) return None final_result_details = await stream_to_final(streamed_response) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 36f40f4b9..e972f9498 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -479,21 +479,6 @@ def has_content(self) -> bool: return bool(self.content) -@dataclass -class OutputPart: - """An output response from a model.""" - - content: str - """The output content of the response as a JSON-serialized string.""" - - part_kind: Literal['output'] = 'output' - """Part type identifier, this is available on all parts as a discriminator.""" - - def has_content(self) -> bool: - """Return `True` if the output content is non-empty.""" - return bool(self.content) - - @dataclass class ToolCallPart: """A tool call from a model.""" @@ -548,7 +533,7 @@ def has_content(self) -> bool: return bool(self.args) -ModelResponsePart = Annotated[Union[TextPart, OutputPart, ToolCallPart], pydantic.Discriminator('part_kind')] +ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')] """A message part returned by a model.""" @@ -654,33 +639,6 @@ def apply(self, part: ModelResponsePart) -> TextPart: return replace(part, content=part.content + self.content_delta) -@dataclass -class OutputPartDelta: - """A partial update (delta) for a `OutputPart` to append new structured output content.""" - - content_delta: str - """The incremental structured output content to add to the existing `OutputPart` content.""" - - part_delta_kind: Literal['output'] = 'output' - """Part delta type identifier, used as a discriminator.""" - - def apply(self, part: ModelResponsePart) -> OutputPart: - """Apply this structured output delta to an existing `OutputPart`. - - Args: - part: The existing model response part, which must be a `OutputPart`. - - Returns: - A new `OutputPart` with updated structured output content. - - Raises: - ValueError: If `part` is not a `OutputPart`. - """ - if not isinstance(part, OutputPart): - raise ValueError('Cannot apply OutputPartDeltas to non-OutputParts') - return replace(part, content=part.content + self.content_delta) - - @dataclass class ToolCallPartDelta: """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID.""" @@ -798,9 +756,7 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: return part -ModelResponsePartDelta = Annotated[ - Union[TextPartDelta, OutputPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind') -] +ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] """A partial update (delta) for any model response part.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 8cdc53bb3..5eb7c8256 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -266,7 +266,7 @@ class ModelRequestParameters: output_mode: OutputMode | None = None output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) - allow_text_output: bool = True + require_tool_use: bool = True class Model(ABC): diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 0d5592ff7..1618bc17a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -21,7 +21,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -213,7 +212,7 @@ async def _messages_create( if not tools: tool_choice = None else: - if not model_request_parameters.allow_text_output: + if model_request_parameters.require_tool_use: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} @@ -322,7 +321,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me elif isinstance(m, ModelResponse): assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = [] for response_part in m.parts: - if isinstance(response_part, (TextPart, OutputPart)): + if isinstance(response_part, TextPart): assistant_content_params.append(TextBlockParam(text=response_part.content, type='text')) else: tool_use_block_param = ToolUseBlockParam( diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 1b86f0291..b87816800 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -305,7 +305,7 @@ async def _messages_create( support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic')) if not tools or not support_tools_choice: tool_choice: ToolChoiceTypeDef = {} - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = {'any': {}} # pragma: no cover else: tool_choice = {'auto': {}} diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index a3fa1ecbb..5c2ef6bb9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -13,7 +13,6 @@ ModelRequest, ModelResponse, ModelResponsePart, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -206,7 +205,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: texts: list[str] = [] tool_calls: list[ToolCallV2] = [] for item in message.parts: - if isinstance(item, (TextPart, OutputPart)): + if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index d20277715..934707cce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -21,7 +21,6 @@ ModelRequest, ModelResponse, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -92,7 +91,7 @@ async def request( ) -> ModelResponse: agent_info = AgentInfo( model_request_parameters.function_tools, - model_request_parameters.allow_text_output, + not model_request_parameters.require_tool_use, model_request_parameters.output_tools, model_settings, ) @@ -121,7 +120,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, - model_request_parameters.allow_text_output, + not model_request_parameters.require_tool_use, model_request_parameters.output_tools, model_settings, ) @@ -267,7 +266,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: assert_never(part) elif isinstance(message, ModelResponse): for part in message.parts: - if isinstance(part, (TextPart, OutputPart)): + if isinstance(part, TextPart): response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolCallPart): call = part diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index ba0ab499f..c256c5ec9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -28,7 +28,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -183,7 +182,7 @@ def _customize_output_object_def(o: OutputObjectDefinition): if model_request_parameters.output_object else None, output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools], - allow_text_output=model_request_parameters.allow_text_output, + require_tool_use=model_request_parameters.require_tool_use, ) @property @@ -205,7 +204,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if model_request_parameters.allow_text_output: + if not model_request_parameters.require_tool_use: return None elif tools: return _tool_config([t['name'] for t in tools['function_declarations']]) @@ -563,7 +562,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent: for item in m.parts: if isinstance(item, ToolCallPart): parts.append(_function_call_part_from_call(item)) - elif isinstance(item, (TextPart, OutputPart)): + elif isinstance(item, TextPart): if item.content: parts.append(_GeminiTextPart(text=item.content)) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 28e9287b0..1b0cb4ea2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -20,7 +20,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -195,7 +194,7 @@ async def _completions_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = 'required' else: tool_choice = 'auto' @@ -275,7 +274,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletio texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: - if isinstance(item, (TextPart, OutputPart)): + if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 7a026103e..394c49fad 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -22,7 +22,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -284,7 +283,7 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: return 'required' else: return 'auto' @@ -482,7 +481,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: tool_calls: list[MistralToolCall] = [] for part in message.parts: - if isinstance(part, (TextPart, OutputPart)): + if isinstance(part, TextPart): content_chunks.append(MistralTextChunk(text=part.content)) elif isinstance(part, ToolCallPart): tool_calls.append(self._map_tool_call(part)) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 3778d4951..46e03381b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -26,7 +26,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -269,14 +268,10 @@ async def _completions_create( openai_messages = await self._map_messages(messages) tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - tool_choice: Literal['none', 'required', 'auto'] | NotGiven = NOT_GIVEN - response_format: chat.completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN + response_format: chat.completion_create_params.ResponseFormat | None = None if model_request_parameters.output_mode == 'tool': tools.extend(self._map_tool_definition(r) for r in model_request_parameters.output_tools) - - if not model_request_parameters.allow_text_output: - tool_choice = 'required' elif output_object := model_request_parameters.output_object: if model_request_parameters.output_mode == 'json_schema': response_format = self._map_output_object_definition(output_object) @@ -288,6 +283,13 @@ async def _completions_create( ), ) + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif model_request_parameters.require_tool_use: + tool_choice = 'required' + else: + tool_choice = 'auto' + try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) @@ -305,7 +307,7 @@ async def _completions_create( temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), - response_format=response_format, + response_format=response_format or NOT_GIVEN, seed=model_settings.get('seed', NOT_GIVEN), presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), @@ -349,11 +351,7 @@ def _process_response( } if choice.message.content is not None: - if model_request_parameters.output_mode in {'json_schema', 'manual_json'}: - # TODO: Strip Markdown fence and text before/after - items.append(OutputPart(choice.message.content)) - else: - items.append(TextPart(choice.message.content)) + items.append(TextPart(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) @@ -398,7 +396,7 @@ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCom texts: list[str] = [] tool_calls: list[chat.ChatCompletionMessageToolCallParam] = [] for item in message.parts: - if isinstance(item, (TextPart, OutputPart)): + if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) @@ -637,11 +635,7 @@ def _process_response( """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc) items: list[ModelResponsePart] = [] - # TODO: Parse out manual JSON, a la split_content_into_text_and_thinking - if model_request_parameters.output_mode == 'json_schema': - items.append(OutputPart(response.output_text)) - else: - items.append(TextPart(response.output_text)) + items.append(TextPart(response.output_text)) for item in response.output: if item.type == 'function_call': items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id)) @@ -694,7 +688,7 @@ async def _responses_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = 'required' else: tool_choice = 'auto' @@ -793,7 +787,7 @@ async def _map_messages( assert_never(part) elif isinstance(message, ModelResponse): for item in message.parts: - if isinstance(item, (TextPart, OutputPart)): + if isinstance(item, TextPart): openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content)) elif isinstance(item, ToolCallPart): openai_messages.append(self._map_tool_call(item)) @@ -1184,5 +1178,5 @@ def _customize_output_object_def(o: OutputObjectDefinition): if model_request_parameters.output_object else None, output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools], - allow_text_output=model_request_parameters.allow_text_output, + require_tool_use=model_request_parameters.require_tool_use, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 01f459566..9079afa5e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -17,7 +17,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - OutputPart, RetryPromptPart, TextPart, ToolCallPart, @@ -131,7 +130,7 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: - assert model_request_parameters.allow_text_output, ( + assert not model_request_parameters.require_tool_use, ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' @@ -146,7 +145,7 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap return _WrappedToolOutput({k: self.custom_output_args}) else: return _WrappedToolOutput(self.custom_output_args) - elif model_request_parameters.allow_text_output: + elif not model_request_parameters.require_tool_use: return _WrappedTextOutput(None) elif model_request_parameters.output_tools: return _WrappedToolOutput(None) @@ -242,7 +241,7 @@ def __post_init__(self, _messages: Iterable[ModelMessage]): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): - if isinstance(part, (TextPart, OutputPart)): + if isinstance(part, TextPart): text = part.content *words, last_word = text.split(' ') words = [f'{word} ' for word in words] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index f8b5658af..bb9f0026a 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -80,6 +80,7 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: @@ -89,32 +90,20 @@ async def _validate_response( call, output_tool = match result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data - elif ( - self._output_schema is not None - and len(message.parts) > 0 - and (part := message.parts[0]) - and isinstance(part, _messages.OutputPart) - ): - result_data = self._output_schema.validate( - part.content, allow_partial=allow_partial, wrap_validation_errors=False - ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, None, self._run_ctx) - return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate( - text, - None, - self._run_ctx, + + if self._output_schema is None or self._output_schema.allow_plain_text_output: + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + else: + result_data = self._output_schema.validate( + text, allow_partial=allow_partial, wrap_validation_errors=False ) - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -139,9 +128,6 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. return _messages.FinalResultEvent( tool_name=call.tool_name, tool_call_id=call.tool_call_id ) - elif isinstance(new_part, _messages.OutputPart): - if output_schema: - return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) elif _output.allow_text_output(output_schema): # pragma: no branch assert_type(e, _messages.PartStartEvent) return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) @@ -330,7 +316,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if self._output_schema and not self._output_schema.allow_text_output: + if self._output_schema and not self._output_schema.allow_plain_text_output: raise exceptions.UserError('stream_text() can only be used with text responses') if delta: @@ -408,6 +394,7 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: @@ -417,40 +404,20 @@ async def validate_structured_output( call, output_tool = match result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data - elif ( - self._output_schema is not None - and len(message.parts) > 0 - and (part := message.parts[0]) - and isinstance(part, _messages.OutputPart) - ): - result_data = self._output_schema.validate( - part.content, allow_partial=allow_partial, wrap_validation_errors=False - ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover - return result_data - elif ( - self._output_schema is not None - and len(message.parts) > 0 - and (part := message.parts[0]) - and isinstance(part, _messages.OutputPart) - ): - result_data = self._output_schema.validate( - part.content, allow_partial=allow_partial, wrap_validation_errors=False - ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, None, self._run_ctx) - return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + + if self._output_schema is None or self._output_schema.allow_plain_text_output: + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + else: + result_data = self._output_schema.validate( + text, allow_partial=allow_partial, wrap_validation_errors=False + ) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover + return result_data async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 805281ced..ad6c1bab1 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -127,7 +127,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', 'gen_ai.system': 'function', @@ -200,7 +200,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', 'gen_ai.system': 'function', @@ -272,7 +272,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'fallback:function,function', 'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:', - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:', diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 84859fae8..db0ac6210 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -64,7 +64,7 @@ async def test_model_simple(allow_model_requests: None): assert 'x-goog-api-key' in m.client.headers mrp = ModelRequestParameters( - function_tools=[], allow_text_output=True, output_tools=[], output_mode=None, output_object=None + function_tools=[], require_tool_use=False, output_tools=[], output_mode=None, output_object=None ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) @@ -100,7 +100,7 @@ async def test_model_tools(allow_model_requests: None): mrp = ModelRequestParameters( function_tools=tools, - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_mode=None, output_object=None, @@ -148,7 +148,7 @@ async def test_require_response_tool(allow_model_requests: None): ) mrp = ModelRequestParameters( function_tools=[], - allow_text_output=False, + require_tool_use=True, output_tools=[output_tool], output_mode=None, output_object=None, @@ -235,7 +235,7 @@ class Locations(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_mode=None, output_object=None, @@ -321,7 +321,7 @@ class QueryDetails(BaseModel): mrp = ModelRequestParameters( function_tools=[], output_mode=None, - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_object=None, ) @@ -368,7 +368,7 @@ class Locations(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_mode=None, output_object=None, @@ -438,7 +438,7 @@ class Location(BaseModel): with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): mrp = ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_mode=None, output_object=None, @@ -476,7 +476,7 @@ class FormattedStringFields(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[output_tool], output_mode=None, output_object=None, diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 51e743975..5bb8204db 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -132,7 +132,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], output_mode=None, output_object=None, @@ -153,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -330,7 +330,7 @@ async def test_instrumented_model_not_recording(): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], output_mode=None, output_object=None, @@ -354,7 +354,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], output_mode=None, output_object=None, @@ -381,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -444,7 +444,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], output_mode=None, output_object=None, @@ -468,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -549,7 +549,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], output_mode=None, output_object=None, @@ -570,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 008648f34..5a0918211 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -5,12 +5,12 @@ def test_model_request_parameters_are_serializable(): params = ModelRequestParameters( - function_tools=[], output_mode=None, allow_text_output=False, output_tools=[], output_object=None + function_tools=[], output_mode=None, require_tool_use=False, output_tools=[], output_object=None ) assert TypeAdapter(ModelRequestParameters).dump_python(params) == { 'function_tools': [], 'preferred_output_mode': None, - 'allow_text_output': False, + 'require_tool_use': False, 'output_tools': [], 'output_object': None, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 4816f31d2..e3a37f541 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -23,7 +23,6 @@ ImageUrl, ModelRequest, ModelResponse, - OutputPart, RetryPromptPart, SystemPromptPart, TextPart, @@ -1731,7 +1730,9 @@ async def get_user_country() -> str: ] ), ModelResponse( - parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR') + ], usage=Usage( requests=1, request_tokens=71, @@ -1754,13 +1755,13 @@ async def get_user_country() -> str: ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id=IsStr(), + tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR', timestamp=IsDatetime(), ) ] ), ModelResponse( - parts=[OutputPart(content='{"city":"Mexico City","country":"Mexico"}')], + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=Usage( requests=1, request_tokens=92, @@ -1810,7 +1811,9 @@ async def get_user_country() -> str: ] ), ModelResponse( - parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl') + ], usage=Usage( requests=1, request_tokens=106, @@ -1833,13 +1836,13 @@ async def get_user_country() -> str: ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id=IsStr(), + tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl', timestamp=IsDatetime(), ) ] ), ModelResponse( - parts=[OutputPart(content='{"city":"Mexico City","country":"Mexico"}')], + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=Usage( requests=1, request_tokens=127, diff --git a/tests/test_agent.py b/tests/test_agent.py index 5e93df4e5..70e4d3d6d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -12,6 +12,7 @@ from pydantic_core import to_json from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai._output import ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -260,7 +261,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: args_json = '{"response": ["foo", "bar"]}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) - agent = Agent(FunctionModel(return_tuple), output_type=tuple[str, str]) + agent = Agent(FunctionModel(return_tuple), output_type=ToolOutput(type_=tuple[str, str])) result = agent.run_sync('Hello') assert result.output == ('foo', 'bar') @@ -352,14 +353,14 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_plain_text_output is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is False + assert m.last_model_request_parameters.require_tool_use is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -409,7 +410,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: got_tool_call_name = ctx.tool_name return o - assert agent._output_schema.allow_text_output is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_plain_text_output is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] result = agent.run_sync('Hello') assert result.output == snapshot('success (no tool calls)') @@ -417,7 +418,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is True + assert m.last_model_request_parameters.require_tool_use is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -493,7 +494,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is False + assert m.last_model_request_parameters.require_tool_use is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 2 diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 57ee697e6..eff2c62ba 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -223,10 +223,10 @@ async def my_ret(x: int) -> str: 'strict': None, } ], - 'preferred_output_mode': None, - 'allow_text_output': True, + 'output_mode': None, 'output_tools': [], 'output_object': None, + 'require_tool_use': False, } ) ), @@ -406,7 +406,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'test', 'gen_ai.request.model': 'test', - 'model_request_parameters': '{"function_tools": [], "preferred_output_mode": null, "allow_text_output": true, "output_tools": [], "output_object": null}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat test', 'gen_ai.usage.input_tokens': 51,