Skip to content

Commit 0b43f65

Browse files
committed
Add builtin tools
1 parent 3ad6d38 commit 0b43f65

File tree

13 files changed

+407
-160
lines changed

13 files changed

+407
-160
lines changed

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
from abc import ABC
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
55
from typing import Literal
66

77
from typing_extensions import TypedDict
@@ -19,20 +19,6 @@ class AbstractBuiltinTool(ABC):
1919
"""
2020

2121

22-
class UserLocation(TypedDict, total=False):
23-
"""Allows you to localize search results based on a user's location.
24-
25-
Supported by:
26-
* Anthropic
27-
* OpenAI
28-
"""
29-
30-
city: str
31-
country: str
32-
region: str
33-
timezone: str
34-
35-
3622
@dataclass
3723
class WebSearchTool(AbstractBuiltinTool):
3824
"""A builtin tool that allows your agent to search the web for information.
@@ -47,7 +33,7 @@ class WebSearchTool(AbstractBuiltinTool):
4733
* OpenAI
4834
"""
4935

50-
user_location: UserLocation = field(default_factory=UserLocation)
36+
user_location: UserLocation | None = None
5137
"""The `user_location` parameter allows you to localize search results based on a user's location.
5238
5339
Supported by:
@@ -82,3 +68,26 @@ class WebSearchTool(AbstractBuiltinTool):
8268
Supported by:
8369
* Anthropic
8470
"""
71+
72+
73+
class UserLocation(TypedDict, total=False):
74+
"""Allows you to localize search results based on a user's location.
75+
76+
Supported by:
77+
* Anthropic
78+
* OpenAI
79+
"""
80+
81+
city: str
82+
country: str
83+
region: str
84+
timezone: str
85+
86+
87+
class CodeExecutionTool(AbstractBuiltinTool):
88+
"""A builtin tool that allows your agent to execute code.
89+
90+
Supported by:
91+
* Anthropic
92+
* OpenAI
93+
"""

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
349349

350350

351351
@dataclass(repr=False)
352-
class ToolReturnPart:
353-
"""A tool return message, this encodes the result of running a tool."""
352+
class BaseToolReturnPart:
353+
"""Base class for tool return parts."""
354354

355355
tool_name: str
356356
"""The name of the "tool" was called."""
@@ -364,9 +364,6 @@ class ToolReturnPart:
364364
timestamp: datetime = field(default_factory=_now_utc)
365365
"""The timestamp, when the tool returned."""
366366

367-
part_kind: Literal['tool-return'] = 'tool-return'
368-
"""Part type identifier, this is available on all parts as a discriminator."""
369-
370367
def model_response_str(self) -> str:
371368
"""Return a string representation of the content for the model."""
372369
if isinstance(self.content, str):
@@ -391,6 +388,22 @@ def otel_event(self, _settings: InstrumentationSettings) -> Event:
391388
__repr__ = _utils.dataclasses_no_defaults_repr
392389

393390

391+
@dataclass(repr=False)
392+
class ToolReturnPart(BaseToolReturnPart):
393+
"""A tool return message, this encodes the result of running a tool."""
394+
395+
part_kind: Literal['tool-return'] = 'tool-return'
396+
"""Part type identifier, this is available on all parts as a discriminator."""
397+
398+
399+
@dataclass(repr=False)
400+
class ServerToolReturnPart(BaseToolReturnPart):
401+
"""A tool return message from a server tool."""
402+
403+
part_kind: Literal['server-tool-return'] = 'server-tool-return'
404+
"""Part type identifier, this is available on all parts as a discriminator."""
405+
406+
394407
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
395408

396409

@@ -503,7 +516,7 @@ def has_content(self) -> bool:
503516

504517

505518
@dataclass(repr=False)
506-
class ToolCallPart:
519+
class BaseToolCallPart:
507520
"""A tool call from a model."""
508521

509522
tool_name: str
@@ -521,9 +534,6 @@ class ToolCallPart:
521534
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
522535
"""
523536

