diff --git a/docs/api/models/base.md b/docs/api/models/base.md index bf72de7e6..24fcb9bb8 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -5,7 +5,6 @@ members: - KnownModelName - Model - - AgentModel - AbstractToolDefinition - StreamedResponse - ALLOW_MODEL_REQUESTS diff --git a/docs/api/models/vertexai.md b/docs/api/models/vertexai.md index d59968c79..0c4d48f0c 100644 --- a/docs/api/models/vertexai.md +++ b/docs/api/models/vertexai.md @@ -2,8 +2,8 @@ Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models. -This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method -changed from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], it relies on the VertexAI +This model inherits from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] with just the URL and auth method +changed, it relies on the VertexAI [`generateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent) and [`streamGenerateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent) diff --git a/docs/models.md b/docs/models.md index 76fc53dae..945043491 100644 --- a/docs/models.md +++ b/docs/models.md @@ -515,9 +515,8 @@ agent = Agent(model) To implement support for models not already supported, you will need to subclass the [`Model`][pydantic_ai.models.Model] abstract base class. -This in turn will require you to implement the following other abstract base classes: +For streaming, you'll also need to implement the following abstract base class: -* [`AgentModel`][pydantic_ai.models.AgentModel] * [`StreamedResponse`][pydantic_ai.models.StreamedResponse] The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d48ecde20..a05d8fa9b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -296,10 +296,12 @@ async def main(): run_context.run_step += 1 with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step): - agent_model = await self._prepare_model(run_context, result_schema) + agent_request_config = await self._prepare_agent_request_config(run_context, result_schema) with _logfire.span('model request', run_step=run_context.run_step) as model_req_span: - model_response, request_usage = await agent_model.request(messages, model_settings) + model_response, request_usage = await model_used.request( + messages, model_settings, agent_request_config + ) model_req_span.set_attribute('response', model_response) model_req_span.set_attribute('usage', request_usage) @@ -527,10 +529,12 @@ async def main(): usage_limits.check_before_request(run_context.usage) with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step): - agent_model = await self._prepare_model(run_context, result_schema) + agent_request_config = await self._prepare_agent_request_config(run_context, result_schema) with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span: - async with agent_model.request_stream(messages, model_settings) as model_response: + async with model_used.request_stream( + messages, model_settings, agent_request_config + ) as model_response: run_context.usage.requests += 1 model_req_span.set_attribute('response_type', model_response.__class__.__name__) # We want to end the "model request" span here, but we can't exit the context manager @@ -998,9 +1002,9 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) - return model_ - async def _prepare_model( + async def _prepare_agent_request_config( self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None - ) -> models.AgentModel: + ) -> models.AgentRequestConfig: """Build tools and create an agent model.""" function_tools: list[ToolDefinition] = [] @@ -1011,7 +1015,7 @@ async def add_tool(tool: Tool[AgentDepsT]) -> None: await asyncio.gather(*map(add_tool, self._function_tools.values())) - return await run_context.model.agent_model( + return models.AgentRequestConfig( function_tools=function_tools, allow_text_result=self._allow_text_result(result_schema), result_tools=result_schema.tool_defs() if result_schema is not None else [], diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index d974c7b0b..8d1851d20 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -171,49 +171,36 @@ """ +@dataclass +class AgentRequestConfig: + function_tools: list[ToolDefinition] + allow_text_result: bool + result_tools: list[ToolDefinition] + + class Model(ABC): """Abstract class for a model.""" - @abstractmethod - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - """Create an agent model, this is called for each step of an agent run. - - This is async in case slow/async config checks need to be performed that can't be done in `__init__`. - - Args: - function_tools: The tools available to the agent. - allow_text_result: Whether a plain text final response/result is permitted. - result_tools: Tool definitions for the final result tool(s), if any. - - Returns: - An agent model. - """ - raise NotImplementedError() - @abstractmethod def name(self) -> str: raise NotImplementedError() - -class AgentModel(ABC): - """Model configured for each step of an Agent run.""" - @abstractmethod async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, Usage]: """Make a request to the model.""" raise NotImplementedError() @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index e1a482f87..e7d164843 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -28,7 +28,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, cached_async_http_client, @@ -134,24 +134,6 @@ def __init__( else: self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client()) - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in function_tools] - if result_tools: - tools += [self._map_tool_definition(r) for r in result_tools] - return AnthropicAgentModel( - self.client, - self.model_name, - allow_text_result, - tools, - ) - def name(self) -> str: return f'anthropic:{self.model_name}' @@ -163,52 +145,72 @@ def _map_tool_definition(f: ToolDefinition) -> ToolParam: 'input_schema': f.parameters_json_schema, } - -@dataclass -class AnthropicAgentModel(AgentModel): - """Implementation of `AgentModel` for Anthropic models.""" - - client: AsyncAnthropic - model_name: AnthropicModelName - allow_text_result: bool - tools: list[ToolParam] + def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolParam]: + tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] + if agent_request_config.result_tools: + tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools] + return tools async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, usage.Usage]: - response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._messages_create( + messages, False, cast(AnthropicModelSettings, model_settings or {}), agent_request_config + ) return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: - response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {})) + response = await self._messages_create( + messages, True, cast(AnthropicModelSettings, model_settings or {}), agent_request_config + ) async with response: yield await self._process_streamed_response(response) @overload async def _messages_create( - self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: AnthropicModelSettings, + agent_request_config: AgentRequestConfig, ) -> AsyncStream[RawMessageStreamEvent]: pass @overload async def _messages_create( - self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: AnthropicModelSettings, + agent_request_config: AgentRequestConfig, ) -> AnthropicMessage: pass async def _messages_create( - self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings + self, + messages: list[ModelMessage], + stream: bool, + model_settings: AnthropicModelSettings, + agent_request_config: AgentRequestConfig, ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]: # standalone function to make it easier to override + tools = self._get_tools(agent_request_config) tool_choice: ToolChoiceParam | None - if not self.tools: + if not tools: tool_choice = None else: - if not self.allow_text_result: + if not agent_request_config.allow_text_result: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} @@ -223,7 +225,7 @@ async def _messages_create( system=system_prompt or NOT_GIVEN, messages=anthropic_messages, model=self.model_name, - tools=self.tools or NOT_GIVEN, + tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, temperature=model_settings.get('temperature', NOT_GIVEN), diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index d66eade82..c315017e4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -25,7 +25,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, check_allow_model_requests, ) @@ -114,24 +114,6 @@ def __init__( else: self.client = AsyncClientV2(api_key=api_key) # type: ignore - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in function_tools] - if result_tools: - tools += [self._map_tool_definition(r) for r in result_tools] - return CohereAgentModel( - self.client, - self.model_name, - allow_text_result, - tools, - ) - def name(self) -> str: return f'cohere:{self.model_name}' @@ -146,32 +128,34 @@ def _map_tool_definition(f: ToolDefinition) -> ToolV2: ), ) - -@dataclass -class CohereAgentModel(AgentModel): - """Implementation of `AgentModel` for Cohere models.""" - - client: AsyncClientV2 - model_name: CohereModelName - allow_text_result: bool - tools: list[ToolV2] + def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolV2]: + tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] + if agent_request_config.result_tools: + tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools] + return tools async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, result.Usage]: - response = await self._chat(messages, cast(CohereModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), agent_request_config) return self._process_response(response), _map_usage(response) async def _chat( self, messages: list[ModelMessage], model_settings: CohereModelSettings, + agent_request_config: AgentRequestConfig, ) -> ChatResponse: + tools = self._get_tools(agent_request_config) cohere_messages = list(chain(*(self._map_message(m) for m in messages))) return await self.client.chat( model=self.model_name, messages=cohere_messages, - tools=self.tools or OMIT, + tools=tools or OMIT, max_tokens=model_settings.get('max_tokens', OMIT), temperature=model_settings.get('temperature', OMIT), p=model_settings.get('top_p', OMIT), diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 58390b3df..56792dda9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -4,7 +4,7 @@ import re from collections.abc import AsyncIterator, Awaitable, Iterable from contextlib import asynccontextmanager -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from datetime import datetime from itertools import chain from typing import Callable, Union @@ -27,7 +27,7 @@ ) from ..settings import ModelSettings from ..tools import ToolDefinition -from . import AgentModel, Model, StreamedResponse +from . import AgentRequestConfig, Model, StreamedResponse @dataclass(init=False) @@ -63,24 +63,64 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre self.function = function self.stream_function = stream_function - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - return FunctionAgentModel( - self.function, - self.stream_function, - AgentInfo(function_tools, allow_text_result, result_tools, None), - ) - def name(self) -> str: function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' return f'function:{function_name}:{stream_function_name}' + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, + ) -> tuple[ModelResponse, usage.Usage]: + agent_info = AgentInfo( + agent_request_config.function_tools, + agent_request_config.allow_text_result, + agent_request_config.result_tools, + model_settings, + ) + + assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' + model_name = f'function:{self.function.__name__}' + + if inspect.iscoroutinefunction(self.function): + response = await self.function(messages, agent_info) + else: + response_ = await _utils.run_in_executor(self.function, messages, agent_info) + assert isinstance(response_, ModelResponse), response_ + response = response_ + response.model_name = model_name + # TODO is `messages` right here? Should it just be new messages? + return response, _estimate_usage(chain(messages, [response])) + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, + ) -> AsyncIterator[StreamedResponse]: + agent_info = AgentInfo( + agent_request_config.function_tools, + agent_request_config.allow_text_result, + agent_request_config.result_tools, + model_settings, + ) + + assert ( + self.stream_function is not None + ), 'FunctionModel must receive a `stream_function` to support streamed requests' + model_name = f'function:{self.stream_function.__name__}' + + response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info)) + + first = await response_stream.peek() + if isinstance(first, _utils.Unset): + raise ValueError('Stream function must return at least one item') + + yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream) + @dataclass(frozen=True) class AgentInfo: @@ -119,9 +159,11 @@ class DeltaToolCall: DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] """A mapping of tool call IDs to incremental changes.""" +# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, AgentRequestConfig], ...] FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]] """A function used to generate a non-streamed response.""" +# TODO: Change signature as indicated above StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]] """A function used to generate a streamed response. @@ -132,50 +174,6 @@ class DeltaToolCall: """ -@dataclass -class FunctionAgentModel(AgentModel): - """Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - - function: FunctionDef | None - stream_function: StreamFunctionDef | None - agent_info: AgentInfo - - async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, usage.Usage]: - agent_info = replace(self.agent_info, model_settings=model_settings) - - assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' - model_name = f'function:{self.function.__name__}' - - if inspect.iscoroutinefunction(self.function): - response = await self.function(messages, agent_info) - else: - response_ = await _utils.run_in_executor(self.function, messages, agent_info) - assert isinstance(response_, ModelResponse), response_ - response = response_ - response.model_name = model_name - # TODO is `messages` right here? Should it just be new messages? - return response, _estimate_usage(chain(messages, [response])) - - @asynccontextmanager - async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[StreamedResponse]: - assert ( - self.stream_function is not None - ), 'FunctionModel must receive a `stream_function` to support streamed requests' - model_name = f'function:{self.stream_function.__name__}' - - response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info)) - - first = await response_stream.peek() - if isinstance(first, _utils.Unset): - raise ValueError('Stream function must return at least one item') - - yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream) - - @dataclass class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index e49c62eca..eb341e3d9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -31,7 +31,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, cached_async_http_client, @@ -65,9 +65,10 @@ class GeminiModel(Model): """ model_name: GeminiModelName - auth: AuthProtocol http_client: AsyncHTTPClient - url: str + + _auth: AuthProtocol | None + _url: str | None def __init__( self, @@ -94,115 +95,84 @@ def __init__( api_key = env_api_key else: raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable') - self.auth = ApiKeyAuth(api_key) self.http_client = http_client or cached_async_http_client() - self.url = url_template.format(model=model_name) + self._auth = ApiKeyAuth(api_key) + self._url = url_template.format(model=model_name) - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> GeminiAgentModel: - check_allow_model_requests() - return GeminiAgentModel( - http_client=self.http_client, - model_name=self.model_name, - auth=self.auth, - url=self.url, - function_tools=function_tools, - allow_text_result=allow_text_result, - result_tools=result_tools, - ) + @property + def auth(self) -> AuthProtocol: + assert self._auth is not None, 'Auth not initialized' + return self._auth + + @property + def url(self) -> str: + assert self._url is not None, 'URL not initialized' + return self._url def name(self) -> str: return f'google-gla:{self.model_name}' - -class AuthProtocol(Protocol): - """Abstract definition for Gemini authentication.""" - - async def headers(self) -> dict[str, str]: ... - - -@dataclass -class ApiKeyAuth: - """Authentication using an API key for the `X-Goog-Api-Key` header.""" - - api_key: str - - async def headers(self) -> dict[str, str]: - # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest - return {'X-Goog-Api-Key': self.api_key} - - -@dataclass(init=False) -class GeminiAgentModel(AgentModel): - """Implementation of `AgentModel` for Gemini models.""" - - http_client: AsyncHTTPClient - model_name: GeminiModelName - auth: AuthProtocol - tools: _GeminiTools | None - tool_config: _GeminiToolConfig | None - url: str - - def __init__( - self, - http_client: AsyncHTTPClient, - model_name: GeminiModelName, - auth: AuthProtocol, - url: str, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ): - tools = [_function_from_abstract_tool(t) for t in function_tools] - if result_tools: - tools += [_function_from_abstract_tool(t) for t in result_tools] - - if allow_text_result: - tool_config = None - else: - tool_config = _tool_config([t['name'] for t in tools]) - - self.http_client = http_client - self.model_name = model_name - self.auth = auth - self.tools = _GeminiTools(function_declarations=tools) if tools else None - self.tool_config = tool_config - self.url = url - async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, usage.Usage]: + check_allow_model_requests() async with self._make_request( - messages, False, cast(GeminiModelSettings, model_settings or {}) + messages, False, cast(GeminiModelSettings, model_settings or {}), agent_request_config ) as http_response: response = _gemini_response_ta.validate_json(await http_response.aread()) return self._process_response(response), _metadata_as_usage(response) @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: - async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response: + check_allow_model_requests() + async with self._make_request( + messages, True, cast(GeminiModelSettings, model_settings or {}), agent_request_config + ) as http_response: yield await self._process_streamed_response(http_response) + def _get_tools(self, agent_request_config: AgentRequestConfig) -> _GeminiTools | None: + tools = [_function_from_abstract_tool(t) for t in agent_request_config.function_tools] + if agent_request_config.result_tools: + tools += [_function_from_abstract_tool(t) for t in agent_request_config.result_tools] + return _GeminiTools(function_declarations=tools) if tools else None + + def _get_tool_config( + self, agent_request_config: AgentRequestConfig, tools: _GeminiTools | None + ) -> _GeminiToolConfig | None: + if agent_request_config.allow_text_result: + return None + elif tools: + return _tool_config([t['name'] for t in tools['function_declarations']]) + else: + return _tool_config([]) + @asynccontextmanager async def _make_request( - self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings + self, + messages: list[ModelMessage], + streamed: bool, + model_settings: GeminiModelSettings, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[HTTPResponse]: + tools = self._get_tools(agent_request_config) + tool_config = self._get_tool_config(agent_request_config, tools) sys_prompt_parts, contents = self._message_to_gemini_content(messages) request_data = _GeminiRequest(contents=contents) if sys_prompt_parts: request_data['system_instruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) - if self.tools is not None: - request_data['tools'] = self.tools - if self.tool_config is not None: - request_data['tool_config'] = self.tool_config + if tools is not None: + request_data['tools'] = tools + if tool_config is not None: + request_data['tool_config'] = tool_config generation_config: _GeminiGenerationConfig = {} if model_settings: @@ -306,6 +276,23 @@ def _message_to_gemini_content( return sys_prompt_parts, contents +class AuthProtocol(Protocol): + """Abstract definition for Gemini authentication.""" + + async def headers(self) -> dict[str, str]: ... + + +@dataclass +class ApiKeyAuth: + """Authentication using an API key for the `X-Goog-Api-Key` header.""" + + api_key: str + + async def headers(self) -> dict[str, str]: + # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest + return {'X-Goog-Api-Key': self.api_key} + + @dataclass class GeminiStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for the Gemini model.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 5b9787083..1c833210d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -28,7 +28,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, cached_async_http_client, @@ -112,24 +112,6 @@ def __init__( else: self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client()) - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in function_tools] - if result_tools: - tools += [self._map_tool_definition(r) for r in result_tools] - return GroqAgentModel( - self.client, - self.model_name, - allow_text_result, - tools, - ) - def name(self) -> str: return f'groq:{self.model_name}' @@ -144,49 +126,70 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: }, } - -@dataclass -class GroqAgentModel(AgentModel): - """Implementation of `AgentModel` for Groq models.""" - - client: AsyncGroq - model_name: str - allow_text_result: bool - tools: list[chat.ChatCompletionToolParam] + def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[chat.ChatCompletionToolParam]: + tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] + if agent_request_config.result_tools: + tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools] + return tools async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, usage.Usage]: - response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(GroqModelSettings, model_settings or {}), agent_request_config + ) return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: - response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(GroqModelSettings, model_settings or {}), agent_request_config + ) async with response: yield await self._process_streamed_response(response) @overload async def _completions_create( - self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: GroqModelSettings, + agent_request_config: AgentRequestConfig, ) -> AsyncStream[ChatCompletionChunk]: pass @overload async def _completions_create( - self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: GroqModelSettings, + agent_request_config: AgentRequestConfig, ) -> chat.ChatCompletion: pass async def _completions_create( - self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings + self, + messages: list[ModelMessage], + stream: bool, + model_settings: GroqModelSettings, + agent_request_config: AgentRequestConfig, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + tools = self._get_tools(agent_request_config) # standalone function to make it easier to override - if not self.tools: + if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not self.allow_text_result: + elif not agent_request_config.allow_text_result: tool_choice = 'required' else: tool_choice = 'auto' @@ -198,7 +201,7 @@ async def _completions_create( messages=groq_messages, n=1, parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), - tools=self.tools or NOT_GIVEN, + tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, max_tokens=model_settings.get('max_tokens', NOT_GIVEN), diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0c8bfe497..49a9eea30 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -31,7 +31,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, cached_async_http_client, @@ -101,6 +101,7 @@ class MistralModel(Model): model_name: MistralModelName client: Mistral = field(repr=False) + json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""" def __init__( self, @@ -109,6 +110,7 @@ def __init__( api_key: str | Callable[[], str | None] | None = None, client: Mistral | None = None, http_client: AsyncHTTPClient | None = None, + json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""", ): """Initialize a Mistral model. @@ -117,8 +119,10 @@ def __init__( api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable. client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. + json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input. """ self.model_name = model_name + self.json_mode_schema_prompt = json_mode_schema_prompt if client is not None: assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`' @@ -128,64 +132,50 @@ def __init__( api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client()) - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - """Create an agent model, this is called for each step of an agent run from Pydantic AI call.""" - check_allow_model_requests() - return MistralAgentModel( - self.client, - self.model_name, - allow_text_result, - function_tools, - result_tools, - ) - def name(self) -> str: return f'mistral:{self.model_name}' - -@dataclass -class MistralAgentModel(AgentModel): - """Implementation of `AgentModel` for Mistral models.""" - - client: Mistral - model_name: MistralModelName - allow_text_result: bool - function_tools: list[ToolDefinition] - result_tools: list[ToolDefinition] - json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""" - async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, Usage]: """Make a non-streaming request to the model from Pydantic AI call.""" - response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._completions_create( + messages, cast(MistralModelSettings, model_settings or {}), agent_request_config + ) return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" - response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._stream_completions_create( + messages, cast(MistralModelSettings, model_settings or {}), agent_request_config + ) async with response: - yield await self._process_streamed_response(self.result_tools, response) + yield await self._process_streamed_response(agent_request_config.result_tools, response) async def _completions_create( - self, messages: list[ModelMessage], model_settings: MistralModelSettings + self, + messages: list[ModelMessage], + model_settings: MistralModelSettings, + agent_request_config: AgentRequestConfig, ) -> MistralChatCompletionResponse: """Make a non-streaming request to the model.""" response = await self.client.chat.complete_async( model=str(self.model_name), messages=list(chain(*(self._map_message(m) for m in messages))), n=1, - tools=self._map_function_and_result_tools_definition() or UNSET, - tool_choice=self._get_tool_choice(), + tools=self._map_function_and_result_tools_definition(agent_request_config) or UNSET, + tool_choice=self._get_tool_choice(agent_request_config), stream=False, max_tokens=model_settings.get('max_tokens', UNSET), temperature=model_settings.get('temperature', UNSET), @@ -200,19 +190,24 @@ async def _stream_completions_create( self, messages: list[ModelMessage], model_settings: MistralModelSettings, + agent_request_config: AgentRequestConfig, ) -> MistralEventStreamAsync[MistralCompletionEvent]: """Create a streaming completion request to the Mistral model.""" response: MistralEventStreamAsync[MistralCompletionEvent] | None mistral_messages = list(chain(*(self._map_message(m) for m in messages))) - if self.result_tools and self.function_tools or self.function_tools: + if ( + agent_request_config.result_tools + and agent_request_config.function_tools + or agent_request_config.function_tools + ): # Function Calling response = await self.client.chat.stream_async( model=str(self.model_name), messages=mistral_messages, n=1, - tools=self._map_function_and_result_tools_definition() or UNSET, - tool_choice=self._get_tool_choice(), + tools=self._map_function_and_result_tools_definition(agent_request_config) or UNSET, + tool_choice=self._get_tool_choice(agent_request_config), temperature=model_settings.get('temperature', UNSET), top_p=model_settings.get('top_p', 1), max_tokens=model_settings.get('max_tokens', UNSET), @@ -221,9 +216,9 @@ async def _stream_completions_create( frequency_penalty=model_settings.get('frequency_penalty'), ) - elif self.result_tools: + elif agent_request_config.result_tools: # Json Mode - parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools] + parameters_json_schemas = [tool.parameters_json_schema for tool in agent_request_config.result_tools] user_output_format_message = self._generate_user_output_format(parameters_json_schemas) mistral_messages.append(user_output_format_message) @@ -244,7 +239,7 @@ async def _stream_completions_create( assert response, 'A unexpected empty response from Mistral.' return response - def _get_tool_choice(self) -> MistralToolChoiceEnum | None: + def _get_tool_choice(self, agent_request_config: AgentRequestConfig) -> MistralToolChoiceEnum | None: """Get tool choice for the model. - "auto": Default mode. Model decides if it uses the tool or not. @@ -252,19 +247,21 @@ def _get_tool_choice(self) -> MistralToolChoiceEnum | None: - "none": Prevents tool use. - "required": Forces tool use. """ - if not self.function_tools and not self.result_tools: + if not agent_request_config.function_tools and not agent_request_config.result_tools: return None - elif not self.allow_text_result: + elif not agent_request_config.allow_text_result: return 'required' else: return 'auto' - def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None: + def _map_function_and_result_tools_definition( + self, agent_request_config: AgentRequestConfig + ) -> list[MistralTool] | None: """Map function and result tools to MistralTool format. Returns None if both function_tools and result_tools are empty. """ - all_tools: list[ToolDefinition] = self.function_tools + self.result_tools + all_tools: list[ToolDefinition] = agent_request_config.function_tools + agent_request_config.result_tools tools = [ MistralTool( function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description) diff --git a/pydantic_ai_slim/pydantic_ai/models/ollama.py b/pydantic_ai_slim/pydantic_ai/models/ollama.py index ae3820167..5d6cf593a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/ollama.py +++ b/pydantic_ai_slim/pydantic_ai/models/ollama.py @@ -3,16 +3,6 @@ from dataclasses import dataclass from typing import Literal, Union -from httpx import AsyncClient as AsyncHTTPClient - -from ..tools import ToolDefinition -from . import ( - AgentModel, - Model, - cached_async_http_client, - check_allow_model_requests, -) - try: from openai import AsyncOpenAI except ImportError as e: @@ -58,7 +48,7 @@ @dataclass(init=False) -class OllamaModel(Model): +class OllamaModel(OpenAIModel): """A model that implements Ollama using the OpenAI API. Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server. @@ -66,58 +56,5 @@ class OllamaModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - model_name: OllamaModelName - openai_model: OpenAIModel - - def __init__( - self, - model_name: OllamaModelName, - *, - base_url: str | None = 'http://localhost:11434/v1/', - api_key: str = 'ollama', - openai_client: AsyncOpenAI | None = None, - http_client: AsyncHTTPClient | None = None, - ): - """Initialize an Ollama model. - - Ollama has built-in compatibility for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the - [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here. - - Args: - model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library) - You must first download the model (`ollama pull `) in order to use the model - base_url: The base url for the ollama requests. The default value is the ollama default - api_key: The API key to use for authentication. Defaults to 'ollama' for local instances, - but can be customized for proxy setups that require authentication - openai_client: An existing - [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) - client to use, if provided, `base_url` and `http_client` must be `None`. - http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. - """ - self.model_name = model_name - if openai_client is not None: - assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' - assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' - self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client) - else: - # API key is not required for ollama but a value is required to create the client - http_client_ = http_client or cached_async_http_client() - oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_) - self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client) - - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - check_allow_model_requests() - return await self.openai_model.agent_model( - function_tools=function_tools, - allow_text_result=allow_text_result, - result_tools=result_tools, - ) - def name(self) -> str: return f'ollama:{self.model_name}' diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 7ad5e1482..885002d73 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -28,7 +28,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, cached_async_http_client, @@ -54,10 +54,10 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user'] -class OpenAIModelSettings(ModelSettings): +class OpenAIModelSettings(ModelSettings, total=False): """Settings used for an OpenAI model request.""" - # This class is a placeholder for any future openai-specific settings + use_structured_response_format: bool @dataclass(init=False) @@ -112,25 +112,6 @@ def __init__( self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client()) self.system_prompt_role = system_prompt_role - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - check_allow_model_requests() - tools = [self._map_tool_definition(r) for r in function_tools] - if result_tools: - tools += [self._map_tool_definition(r) for r in result_tools] - return OpenAIAgentModel( - self.client, - self.model_name, - allow_text_result, - tools, - self.system_prompt_role, - ) - def name(self) -> str: return f'openai:{self.model_name}' @@ -145,50 +126,84 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: }, } - -@dataclass -class OpenAIAgentModel(AgentModel): - """Implementation of `AgentModel` for OpenAI models.""" - - client: AsyncOpenAI - model_name: OpenAIModelName - allow_text_result: bool - tools: list[chat.ChatCompletionToolParam] - system_prompt_role: OpenAISystemPromptRole | None - async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, usage.Usage]: - response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(OpenAIModelSettings, model_settings or {}), agent_request_config + ) return self._process_response(response), _map_usage(response) + def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[chat.ChatCompletionToolParam]: + tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] + if agent_request_config.result_tools: + tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools] + return tools + @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: - response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {})) + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(OpenAIModelSettings, model_settings or {}), agent_request_config + ) async with response: yield await self._process_streamed_response(response) @overload async def _completions_create( - self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: OpenAIModelSettings, + agent_request_config: AgentRequestConfig, ) -> AsyncStream[ChatCompletionChunk]: pass @overload async def _completions_create( - self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: OpenAIModelSettings, + agent_request_config: AgentRequestConfig, ) -> chat.ChatCompletion: pass async def _completions_create( - self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings + self, + messages: list[ModelMessage], + stream: bool, + model_settings: OpenAIModelSettings, + agent_request_config: AgentRequestConfig, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + tools = self._get_tools(agent_request_config) + + if model_settings.get('use_structured_response_format'): + tools = [] + result_tools = agent_request_config.result_tools + if len(result_tools) == 0: + raise ValueError('structured response_format requires at least one result tool') + elif len(result_tools) == 1: + json_schema = agent_request_config.result_tools[0].parameters_json_schema + else: + json_schema = {'anyOf': [tool.parameters_json_schema for tool in result_tools]} + response_format = {'type': 'json_schema', 'json_schema': json_schema} + else: + response_format = NOT_GIVEN + # standalone function to make it easier to override - if not self.tools: + if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not self.allow_text_result: + elif not agent_request_config.allow_text_result: tool_choice = 'required' else: tool_choice = 'auto' @@ -200,9 +215,10 @@ async def _completions_create( messages=openai_messages, n=1, parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), - tools=self.tools or NOT_GIVEN, + tools=tools or NOT_GIVEN, tool_choice=tool_choice or NOT_GIVEN, stream=stream, + response_format=response_format, stream_options={'include_usage': True} if stream else NOT_GIVEN, max_tokens=model_settings.get('max_tokens', NOT_GIVEN), temperature=model_settings.get('temperature', NOT_GIVEN), diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 878df563f..7cd51d80d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -26,7 +26,7 @@ from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( - AgentModel, + AgentRequestConfig, Model, StreamedResponse, ) @@ -87,86 +87,79 @@ class TestModel(Model): This is set when the model is called, so will reflect the result tools from the last step of the last run. """ - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> AgentModel: - self.agent_model_function_tools = function_tools - self.agent_model_allow_text_result = allow_text_result - self.agent_model_result_tools = result_tools - - if self.call_tools == 'all': - tool_calls = [(r.name, r) for r in function_tools] - else: - function_tools_lookup = {t.name: t for t in function_tools} - tools_to_call = (function_tools_lookup[name] for name in self.call_tools) - tool_calls = [(r.name, r) for r in tools_to_call] - - if self.custom_result_text is not None: - assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.' - assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.' - result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text) - elif self.custom_result_args is not None: - assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.' - result_tool = result_tools[0] - - if k := result_tool.outer_typed_dict_key: - result = _FunctionToolResult({k: self.custom_result_args}) - else: - result = _FunctionToolResult(self.custom_result_args) - elif allow_text_result: - result = _TextResult(None) - elif result_tools: - result = _FunctionToolResult(None) - else: - result = _TextResult(None) - - return TestAgentModel(tool_calls, result, result_tools, self.seed) - def name(self) -> str: return 'test-model' - -@dataclass -class TestAgentModel(AgentModel): - """Implementation of `AgentModel` for testing purposes.""" - - # NOTE: Avoid test discovery by pytest. - __test__ = False - - tool_calls: list[tuple[str, ToolDefinition]] - # left means the text is plain text; right means it's a function call - result: _TextResult | _FunctionToolResult - result_tools: list[ToolDefinition] - seed: int - model_name: str = 'test' - async def request( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> tuple[ModelResponse, Usage]: - model_response = self._request(messages, model_settings) + model_response = self._request(messages, model_settings, agent_request_config) usage = _estimate_usage([*messages, model_response]) return model_response, usage @asynccontextmanager async def request_stream( - self, messages: list[ModelMessage], model_settings: ModelSettings | None + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, ) -> AsyncIterator[StreamedResponse]: - model_response = self._request(messages, model_settings) - yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages) + model_response = self._request(messages, model_settings, agent_request_config) + yield TestStreamedResponse(_model_name=self.name(), _structured_response=model_response, _messages=messages) def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() - def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse: + def _get_tool_calls(self, agent_request_config: AgentRequestConfig) -> list[tuple[str, ToolDefinition]]: + if self.call_tools == 'all': + return [(r.name, r) for r in agent_request_config.function_tools] + else: + function_tools_lookup = {t.name: t for t in agent_request_config.function_tools} + tools_to_call = (function_tools_lookup[name] for name in self.call_tools) + return [(r.name, r) for r in tools_to_call] + + def _get_result(self, agent_request_config: AgentRequestConfig) -> _TextResult | _FunctionToolResult: + if self.custom_result_text is not None: + assert ( + agent_request_config.allow_text_result + ), 'Plain response not allowed, but `custom_result_text` is set.' + assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.' + return _TextResult(self.custom_result_text) + elif self.custom_result_args is not None: + assert ( + agent_request_config.result_tools is not None + ), 'No result tools provided, but `custom_result_args` is set.' + result_tool = agent_request_config.result_tools[0] + + if k := result_tool.outer_typed_dict_key: + return _FunctionToolResult({k: self.custom_result_args}) + else: + return _FunctionToolResult(self.custom_result_args) + elif agent_request_config.allow_text_result: + return _TextResult(None) + elif agent_request_config.result_tools: + return _FunctionToolResult(None) + else: + return _TextResult(None) + + def _request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, + ) -> ModelResponse: + tool_calls = self._get_tool_calls(agent_request_config) + result = self._get_result(agent_request_config) + result_tools = agent_request_config.result_tools + # if there are tools, the first thing we want to do is call all of them - if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages): + if tool_calls and not any(isinstance(m, ModelResponse) for m in messages): return ModelResponse( - parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls], - model_name=self.model_name, + parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls], + model_name=self.name(), ) if messages: @@ -179,28 +172,26 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | # Handle retries for both function tools and result tools # Check function tools first retry_parts: list[ModelResponsePart] = [ - ToolCallPart(name, self.gen_tool_args(args)) - for name, args in self.tool_calls - if name in new_retry_names + ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names ] # Check result tools - if self.result_tools: + if result_tools: retry_parts.extend( [ ToolCallPart( tool.name, - self.result.value - if isinstance(self.result, _FunctionToolResult) and self.result.value is not None + result.value + if isinstance(result, _FunctionToolResult) and result.value is not None else self.gen_tool_args(tool), ) - for tool in self.result_tools + for tool in result_tools if tool.name in new_retry_names ] ) - return ModelResponse(parts=retry_parts, model_name=self.model_name) + return ModelResponse(parts=retry_parts, model_name=self.name()) - if isinstance(self.result, _TextResult): - if (response_text := self.result.value) is None: + if isinstance(result, _TextResult): + if (response_text := result.value) is None: # build up details of tool responses output: dict[str, Any] = {} for message in messages: @@ -210,23 +201,21 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | output[part.tool_name] = part.content if output: return ModelResponse( - parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name + parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.name() ) else: - return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name) + return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.name()) else: - return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name) + return ModelResponse(parts=[TextPart(response_text)], model_name=self.name()) else: - assert self.result_tools, 'No result tools provided' - custom_result_args = self.result.value - result_tool = self.result_tools[self.seed % len(self.result_tools)] + assert result_tools, 'No result tools provided' + custom_result_args = result.value + result_tool = result_tools[self.seed % len(result_tools)] if custom_result_args is not None: - return ModelResponse( - parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name - ) + return ModelResponse(parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.name()) else: response_args = self.gen_tool_args(result_tool) - return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name) + return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.name()) @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/models/vertexai.py b/pydantic_ai_slim/pydantic_ai/models/vertexai.py index 2bfdd195f..cb006dfd6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/vertexai.py +++ b/pydantic_ai_slim/pydantic_ai/models/vertexai.py @@ -1,5 +1,7 @@ from __future__ import annotations as _annotations +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path @@ -7,11 +9,13 @@ from httpx import AsyncClient as AsyncHTTPClient +from .. import usage from .._utils import run_in_executor from ..exceptions import UserError -from ..tools import ToolDefinition -from . import Model, cached_async_http_client, check_allow_model_requests -from .gemini import GeminiAgentModel, GeminiModelName +from ..messages import ModelMessage, ModelResponse +from ..settings import ModelSettings +from . import AgentRequestConfig, StreamedResponse, cached_async_http_client +from .gemini import GeminiModel, GeminiModelName try: import google.auth @@ -52,20 +56,15 @@ @dataclass(init=False) -class VertexAIModel(Model): +class VertexAIModel(GeminiModel): """A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API.""" - model_name: GeminiModelName service_account_file: Path | str | None project_id: str | None region: VertexAiRegion model_publisher: Literal['google'] - http_client: AsyncHTTPClient url_template: str - auth: BearerTokenAuth | None - url: str | None - # TODO __init__ can be removed once we drop 3.9 and we can set kw_only correctly on the dataclass def __init__( self, @@ -104,35 +103,16 @@ def __init__( self.http_client = http_client or cached_async_http_client() self.url_template = url_template - self.auth = None - self.url = None - - async def agent_model( - self, - *, - function_tools: list[ToolDefinition], - allow_text_result: bool, - result_tools: list[ToolDefinition], - ) -> GeminiAgentModel: - check_allow_model_requests() - url, auth = await self.ainit() - return GeminiAgentModel( - http_client=self.http_client, - model_name=self.model_name, - auth=auth, - url=url, - function_tools=function_tools, - allow_text_result=allow_text_result, - result_tools=result_tools, - ) + self._auth = None + self._url = None - async def ainit(self) -> tuple[str, BearerTokenAuth]: + async def ainit(self) -> None: """Initialize the model, setting the URL and auth. This will raise an error if authentication fails. """ - if self.url is not None and self.auth is not None: - return self.url, self.auth + if self._url is not None and self._auth is not None: + return if self.service_account_file is not None: creds: BaseCredentials | ServiceAccountCredentials = _creds_from_file(self.service_account_file) @@ -155,18 +135,37 @@ async def ainit(self) -> tuple[str, BearerTokenAuth]: ) project_id = self.project_id - self.url = url = self.url_template.format( + self._url = self.url_template.format( region=self.region, project_id=project_id, model_publisher=self.model_publisher, model=self.model_name, ) - self.auth = auth = BearerTokenAuth(creds) - return url, auth + self._auth = BearerTokenAuth(creds) def name(self) -> str: return f'google-vertex:{self.model_name}' + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, + ) -> tuple[ModelResponse, usage.Usage]: + await self.ainit() + return await super().request(messages, model_settings, agent_request_config) + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + agent_request_config: AgentRequestConfig, + ) -> AsyncIterator[StreamedResponse]: + await self.ainit() + async with super().request_stream(messages, model_settings, agent_request_config) as value: + yield value + # pyright: reportUnknownMemberType=false def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials: diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 8f96d093e..d214e74fc 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -24,6 +24,7 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models import AgentRequestConfig from pydantic_ai.models.gemini import ( ApiKeyAuth, GeminiModel, @@ -77,13 +78,17 @@ def test_api_key_empty(env: TestEnv): async def test_agent_model_simple(allow_model_requests: None): m = GeminiModel('gemini-1.5-flash', api_key='via-arg') - agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) - assert isinstance(agent_model.http_client, httpx.AsyncClient) - assert agent_model.model_name == 'gemini-1.5-flash' - assert isinstance(agent_model.auth, ApiKeyAuth) - assert agent_model.auth.api_key == 'via-arg' - assert agent_model.tools is None - assert agent_model.tool_config is None + # agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + assert isinstance(m.http_client, httpx.AsyncClient) + assert m.model_name == 'gemini-1.5-flash' + assert isinstance(m.auth, ApiKeyAuth) + assert m.auth.api_key == 'via-arg' + + arc = AgentRequestConfig(function_tools=[], allow_text_result=True, result_tools=[]) + tools = m._get_tools(arc) + tool_config = m._get_tool_config(arc, tools) + assert tools is None + assert tool_config is None async def test_agent_model_tools(allow_model_requests: None): @@ -110,8 +115,11 @@ async def test_agent_model_tools(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, ) - agent_model = await m.agent_model(function_tools=tools, allow_text_result=True, result_tools=[result_tool]) - assert agent_model.tools == snapshot( + + arc = AgentRequestConfig(function_tools=tools, allow_text_result=True, result_tools=[result_tool]) + tools = m._get_tools(arc) + tool_config = m._get_tool_config(arc, tools) + assert tools == snapshot( _GeminiTools( function_declarations=[ _GeminiFunction( @@ -139,7 +147,7 @@ async def test_agent_model_tools(allow_model_requests: None): ] ) ) - assert agent_model.tool_config is None + assert tool_config is None async def test_require_response_tool(allow_model_requests: None): @@ -149,8 +157,10 @@ async def test_require_response_tool(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - agent_model = await m.agent_model(function_tools=[], allow_text_result=False, result_tools=[result_tool]) - assert agent_model.tools == snapshot( + arc = AgentRequestConfig(function_tools=[], allow_text_result=False, result_tools=[result_tool]) + tools = m._get_tools(arc) + tool_config = m._get_tool_config(arc, tools) + assert tools == snapshot( _GeminiTools( function_declarations=[ _GeminiFunction( @@ -164,7 +174,7 @@ async def test_require_response_tool(allow_model_requests: None): ] ) ) - assert agent_model.tool_config == snapshot( + assert tool_config == snapshot( _GeminiToolConfig( function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=['result']) ) @@ -206,8 +216,9 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool]) - assert agent_model.tools == snapshot( + assert m._get_tools( + AgentRequestConfig(function_tools=[], allow_text_result=True, result_tools=[result_tool]) + ) == snapshot( _GeminiTools( function_declarations=[ _GeminiFunction( @@ -252,8 +263,9 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool]) - assert agent_model.tools == snapshot( + assert m._get_tools( + AgentRequestConfig(function_tools=[], allow_text_result=True, result_tools=[result_tool]) + ) == snapshot( _GeminiTools( function_declarations=[ _GeminiFunction( @@ -315,7 +327,7 @@ class Location(BaseModel): json_schema, ) with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): - await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool]) + m._get_tools(AgentRequestConfig(function_tools=[], allow_text_result=True, result_tools=[result_tool])) async def test_json_def_date(allow_model_requests: None): @@ -346,8 +358,9 @@ class FormattedStringFields(BaseModel): 'This is the tool for the final Result', json_schema, ) - agent_model = await m.agent_model(function_tools=[], allow_text_result=True, result_tools=[result_tool]) - assert agent_model.tools == snapshot( + assert m._get_tools( + AgentRequestConfig(function_tools=[], allow_text_result=True, result_tools=[result_tool]) + ) == snapshot( _GeminiTools( function_declarations=[ _GeminiFunction( diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index d03e86580..ae5eb7b30 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -48,7 +48,6 @@ from mistralai.types.basemodel import Unset as MistralUnset from pydantic_ai.models.mistral import ( - MistralAgentModel, MistralModel, MistralStreamedResponse, ) @@ -1668,8 +1667,8 @@ def test_generate_user_output_format_complex(): 'prop_unrecognized_type': {'type': 'customSomething'}, } } - mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}') - result = mam._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage] + m = MistralModel('', json_mode_schema_prompt='{schema}') + result = m._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage] assert result.content == ( "{'prop_anyOf': 'Optional[str]', " "'prop_no_type': 'Any', " @@ -1685,8 +1684,8 @@ def test_generate_user_output_format_complex(): def test_generate_user_output_format_multiple(): schema = {'properties': {'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}}} - mam = MistralAgentModel(Mistral(api_key=''), '', False, [], [], '{schema}') - result = mam._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage] + m = MistralModel('', json_mode_schema_prompt='{schema}') + result = m._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage] assert result.content == "[{'prop_anyOf': 'Optional[str]'}, {'prop_anyOf': 'Optional[str]'}]" diff --git a/tests/models/test_vertexai.py b/tests/models/test_vertexai.py index 679f0e128..4e434488f 100644 --- a/tests/models/test_vertexai.py +++ b/tests/models/test_vertexai.py @@ -33,7 +33,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None): assert model.url is None assert model.auth is None - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert model.url == snapshot( 'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/' @@ -58,7 +58,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None): assert patch.call_count == 0 - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert patch.call_count == 1 @@ -69,7 +69,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None): assert model.auth is not None assert model.name() == snapshot('google-vertex:gemini-1.5-flash') - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert model.url is not None assert model.auth is not None assert patch.call_count == 1 @@ -83,7 +83,7 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None) assert model.url is None assert model.auth is None - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert model.url == snapshot( 'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/' @@ -99,7 +99,7 @@ async def test_init_service_account_wrong_project_id(tmp_path: Path, allow_model model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='different') with pytest.raises(UserError) as exc_info: - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert str(exc_info.value) == snapshot( "The project_id you provided does not match the one from service account file: 'different' != 'my-project-id'" ) @@ -110,7 +110,7 @@ async def test_init_env_wrong_project_id(mocker: MockerFixture, allow_model_requ model = VertexAIModel('gemini-1.5-flash', project_id='different') with pytest.raises(UserError) as exc_info: - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert str(exc_info.value) == snapshot( "The project_id you provided does not match the one from `google.auth.default()`: 'different' != 'my-project-id'" ) @@ -124,7 +124,7 @@ async def test_init_env_no_project_id(mocker: MockerFixture, allow_model_request model = VertexAIModel('gemini-1.5-flash') with pytest.raises(UserError) as exc_info: - await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) + await model.ainit() assert str(exc_info.value) == snapshot('No project_id provided and none found in `google.auth.default()`')