Skip to content

Commit 5fab108

Browse files
authored
Handle multi-modal and error responses from MCP tool calls (#1618)
1 parent cbfb311 commit 5fab108

29 files changed

+4025
-382
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import dataclasses
5+
import hashlib
56
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
67
from contextlib import asynccontextmanager, contextmanager
78
from contextvars import ContextVar
@@ -92,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
9293

9394
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
9495
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
96+
default_retries: int
9597

9698
tracer: Tracer
9799

@@ -546,6 +548,13 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
546548
)
547549

548550

551+
def multi_modal_content_identifier(identifier: str | bytes) -> str:
552+
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
553+
if isinstance(identifier, str):
554+
identifier = identifier.encode('utf-8')
555+
return hashlib.sha1(identifier).hexdigest()[:6]
556+
557+
549558
async def process_function_tools( # noqa C901
550559
tool_calls: list[_messages.ToolCallPart],
551560
output_tool_name: str | None,
@@ -648,8 +657,6 @@ async def process_function_tools( # noqa C901
648657
for tool, call in calls_to_run
649658
]
650659

651-
file_index = 1
652-
653660
pending = tasks
654661
while pending:
655662
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
@@ -661,17 +668,38 @@ async def process_function_tools( # noqa C901
661668
if isinstance(result, _messages.RetryPromptPart):
662669
results_by_index[index] = result
663670
elif isinstance(result, _messages.ToolReturnPart):
664-
if isinstance(result.content, _messages.MultiModalContentTypes):
665-
user_parts.append(
666-
_messages.UserPromptPart(
667-
content=[f'This is file {file_index}:', result.content],
668-
timestamp=result.timestamp,
669-
part_kind='user-prompt',
671+
contents: list[Any]
672+
single_content: bool
673+
if isinstance(result.content, list):
674+
contents = result.content # type: ignore
675+
single_content = False
676+
else:
677+
contents = [result.content]
678+
single_content = True
679+
680+
processed_contents: list[Any] = []
681+
for content in contents:
682+
if isinstance(content, _messages.MultiModalContentTypes):
683+
if isinstance(content, _messages.BinaryContent):
684+
identifier = multi_modal_content_identifier(content.data)
685+
else:
686+
identifier = multi_modal_content_identifier(content.url)
687+
688+
user_parts.append(
689+
_messages.UserPromptPart(
690+
content=[f'This is file {identifier}:', content],
691+
timestamp=result.timestamp,
692+
part_kind='user-prompt',
693+
)
670694
)
671-
)
695+
processed_contents.append(f'See file {identifier}')
696+
else:
697+
processed_contents.append(content)
672698

673-
result.content = f'See file {file_index}.'
674-
file_index += 1
699+
if single_content:
700+
result.content = processed_contents[0]
701+
else:
702+
result.content = processed_contents
675703

676704
results_by_index[index] = result
677705
else:
@@ -710,7 +738,7 @@ async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
710738
for server in ctx.deps.mcp_servers:
711739
tools = await server.list_tools()
712740
if tool_name in {tool.name for tool in tools}:
713-
return Tool(name=tool_name, function=run_tool, takes_ctx=True)
741+
return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries)
714742
return None
715743

716744

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
658658
output_validators=output_validators,
659659
function_tools=self._function_tools,
660660
mcp_servers=self._mcp_servers,
661+
default_retries=self._default_retries,
661662
tracer=tracer,
662663
get_instructions=get_instructions,
663664
)

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import base64
4+
import json
35
from abc import ABC, abstractmethod
46
from collections.abc import AsyncIterator, Sequence
57
from contextlib import AsyncExitStack, asynccontextmanager
@@ -9,16 +11,25 @@
911
from typing import Any
1012

1113
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12-
from mcp.types import JSONRPCMessage, LoggingLevel
13-
from typing_extensions import Self
14-
14+
from mcp.types import (
15+
BlobResourceContents,
16+
EmbeddedResource,
17+
ImageContent,
18+
JSONRPCMessage,
19+
LoggingLevel,
20+
TextContent,
21+
TextResourceContents,
22+
)
23+
from typing_extensions import Self, assert_never
24+
25+
from pydantic_ai.exceptions import ModelRetry
26+
from pydantic_ai.messages import BinaryContent
1527
from pydantic_ai.tools import ToolDefinition
1628

1729
try:
1830
from mcp.client.session import ClientSession
1931
from mcp.client.sse import sse_client
2032
from mcp.client.stdio import StdioServerParameters, stdio_client
21-
from mcp.types import CallToolResult
2233
except ImportError as _import_error:
2334
raise ImportError(
2435
'Please install the `mcp` package to use the MCP server, '
@@ -74,7 +85,9 @@ async def list_tools(self) -> list[ToolDefinition]:
7485
for tool in tools.tools
7586
]
7687

77-
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
88+
async def call_tool(
89+
self, tool_name: str, arguments: dict[str, Any]
90+
) -> str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]:
7891
"""Call a tool on the server.
7992
8093
Args:
@@ -83,8 +96,21 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallTool
8396
8497
Returns:
8598
The result of the tool call.
99+
100+
Raises:
101+
ModelRetry: If the tool call fails.
86102
"""
87-
return await self._client.call_tool(tool_name, arguments)
103+
result = await self._client.call_tool(tool_name, arguments)
104+
105+
content = [self._map_tool_result_part(part) for part in result.content]
106+
107+
if result.isError:
108+
text = '\n'.join(str(part) for part in content)
109+
raise ModelRetry(text)
110+
111+
if len(content) == 1:
112+
return content[0]
113+
return content
88114

89115
async def __aenter__(self) -> Self:
90116
self._exit_stack = AsyncExitStack()
@@ -105,6 +131,35 @@ async def __aexit__(
105131
await self._exit_stack.aclose()
106132
self.is_running = False
107133

134+
def _map_tool_result_part(
135+
self, part: TextContent | ImageContent | EmbeddedResource
136+
) -> str | BinaryContent | dict[str, Any] | list[Any]:
137+
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
138+
139+
if isinstance(part, TextContent):
140+
text = part.text
141+
if text.startswith(('[', '{')):
142+
try:
143+
return json.loads(text)
144+
except ValueError:
145+
pass
146+
return text
147+
elif isinstance(part, ImageContent):
148+
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
149+
elif isinstance(part, EmbeddedResource):
150+
resource = part.resource
151+
if isinstance(resource, TextResourceContents):
152+
return resource.text
153+
elif isinstance(resource, BlobResourceContents):
154+
return BinaryContent(
155+
data=base64.b64decode(resource.blob),
156+
media_type=resource.mimeType or 'application/octet-stream',
157+
)
158+
else:
159+
assert_never(resource)
160+
else:
161+
assert_never(part)
162+
108163

109164
@dataclass
110165
class MCPServerStdio(MCPServer):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ snap = ["create"]
243243

244244
[tool.codespell]
245245
# Ref: https://github.com/codespell-project/codespell#using-a-config-file
246-
skip = '.git*,*.svg,*.lock,*.css'
246+
skip = '.git*,*.svg,*.lock,*.css,*.yaml'
247247
check-hidden = true
248248
# Ignore "formatting" like **L**anguage
249249
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'

0 commit comments

Comments
 (0)