524-
part_kind: Literal['tool-call'] = 'tool-call'
525-
"""Part type identifier, this is available on all parts as a discriminator."""
526-
527537
def args_as_dict(self) -> dict[str, Any]:
528538
"""Return the arguments as a Python dictionary.
529539
@@ -560,7 +570,28 @@ def has_content(self) -> bool:
560570
__repr__ = _utils.dataclasses_no_defaults_repr
561571

562572

563-
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
573+
@dataclass(repr=False)
574+
class ToolCallPart(BaseToolCallPart):
575+
"""A tool call from a model."""
576+
577+
part_kind: Literal['tool-call'] = 'tool-call'
578+
"""Part type identifier, this is available on all parts as a discriminator."""
579+
580+
581+
@dataclass(repr=False)
582+
class ServerToolCallPart(BaseToolCallPart):
583+
"""A tool call from a server tool."""
584+
585+
model_name: str | None = None
586+
"""The name of the model that generated the response."""
587+
588+
part_kind: Literal['server-tool-call'] = 'server-tool-call'
589+
"""Part type identifier, this is available on all parts as a discriminator."""
590+
591+
592+
ModelResponsePart = Annotated[
593+
Union[TextPart, ToolCallPart, ServerToolCallPart, ServerToolReturnPart], pydantic.Discriminator('part_kind')
594+
]
564595
"""A message part returned by a model."""
565596

566597

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from datetime import datetime, timezone
88
from typing import Any, Literal, Union, cast, overload
99

10-
from anthropic.types import ServerToolUseBlock, ToolUnionParam, WebSearchTool20250305Param, WebSearchToolResultBlock
11-
from anthropic.types.web_search_tool_20250305_param import UserLocation
10+
from anthropic.types.beta import BetaMessage, BetaRawMessageStreamEvent, BetaToolUnionParam
1211
from typing_extensions import assert_never
1312

14-
from pydantic_ai.builtin_tools import WebSearchTool
13+
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
1514

1615
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1716
from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -25,6 +24,8 @@
2524
ModelResponsePart,
2625
ModelResponseStreamEvent,
2726
RetryPromptPart,
27+
ServerToolCallPart,
28+
ServerToolReturnPart,
2829
SystemPromptPart,
2930
TextPart,
3031
ToolCallPart,
@@ -61,15 +62,21 @@
6162
RawMessageStartEvent,
6263
RawMessageStopEvent,
6364
RawMessageStreamEvent,
65+
ServerToolUseBlock,
6466
TextBlock,
6567
TextBlockParam,
6668
TextDelta,
6769
ToolChoiceParam,
6870
ToolParam,
6971
ToolResultBlockParam,
72+
ToolUnionParam,
7073
ToolUseBlock,
7174
ToolUseBlockParam,
75+
WebSearchTool20250305Param,
76+
WebSearchToolResultBlock,
7277
)
78+
from anthropic.types.beta.beta_code_execution_tool_20250522_param import BetaCodeExecutionTool20250522Param
79+
from anthropic.types.web_search_tool_20250305_param import UserLocation
7380
except ImportError as _import_error:
7481
raise ImportError(
7582
'Please install `anthropic` to use the Anthropic model, '
@@ -207,10 +214,11 @@ async def _messages_create(
207214
stream: bool,
208215
model_settings: AnthropicModelSettings,
209216
model_request_parameters: ModelRequestParameters,
210-
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
217+
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent] | BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
211218
# standalone function to make it easier to override
212219
tools = self._get_tools(model_request_parameters)
213220
tools += self._get_builtin_tools(model_request_parameters)
221+
beta_tools = self._get_beta_tools(model_request_parameters)
214222
tool_choice: ToolChoiceParam | None
215223

216224
if not tools:
@@ -256,11 +264,25 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
256264
for item in response.content:
257265
if isinstance(item, TextBlock):
258266
items.append(TextPart(content=item.text))
259-
if isinstance(item, WebSearchToolResultBlock):
260-
# TODO(Marcelo): We should send something back to the user, because we need to send it back on the next request.
261-
...
267+
elif isinstance(item, WebSearchToolResultBlock):
268+
items.append(
269+
ServerToolReturnPart(
270+
tool_name='web_search',
271+
content=item.content,
272+
tool_call_id=item.tool_use_id,
273+
)
274+
)
275+
elif isinstance(item, ServerToolUseBlock):
276+
items.append(
277+
ServerToolCallPart(
278+
model_name='anthropic',
279+
tool_name=item.name,
280+
args=cast(dict[str, Any], item.input),
281+
tool_call_id=item.id,
282+
)
283+
)
262284
else:
263-
assert isinstance(item, (ToolUseBlock, ServerToolUseBlock)), f'unexpected item type {type(item)}'
285+
assert isinstance(item, ToolUseBlock), f'unexpected item type {type(item)}'
264286
items.append(
265287
ToolCallPart(
266288
tool_name=item.name,
@@ -305,6 +327,13 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
305327
)
306328
return tools
307329

330+
def _get_beta_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
331+
tools: list[BetaToolUnionParam] = []
332+
for tool in model_request_parameters.builtin_tools:
333+
if isinstance(tool, CodeExecutionTool):
334+
tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
335+
return tools
336+
308337
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
309338
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
310339
system_prompt_parts: list[str] = []

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ openai = ["openai>=1.75.0"]
6565
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
6666
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
6767
google = ["google-genai>=1.15.0"]
68-
anthropic = ["anthropic>=0.49.0"]
68+
anthropic = ["anthropic>=0.52.0"]
6969
groq = ["groq>=0.15.0"]
7070
mistral = ["mistralai>=1.2.5"]
7171
bedrock = ["boto3>=1.35.74"]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '143'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.anthropic.com
16+
method: POST
17+
parsed_body:
18+
max_tokens: 1024
19+
messages:
20+
- content:
21+
- text: How much is 3 * 12390?
22+
type: text
23+
role: user
24+
model: claude-3-5-sonnet-latest
25+
uri: https://api.anthropic.com/v1/messages?beta=true
26+
response:
27+
headers:
28+
connection:
29+
- keep-alive
30+
content-length:
31+
- '370'
32+
content-type:
33+
- application/json
34+
strict-transport-security:
35+
- max-age=31536000; includeSubDomains; preload
36+
transfer-encoding:
37+
- chunked
38+
parsed_body:
39+
content:
40+
- text: |-
41+
Let me calculate that:
42+
43+
3 * 12390 = 37170
44+
type: text
45+
id: msg_01D3uRsKuEcst7jtMdwoqYUi
46+
model: claude-3-5-sonnet-20241022
47+
role: assistant
48+
stop_reason: end_turn
49+
stop_sequence: null
50+
type: message
51+
usage:
52+
cache_creation_input_tokens: 0
53+
cache_read_input_tokens: 0
54+
input_tokens: 18
55+
output_tokens: 20
56+
service_tier: standard
57+
status:
58+
code: 200
59+
message: OK
60+
version: 1

0 commit comments

Comments
 (0)