Skip to content

Commit 800710b

Browse files
committed
Support Thinking part
1 parent bc8f36c commit 800710b

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

main.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import asyncio
2+
3+
from pydantic_ai import Agent
4+
5+
agent = Agent(model='anthropic:claude-3-7-sonnet-latest')
6+
7+
8+
@agent.tool_plain
9+
def sum(a: int, b: int) -> int:
10+
"""Get the sum of two numbers.
11+
12+
Args:
13+
a: The first number.
14+
b: The second number.
15+
16+
Returns:
17+
The sum of the two numbers.
18+
"""
19+
return a + b
20+
21+
22+
async def main():
23+
async with agent.iter('Get me the sum of 1 and 2, using the sum tool.') as agent_run:
24+
async for node in agent_run:
25+
print(node)
26+
print()
27+
print(agent_run.result)
28+
29+
30+
if __name__ == '__main__':
31+
asyncio.run(main())

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
396396
texts.append(part.content)
397397
elif isinstance(part, _messages.ToolCallPart):
398398
tool_calls.append(part)
399+
elif isinstance(part, _messages.ThinkingPart):
400+
...
399401
else:
400402
assert_never(part)
401403

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,20 @@ def has_content(self) -> bool:
390390
return bool(self.content)
391391

392392

393+
@dataclass
394+
class ThinkingPart:
395+
"""A thinking response from a model."""
396+
397+
content: str
398+
"""The thinking content of the response."""
399+
400+
signature: str | None = None
401+
"""The signature of the thinking."""
402+
403+
part_kind: Literal['thinking'] = 'thinking'
404+
"""Part type identifier, this is available on all parts as a discriminator."""
405+
406+
393407
@dataclass
394408
class ToolCallPart:
395409
"""A tool call from a model."""
@@ -439,7 +453,7 @@ def has_content(self) -> bool:
439453
return bool(self.args)
440454

441455

442-
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
456+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydantic.Discriminator('part_kind')]
443457
"""A message part returned by a model."""
444458

445459

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from json import JSONDecodeError, loads as json_loads
1010
from typing import Any, Literal, Union, cast, overload
1111

12-
from anthropic.types import DocumentBlockParam
12+
from anthropic.types import DocumentBlockParam, ThinkingBlock, ThinkingBlockParam
1313
from httpx import AsyncClient as AsyncHTTPClient
1414
from typing_extensions import assert_never
1515

@@ -27,6 +27,7 @@
2727
RetryPromptPart,
2828
SystemPromptPart,
2929
TextPart,
30+
ThinkingPart,
3031
ToolCallPart,
3132
ToolReturnPart,
3233
UserPromptPart,
@@ -227,13 +228,14 @@ async def _messages_create(
227228

228229
try:
229230
return await self.client.messages.create(
230-
max_tokens=model_settings.get('max_tokens', 1024),
231+
max_tokens=model_settings.get('max_tokens', 2048),
231232
system=system_prompt or NOT_GIVEN,
232233
messages=anthropic_messages,
233234
model=self._model_name,
234235
tools=tools or NOT_GIVEN,
235236
tool_choice=tool_choice or NOT_GIVEN,
236237
stream=stream,
238+
thinking={'budget_tokens': 1024, 'type': 'enabled'},
237239
temperature=model_settings.get('temperature', NOT_GIVEN),
238240
top_p=model_settings.get('top_p', NOT_GIVEN),
239241
timeout=model_settings.get('timeout', NOT_GIVEN),
@@ -250,6 +252,8 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
250252
for item in response.content:
251253
if isinstance(item, TextBlock):
252254
items.append(TextPart(content=item.text))
255+
elif isinstance(item, ThinkingBlock):
256+
items.append(ThinkingPart(content=item.thinking, signature=item.signature))
253257
else:
254258
assert isinstance(item, ToolUseBlock), 'unexpected item type'
255259
items.append(
@@ -316,10 +320,17 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
316320
user_content_params.append(retry_param)
317321
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
318322
elif isinstance(m, ModelResponse):
319-
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
323+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam | ThinkingBlockParam] = []
320324
for response_part in m.parts:
321325
if isinstance(response_part, TextPart):
322326
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
327+
elif isinstance(response_part, ThinkingPart):
328+
assert response_part.signature is not None, 'Thinking part must have a signature'
329+
assistant_content_params.append(
330+
ThinkingBlockParam(
331+
thinking=response_part.content, signature=response_part.signature, type='thinking'
332+
)
333+
)
323334
else:
324335
tool_use_block_param = ToolUseBlockParam(
325336
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),

0 commit comments

Comments
 (0)