From f70bffe01b2250525c60eefea33ad5ed5a1b07bf Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 17 Apr 2025 18:14:28 -0600 Subject: [PATCH] Add support for schema_generator to ToolOutput --- pydantic_ai_slim/pydantic_ai/_output.py | 30 ++++++++++++++++++++----- pydantic_ai_slim/pydantic_ai/result.py | 5 ++++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 2f10d3fe3..efd877503 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError +from pydantic.json_schema import GenerateJsonSchema from typing_extensions import TypedDict, TypeVar, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -104,8 +105,10 @@ def build( description = output_type.description output_type_ = output_type.output_type strict = output_type.strict + schema_generator = output_type.schema_generator else: output_type_ = output_type + schema_generator = GenerateToolJsonSchema if output_type_option := extract_str_from_union(output_type): output_type_ = output_type_option.value @@ -122,7 +125,12 @@ def build( tools[tool_name] = cast( OutputSchemaTool[T], OutputSchemaTool( - output_type=arg, name=tool_name, description=description, multiple=True, strict=strict + output_type=arg, + name=tool_name, + description=description, + multiple=True, + strict=strict, + schema_generator=schema_generator, ), ) else: @@ -130,7 +138,12 @@ def build( tools[name] = cast( OutputSchemaTool[T], OutputSchemaTool( - output_type=output_type_, name=name, description=description, multiple=False, strict=strict + output_type=output_type_, + name=name, + description=description, + multiple=False, + strict=strict, + schema_generator=schema_generator, ), ) @@ -173,7 +186,14 @@ class OutputSchemaTool(Generic[OutputDataT]): type_adapter: TypeAdapter[Any] def __init__( - self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None + self, + *, + output_type: type[OutputDataT], + name: str, + description: str | None, + multiple: bool, + strict: bool | None, + schema_generator: type[GenerateJsonSchema], ): """Build a OutputSchemaTool from a response type.""" if _utils.is_model_like(output_type): @@ -181,7 +201,7 @@ def __init__( outer_typed_dict_key: str | None = None # noinspection PyArgumentList parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + self.type_adapter.json_schema(schema_generator=schema_generator) ) else: response_data_typed_dict = TypedDict( # noqa: UP013 @@ -192,7 +212,7 @@ def __init__( outer_typed_dict_key = 'response' # noinspection PyArgumentList parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + self.type_adapter.json_schema(schema_generator=schema_generator) ) # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM parameters_json_schema.pop('title') diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index d074f3042..5c75cf4fa 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -7,11 +7,12 @@ from datetime import datetime from typing import TYPE_CHECKING, Generic, Union, cast +from pydantic.json_schema import GenerateJsonSchema from typing_extensions import TypeVar, assert_type, deprecated, overload from . import _utils, exceptions, messages as _messages, models from .messages import AgentStreamEvent, FinalResultEvent -from .tools import AgentDepsT, RunContext +from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext from .usage import Usage, UsageLimits if TYPE_CHECKING: @@ -76,12 +77,14 @@ def __init__( description: str | None = None, max_retries: int | None = None, strict: bool | None = None, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, ): self.output_type = type_ self.name = name self.description = description self.max_retries = max_retries self.strict = strict + self.schema_generator = schema_generator # TODO: add support for call and make type_ optional, with the following logic: # if type_ is None and call is None: