Skip to content

Add support for structured output with OpenAI #807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api/models/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
members:
- KnownModelName
- Model
- AgentModel
- AbstractToolDefinition
- StreamedResponse
- ALLOW_MODEL_REQUESTS
Expand Down
4 changes: 2 additions & 2 deletions docs/api/models/vertexai.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models.

This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method
changed from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], it relies on the VertexAI
This model inherits from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] with just the URL and auth method
changed, it relies on the VertexAI
[`generateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
and
[`streamGenerateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)
Expand Down
3 changes: 1 addition & 2 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,8 @@ agent = Agent(model)

To implement support for models not already supported, you will need to subclass the [`Model`][pydantic_ai.models.Model] abstract base class.

This in turn will require you to implement the following other abstract base classes:
For streaming, you'll also need to implement the following abstract base class:

* [`AgentModel`][pydantic_ai.models.AgentModel]
* [`StreamedResponse`][pydantic_ai.models.StreamedResponse]

The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).
Expand Down
18 changes: 11 additions & 7 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ async def main():

run_context.run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context, result_schema)
agent_request_config = await self._prepare_agent_request_config(run_context, result_schema)

with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
model_response, request_usage = await agent_model.request(messages, model_settings)
model_response, request_usage = await model_used.request(
messages, model_settings, agent_request_config
)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('usage', request_usage)

Expand Down Expand Up @@ -527,10 +529,12 @@ async def main():
usage_limits.check_before_request(run_context.usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context, result_schema)
agent_request_config = await self._prepare_agent_request_config(run_context, result_schema)

with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
async with model_used.request_stream(
messages, model_settings, agent_request_config
) as model_response:
run_context.usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
Expand Down Expand Up @@ -998,9 +1002,9 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -

return model_

async def _prepare_model(
async def _prepare_agent_request_config(
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
) -> models.AgentModel:
) -> models.AgentRequestConfig:
"""Build tools and create an agent model."""
function_tools: list[ToolDefinition] = []

Expand All @@ -1011,7 +1015,7 @@ async def add_tool(tool: Tool[AgentDepsT]) -> None:

await asyncio.gather(*map(add_tool, self._function_tools.values()))

return await run_context.model.agent_model(
return models.AgentRequestConfig(
function_tools=function_tools,
allow_text_result=self._allow_text_result(result_schema),
result_tools=result_schema.tool_defs() if result_schema is not None else [],
Expand Down
43 changes: 15 additions & 28 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,49 +171,36 @@
"""


@dataclass
class AgentRequestConfig:
function_tools: list[ToolDefinition]
allow_text_result: bool
result_tools: list[ToolDefinition]


class Model(ABC):
"""Abstract class for a model."""

@abstractmethod
async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
"""Create an agent model, this is called for each step of an agent run.

This is async in case slow/async config checks need to be performed that can't be done in `__init__`.

Args:
function_tools: The tools available to the agent.
allow_text_result: Whether a plain text final response/result is permitted.
result_tools: Tool definitions for the final result tool(s), if any.

Returns:
An agent model.
"""
raise NotImplementedError()

@abstractmethod
def name(self) -> str:
raise NotImplementedError()


class AgentModel(ABC):
"""Model configured for each step of an Agent run."""

@abstractmethod
async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
agent_request_config: AgentRequestConfig,
) -> tuple[ModelResponse, Usage]:
"""Make a request to the model."""
raise NotImplementedError()

@asynccontextmanager
async def request_stream(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
agent_request_config: AgentRequestConfig,
) -> AsyncIterator[StreamedResponse]:
"""Make a request to the model and return a streaming response."""
# This method is not required, but you need to implement it if you want to support streamed responses
Expand Down
78 changes: 40 additions & 38 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
AgentModel,
AgentRequestConfig,
Model,
StreamedResponse,
cached_async_http_client,
Expand Down Expand Up @@ -134,24 +134,6 @@ def __init__(
else:
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())

async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return AnthropicAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)

def name(self) -> str:
return f'anthropic:{self.model_name}'

Expand All @@ -163,52 +145,72 @@ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
'input_schema': f.parameters_json_schema,
}


@dataclass
class AnthropicAgentModel(AgentModel):
"""Implementation of `AgentModel` for Anthropic models."""

client: AsyncAnthropic
model_name: AnthropicModelName
allow_text_result: bool
tools: list[ToolParam]
def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolParam]:
tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools]
if agent_request_config.result_tools:
tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools]
return tools

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
agent_request_config: AgentRequestConfig,
) -> tuple[ModelResponse, usage.Usage]:
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
check_allow_model_requests()
response = await self._messages_create(
messages, False, cast(AnthropicModelSettings, model_settings or {}), agent_request_config
)
return self._process_response(response), _map_usage(response)

@asynccontextmanager
async def request_stream(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
agent_request_config: AgentRequestConfig,
) -> AsyncIterator[StreamedResponse]:
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
response = await self._messages_create(
messages, True, cast(AnthropicModelSettings, model_settings or {}), agent_request_config
)
async with response:
yield await self._process_streamed_response(response)

@overload
async def _messages_create(
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
self,
messages: list[ModelMessage],
stream: Literal[True],
model_settings: AnthropicModelSettings,
agent_request_config: AgentRequestConfig,
) -> AsyncStream[RawMessageStreamEvent]:
pass

@overload
async def _messages_create(
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
self,
messages: list[ModelMessage],
stream: Literal[False],
model_settings: AnthropicModelSettings,
agent_request_config: AgentRequestConfig,
) -> AnthropicMessage:
pass

async def _messages_create(
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
self,
messages: list[ModelMessage],
stream: bool,
model_settings: AnthropicModelSettings,
agent_request_config: AgentRequestConfig,
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
# standalone function to make it easier to override
tools = self._get_tools(agent_request_config)
tool_choice: ToolChoiceParam | None

if not self.tools:
if not tools:
tool_choice = None
else:
if not self.allow_text_result:
if not agent_request_config.allow_text_result:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}
Expand All @@ -223,7 +225,7 @@ async def _messages_create(
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self.model_name,
tools=self.tools or NOT_GIVEN,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
temperature=model_settings.get('temperature', NOT_GIVEN),
Expand Down
46 changes: 15 additions & 31 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
AgentModel,
AgentRequestConfig,
Model,
check_allow_model_requests,
)
Expand Down Expand Up @@ -114,24 +114,6 @@ def __init__(
else:
self.client = AsyncClientV2(api_key=api_key) # type: ignore

async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return CohereAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)

def name(self) -> str:
return f'cohere:{self.model_name}'

Expand All @@ -146,32 +128,34 @@ def _map_tool_definition(f: ToolDefinition) -> ToolV2:
),
)


@dataclass
class CohereAgentModel(AgentModel):
"""Implementation of `AgentModel` for Cohere models."""

client: AsyncClientV2
model_name: CohereModelName
allow_text_result: bool
tools: list[ToolV2]
def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolV2]:
tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools]
if agent_request_config.result_tools:
tools += [self._map_tool_definition(r) for r in agent_request_config.result_tools]
return tools

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
agent_request_config: AgentRequestConfig,
) -> tuple[ModelResponse, result.Usage]:
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
check_allow_model_requests()
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), agent_request_config)
return self._process_response(response), _map_usage(response)

async def _chat(
self,
messages: list[ModelMessage],
model_settings: CohereModelSettings,
agent_request_config: AgentRequestConfig,
) -> ChatResponse:
tools = self._get_tools(agent_request_config)
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
return await self.client.chat(
model=self.model_name,
messages=cohere_messages,
tools=self.tools or OMIT,
tools=tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
Expand Down
Loading
Loading