diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 778a0f26f2..4d4f6a3838 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -254,12 +254,63 @@ async def direct_call_tool( ): # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. # See https://github.com/modelcontextprotocol/python-sdk#structured-output - if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: - return structured['result'] - return structured - - mapped = [await self._map_tool_result_part(part) for part in result.content] - return mapped[0] if len(mapped) == 1 else mapped + return_value = ( + structured['result'] + if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured + else structured + ) + return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value + + parts_with_metadata = [await self._map_tool_result_part(part) for part in result.content] + parts_only = [mapped_part for mapped_part, _ in parts_with_metadata] + any_part_has_metadata = any(metadata is not None for _, metadata in parts_with_metadata) + if not any_part_has_metadata and result.meta is None: + # There is no metadata in the tool result or its parts, return just the mapped values + return parts_only[0] if len(parts_only) == 1 else parts_only + elif not any_part_has_metadata and result.meta is not None: + # There is no metadata in the tool result parts, but there is metadata in the tool result + return messages.ToolReturn( + return_value=(parts_only[0] if len(parts_only) == 1 else parts_only), + metadata=result.meta, + ) + else: + # There is metadata in the tool result parts and there may be a metadata in the tool result, return a ToolReturn object + return_values: list[Any] = [] + user_contents: list[Any] = [] + return_metadata: dict[str, Any] = {} + return_metadata.setdefault('content', []) + for idx, (mapped_part, part_metadata) in enumerate(parts_with_metadata): + if part_metadata is not None: + # Merge the metadata dictionaries, with part metadata taking precedence + return_metadata['content'].append({str(idx): part_metadata}) + if isinstance(mapped_part, messages.BinaryContent): + identifier = mapped_part.identifier + + return_values.append(f'See file {identifier}') + user_contents.append([f'This is file {identifier}:', mapped_part]) + else: + user_contents.append(mapped_part) + + if result.meta is not None and return_metadata.get('content', None) is not None: + # Merge the tool result metadata into the return metadata, with part metadata taking precedence + return_metadata['result'] = result.meta + elif result.meta is not None and return_metadata.get('content', None) is None: + return_metadata = result.meta + elif ( + result.meta is None + and return_metadata.get('content', None) is not None + and len(return_metadata['content']) == 1 + ): + # If there is only one content metadata, unwrap it + return_metadata = return_metadata['content'][0] + # TODO: What else should we cover here? + + # Finally, construct and return the ToolReturn object + return messages.ToolReturn( + return_value=return_values, + content=user_contents, + metadata=return_metadata, + ) async def call_tool( self, @@ -374,35 +425,32 @@ async def _sampling_callback( async def _map_tool_result_part( self, part: mcp_types.ContentBlock - ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: + ) -> tuple[str | messages.BinaryContent | dict[str, Any] | list[Any], dict[str, Any] | None]: # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values + metadata: dict[str, Any] | None = part.meta if isinstance(part, mcp_types.TextContent): text = part.text if text.startswith(('[', '{')): try: - return pydantic_core.from_json(text) + return pydantic_core.from_json(text), metadata except ValueError: pass - return text + return text, metadata elif isinstance(part, mcp_types.ImageContent): - return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) - elif isinstance(part, mcp_types.AudioContent): + return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata + elif isinstance(part, mcp_types.AudioContent): # pragma: no cover # NOTE: The FastMCP server doesn't support audio content. # See for more details. - return messages.BinaryContent( - data=base64.b64decode(part.data), media_type=part.mimeType - ) # pragma: no cover + return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata elif isinstance(part, mcp_types.EmbeddedResource): - resource = part.resource - return self._get_content(resource) + return self._get_content(part.resource), metadata elif isinstance(part, mcp_types.ResourceLink): resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri) - return ( - self._get_content(resource_result.contents[0]) - if len(resource_result.contents) == 1 - else [self._get_content(resource) for resource in resource_result.contents] - ) + if len(resource_result.contents) == 1: + return self._get_content(resource_result.contents[0]), metadata + else: + return [self._get_content(resource) for resource in resource_result.contents], metadata else: assert_never(part) @@ -875,6 +923,7 @@ def __eq__(self, value: object, /) -> bool: ToolResult = ( str | messages.BinaryContent + | messages.ToolReturn | dict[str, Any] | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 54b105ab29..80ba0d806d 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -68,6 +68,45 @@ async def get_image_resource_link() -> ResourceLink: ) +@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator')) +async def get_collatz_conjecture(n: int) -> TextContent: + """Generate the Collatz conjecture sequence for a given number. + This tool attaches response metadata. + + Args: + n: The starting number for the Collatz sequence. + Returns: + A list representing the Collatz sequence with attached metadata. + """ + if n <= 0: + raise ValueError('Starting number for the Collatz conjecture must be a positive integer.') + + input_param_n = n # store the original input value + + sequence = [n] + while n != 1: + if n % 2 == 0: + n = n // 2 + else: + n = 3 * n + 1 + sequence.append(n) + + return TextContent( + type='text', + text=str(sequence), + _meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}}, + ) + + +@mcp.tool() +async def get_structured_text_content_with_metadata() -> dict[str, Any]: + """Return structured dict with metadata.""" + return { + 'result': 'This is some text content.', + '_meta': {'pydantic_ai': {'source': 'get_structured_text_content_with_metadata'}}, + } + + @mcp.resource('resource://kiwi.png', mime_type='image/png') async def kiwi_resource() -> bytes: return Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes() diff --git a/tests/test_mcp.py b/tests/test_mcp.py index fc7a8e5dc2..e1dadee428 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -26,6 +26,7 @@ from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError from pydantic_ai.mcp import MCPServerStreamableHTTP, load_mcp_servers +from pydantic_ai.messages import ToolReturn from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext @@ -77,7 +78,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(20) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -87,6 +88,36 @@ async def test_stdio_server(run_context: RunContext[int]): assert result == snapshot(32.0) +async def test_tool_response_metadata(run_context: RunContext[int]): + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] + assert len(tools) == snapshot(20) + assert tools[4].name == 'get_collatz_conjecture' + assert isinstance(tools[4].description, str) + assert tools[4].description.startswith('Generate the Collatz conjecture sequence for a given number.') + + result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7}) + assert isinstance(result, ToolReturn) + assert result.return_value == snapshot([7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]) + assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}}) + + +async def test_tool_structured_response_metadata(run_context: RunContext[int]): + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] + assert len(tools) == snapshot(20) + assert tools[5].name == 'get_structured_text_content_with_metadata' + assert isinstance(tools[5].description, str) + assert tools[5].description.startswith('Return structured dict with metadata.') + + result = await server.direct_call_tool('get_structured_text_content_with_metadata', {}) + assert isinstance(result, ToolReturn) + assert result.return_value == 'This is some text content.' + assert result.metadata == snapshot({'pydantic_ai': {'source': 'get_structured_text_content_with_metadata'}}) + + async def test_reentrant_context_manager(): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: @@ -138,7 +169,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(20) async def test_process_tool_call(run_context: RunContext[int]) -> int: