From 9c64d6854add24227456c2f2118726a30e7b1dbb Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sat, 22 Feb 2025 09:52:34 -0700 Subject: [PATCH] WIP: Add support for the use of response_format to force a particular json schema for the response --- .../pydantic_ai_examples/flight_booking.py | 4 +- examples/pydantic_ai_examples/rag.py | 12 ++-- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 43 +++++++++++++ pydantic_ai_slim/pydantic_ai/_result.py | 1 + pydantic_ai_slim/pydantic_ai/messages.py | 48 +++++++++++++- pydantic_ai_slim/pydantic_ai/models/cohere.py | 34 +++++++++- .../pydantic_ai/models/function.py | 6 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 31 +++++++-- pydantic_ai_slim/pydantic_ai/models/groq.py | 1 + pydantic_ai_slim/pydantic_ai/models/openai.py | 64 +++++++++++++++---- pydantic_ai_slim/pydantic_ai/models/test.py | 12 ++-- pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/settings.py | 19 ++++++ 13 files changed, 235 insertions(+), 42 deletions(-) diff --git a/examples/pydantic_ai_examples/flight_booking.py b/examples/pydantic_ai_examples/flight_booking.py index 8935d711d..209e2adfd 100644 --- a/examples/pydantic_ai_examples/flight_booking.py +++ b/examples/pydantic_ai_examples/flight_booking.py @@ -105,7 +105,9 @@ class Failed(BaseModel): # This agent is responsible for extracting the user's seat selection -seat_preference_agent = Agent[None, SeatPreference | Failed]( +seat_preference_agent = Agent[ + None, SeatPreference | Failed +]( 'openai:gpt-4o', result_type=SeatPreference | Failed, # type: ignore system_prompt=( diff --git a/examples/pydantic_ai_examples/rag.py b/examples/pydantic_ai_examples/rag.py index b7dd4c4b9..0ad864331 100644 --- a/examples/pydantic_ai_examples/rag.py +++ b/examples/pydantic_ai_examples/rag.py @@ -67,9 +67,9 @@ async def retrieve(context: RunContext[Deps], search_query: str) -> str: model='text-embedding-3-small', ) - assert len(embedding.data) == 1, ( - f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}' - ) + assert ( + len(embedding.data) == 1 + ), f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}' embedding = embedding.data[0].embedding embedding_json = pydantic_core.to_json(embedding).decode() rows = await context.deps.pool.fetch( @@ -149,9 +149,9 @@ async def insert_doc_section( input=section.embedding_content(), model='text-embedding-3-small', ) - assert len(embedding.data) == 1, ( - f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}' - ) + assert ( + len(embedding.data) == 1 + ), f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}' embedding = embedding.data[0].embedding embedding_json = pydantic_core.to_json(embedding).decode() await pool.execute( diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b080acfc0..a9f303449 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -10,6 +10,7 @@ from typing import Any, Generic, Literal, Union, cast import logfire_api +from pydantic import ValidationError from typing_extensions import TypeVar, assert_never from pydantic_graph import BaseNode, Graph, GraphRunContext @@ -370,12 +371,16 @@ async def _run_stream( async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: texts: list[str] = [] + structured_outputs: list[str] = [] tool_calls: list[_messages.ToolCallPart] = [] for part in self.model_response.parts: if isinstance(part, _messages.TextPart): # ignore empty content for text parts, see #437 if part.content: texts.append(part.content) + elif isinstance(part, _messages.StructuredOutputPart): + if part.content: + structured_outputs.append(part.content) elif isinstance(part, _messages.ToolCallPart): tool_calls.append(part) else: @@ -391,6 +396,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: elif texts: # No events are emitted during the handling of text responses, so we don't need to yield anything self._next_node = await self._handle_text_response(ctx, texts) + elif structured_outputs: + # No events are emitted during the handling of text responses, so we don't need to yield anything + self._next_node = await self._handle_structured_outputs_response(ctx, texts) else: raise exceptions.UnexpectedModelBehavior('Received empty model response') @@ -487,6 +495,41 @@ async def _handle_text_response( ) ) + async def _handle_structured_outputs_response( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + structured_outputs: list[str], + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: + if len(structured_outputs) != 1: + raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response') + result_schema = ctx.deps.result_schema + if not result_schema: + raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs') + + structured_output = structured_outputs[0] + try: + result_data_input = result_schema.structured_output_validator.validate_json(structured_output) + except ValidationError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries) + return ModelRequestNode[DepsT, NodeRunEndT]( + _messages.ModelRequest( + parts=[ + _messages.RetryPromptPart( + content='Structured output validation failed: ' + str(e), + ) + ] + ) + ) + + try: + result_data = await _validate_result(result_data_input, ctx, None) + except _result.ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries) + return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + else: + # The following cast is safe because we know `str` is an allowed result type + return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), []) + def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index b8a1686d2..ffeceadc5 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -83,6 +83,7 @@ class ResultSchema(Generic[ResultDataT]): Similar to `Tool` but for the final result of running an agent. """ + structured_output_validator: TypeAdapter[ResultDataT] tools: dict[str, ResultTool[ResultDataT]] allow_text_result: bool diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index c6775c838..805001f9e 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -217,7 +217,22 @@ def has_content(self) -> bool: return bool(self.args) -ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')] +@dataclass +class StructuredOutputPart: + """A structured output response from a model.""" + + content: str + """The structured content of the response as a JSON-serialized string.""" + + part_kind: Literal['structured-output'] = 'structured-output' + """Part type identifier, this is available on all parts as a discriminator.""" + + def has_content(self) -> bool: + """Return `True` if the structured content is non-empty.""" + return bool(self.content) + + +ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, StructuredOutputPart], pydantic.Discriminator('part_kind')] """A message part returned by a model.""" @@ -275,6 +290,33 @@ def apply(self, part: ModelResponsePart) -> TextPart: return replace(part, content=part.content + self.content_delta) +@dataclass +class StructuredOutputPartDelta: + """A partial update (delta) for a `StructuredOutputPart` to append new text content.""" + + content_delta: str + """The incremental text content to add to the existing `StructuredOutputPart` content.""" + + part_delta_kind: Literal['structured-output'] = 'structured-output' + """Part delta type identifier, used as a discriminator.""" + + def apply(self, part: ModelResponsePart) -> StructuredOutputPart: + """Apply this text delta to an existing `TextPart`. + + Args: + part: The existing model response part, which must be a `TextPart`. + + Returns: + A new `TextPart` with updated text content. + + Raises: + ValueError: If `part` is not a `TextPart`. + """ + if not isinstance(part, StructuredOutputPart): + raise ValueError('Cannot apply TextPartDeltas to non-TextParts') + return replace(part, content=part.content + self.content_delta) + + @dataclass class ToolCallPartDelta: """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID.""" @@ -408,7 +450,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: return part -ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] +ModelResponsePartDelta = Annotated[ + Union[TextPartDelta, ToolCallPartDelta, StructuredOutputPartDelta], pydantic.Discriminator('part_delta_kind') +] """A partial update (delta) for any model response part.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 58f58627d..2d368b661 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -5,7 +5,6 @@ from itertools import chain from typing import Literal, Union, cast -from cohere import TextAssistantMessageContentItem from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never @@ -17,6 +16,7 @@ ModelResponse, ModelResponsePart, RetryPromptPart, + StructuredOutputPart, SystemPromptPart, TextPart, ToolCallPart, @@ -37,7 +37,11 @@ AsyncClientV2, ChatMessageV2, ChatResponse, + JsonObjectResponseFormatV2, + ResponseFormatV2, SystemChatMessageV2, + TextAssistantMessageContentItem, + TextResponseFormatV2, ToolCallV2, ToolCallV2Function, ToolChatMessageV2, @@ -152,7 +156,30 @@ async def _chat( model_settings: CohereModelSettings, model_request_parameters: ModelRequestParameters, ) -> ChatResponse: - tools = self._get_tools(model_request_parameters) + if model_settings.get('force_response_format', False): + tools: list[ToolV2] = OMIT + response_format: ResponseFormatV2 + if (n_result_tools := len(model_request_parameters.result_tools)) == 0: + response_format = TextResponseFormatV2() + elif n_result_tools == 1 and not model_request_parameters.allow_text_result: + result_tool = model_request_parameters.result_tools[0] + response_format = JsonObjectResponseFormatV2( + type='json_object', + json_schema=result_tool.parameters_json_schema, + ) + else: + json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools] + if model_request_parameters.allow_text_result: + json_schemas.append({'type': 'string'}) + response_format = JsonObjectResponseFormatV2( + type='json_object', + json_schema={'anyOf': json_schemas}, + ) + else: + # standalone function to make it easier to override + tools = self._get_tools(model_request_parameters) + response_format = OMIT + cohere_messages = list(chain(*(self._map_message(m) for m in messages))) return await self.client.chat( model=self._model_name, @@ -162,6 +189,7 @@ async def _chat( temperature=model_settings.get('temperature', OMIT), p=model_settings.get('top_p', OMIT), seed=model_settings.get('seed', OMIT), + response_format=response_format, presence_penalty=model_settings.get('presence_penalty', OMIT), frequency_penalty=model_settings.get('frequency_penalty', OMIT), ) @@ -193,7 +221,7 @@ def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]: texts: list[str] = [] tool_calls: list[ToolCallV2] = [] for item in message.parts: - if isinstance(item, TextPart): + if isinstance(item, (TextPart, StructuredOutputPart)): texts.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 7cd81f2b7..9cecfe709 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -109,9 +109,9 @@ async def request_stream( model_settings, ) - assert self.stream_function is not None, ( - 'FunctionModel must receive a `stream_function` to support streamed requests' - ) + assert ( + self.stream_function is not None + ), 'FunctionModel must receive a `stream_function` to support streamed requests' response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info)) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 8f869690d..ee7e27de4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -183,19 +183,31 @@ async def _make_request( model_settings: GeminiModelSettings, model_request_parameters: ModelRequestParameters, ) -> AsyncIterator[HTTPResponse]: - tools = self._get_tools(model_request_parameters) - tool_config = self._get_tool_config(model_request_parameters, tools) sys_prompt_parts, contents = self._message_to_gemini_content(messages) - request_data = _GeminiRequest(contents=contents) if sys_prompt_parts: request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) - if tools is not None: - request_data['tools'] = tools - if tool_config is not None: - request_data['tool_config'] = tool_config generation_config: _GeminiGenerationConfig = {} + if model_settings.get('force_response_format', False): + if (n_result_tools := len(model_request_parameters.result_tools)) == 0: + generation_config['response_mimetype'] = 'text/plain' + elif n_result_tools == 1 and not model_request_parameters.allow_text_result: + generation_config['response_mimetype'] = 'application/json' + generation_config['response_schema'] = model_request_parameters.result_tools[0].parameters_json_schema + else: + json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools] + if model_request_parameters.allow_text_result: + json_schemas.append({'type': 'string'}) + generation_config['response_schema'] = {'anyOf': json_schemas} + else: + tools = self._get_tools(model_request_parameters) + tool_config = self._get_tool_config(model_request_parameters, tools) + if tools is not None: + request_data['tools'] = tools + if tool_config is not None: + request_data['tool_config'] = tool_config + if model_settings: if (max_tokens := model_settings.get('max_tokens')) is not None: generation_config['max_output_tokens'] = max_tokens @@ -203,6 +215,8 @@ async def _make_request( generation_config['temperature'] = temperature if (top_p := model_settings.get('top_p')) is not None: generation_config['top_p'] = top_p + if (seed := model_settings.get('seed')) is not None: + generation_config['seed'] = seed if (presence_penalty := model_settings.get('presence_penalty')) is not None: generation_config['presence_penalty'] = presence_penalty if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: @@ -465,9 +479,12 @@ class _GeminiGenerationConfig(TypedDict, total=False): See for API docs. """ + response_mimetype: Literal['text/plain', 'application/json'] + response_schema: dict[str, Any] max_output_tokens: int temperature: float top_p: float + seed: int presence_penalty: float frequency_penalty: float diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index b5829627c..38db55e54 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -202,6 +202,7 @@ async def _completions_create( tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, + response_format=response_format, max_tokens=model_settings.get('max_tokens', NOT_GIVEN), temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 15dc93438..383d1b702 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -20,6 +20,7 @@ ModelResponsePart, ModelResponseStreamEvent, RetryPromptPart, + StructuredOutputPart, SystemPromptPart, TextPart, ToolCallPart, @@ -37,7 +38,7 @@ ) try: - from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream + from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream, NotGiven from openai.types import ChatModel, chat from openai.types.chat import ChatCompletionChunk except ImportError as _import_error: @@ -145,7 +146,9 @@ async def request( response = await self._completions_create( messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) - return self._process_response(response), _map_usage(response) + return self._process_response(response, model_settings.get('force_response_format', False)), _map_usage( + response + ) @asynccontextmanager async def request_stream( @@ -198,18 +201,49 @@ async def _completions_create( model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: - tools = self._get_tools(model_request_parameters) - - # standalone function to make it easier to override - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_result: - tool_choice = 'required' + if model_settings.get('force_response_format', False): + tools: list[chat.ChatCompletionToolParam] | NotGiven = NOT_GIVEN + tool_choice: Literal['none', 'required', 'auto'] | None | NotGiven = NOT_GIVEN + response_format: chat.completion_create_params.ResponseFormat | NotGiven + + if (n_result_tools := len(model_request_parameters.result_tools)) == 0: + response_format = chat.completion_create_params.ResponseFormatText(type='text') # pyright: ignore[reportPrivateImportUsage] + elif n_result_tools == 1 and not model_request_parameters.allow_text_result: + result_tool = model_request_parameters.result_tools[0] + response_format = chat.completion_create_params.ResponseFormatJSONSchema( # pyright: ignore[reportPrivateImportUsage] + type='json_schema', + json_schema={ + 'name': result_tool.name, + 'description': result_tool.description, + 'schema': result_tool.parameters_json_schema, + 'strict': False, # TODO: Expose this via a model setting? + }, + ) + else: + json_schemas = [t.parameters_json_schema for t in model_request_parameters.result_tools] + if model_request_parameters.allow_text_result: + json_schemas.append({'type': 'string'}) + response_format = chat.completion_create_params.ResponseFormatJSONSchema( # pyright: ignore[reportPrivateImportUsage] + type='json_schema', + json_schema={ + 'name': 'final_result', + 'description': 'The final result of the model', + 'schema': {'anyOf': json_schemas}, + 'strict': False, + }, + ) else: - tool_choice = 'auto' + # standalone function to make it easier to override + tools = self._get_tools(model_request_parameters) + if not tools: + tool_choice = None + elif not model_request_parameters.allow_text_result: + tool_choice = 'required' + else: + tool_choice = 'auto' + response_format = NOT_GIVEN openai_messages = list(chain(*(self._map_message(m) for m in messages))) - return await self.client.chat.completions.create( model=self._model_name, messages=openai_messages, @@ -223,6 +257,7 @@ async def _completions_create( temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), + response_format=response_format, seed=model_settings.get('seed', NOT_GIVEN), presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), @@ -230,13 +265,16 @@ async def _completions_create( reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), ) - def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: + def _process_response(self, response: chat.ChatCompletion, force_response_format: bool) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) choice = response.choices[0] items: list[ModelResponsePart] = [] if choice.message.content is not None: - items.append(TextPart(choice.message.content)) + if force_response_format: + items.append(StructuredOutputPart(choice.message.content)) + else: + items.append(TextPart(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(c.function.name, c.function.arguments, c.id)) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 6890181b4..f1ddd0d61 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -130,15 +130,15 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_result(self, model_request_parameters: ModelRequestParameters) -> _TextResult | _FunctionToolResult: if self.custom_result_text is not None: - assert model_request_parameters.allow_text_result, ( - 'Plain response not allowed, but `custom_result_text` is set.' - ) + assert ( + model_request_parameters.allow_text_result + ), 'Plain response not allowed, but `custom_result_text` is set.' assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.' return _TextResult(self.custom_result_text) elif self.custom_result_args is not None: - assert model_request_parameters.result_tools is not None, ( - 'No result tools provided, but `custom_result_args` is set.' - ) + assert ( + model_request_parameters.result_tools is not None + ), 'No result tools provided, but `custom_result_args` is set.' result_tool = model_request_parameters.result_tools[0] if k := result_tool.outer_typed_dict_key: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 7646de5bf..3a4bbd83e 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -351,7 +351,7 @@ class FinalResult(Generic[ResultDataT]): data: ResultDataT """The final result data.""" tool_name: str | None - """Name of the final result tool; `None` if the result came from unstructured text content.""" + """Name of the final result tool; `None` if the result did not come from a tool call.""" def _get_usage_checking_stream_response( diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index 1380f83d3..a0087a306 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -92,6 +92,7 @@ class ModelSettings(TypedDict, total=False): Supported by: * OpenAI + * Gemini * Groq * Cohere * Mistral @@ -130,6 +131,24 @@ class ModelSettings(TypedDict, total=False): * Groq """ + force_response_format: bool + """Whether to force a specific response format from the model. + + TODO: Add a description of what this means and the pros/cons of using this. + Pros: Works better than tool calling with many "dumber" models + Cons: Forces the model to generate structured output, so the agent cannot make use of data retrieval tool calls + before generating a final response. + # TODO: Explain that this can be set on the model if you know you want to the agent in a way that doesn't require tool calls + + Supported by: + + * Cohere + * Gemini + * Groq + * OpenAI + """ + # TODO: I think Mistral should support this too, but need to confirm; that model is implemented quite differently + def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None: """Merge two sets of model settings, preferring the overrides.