diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 92a216f5..1cad7bed 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -7,7 +7,11 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp.exceptions import ToolError -from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.server.fastmcp.utilities.func_metadata import ( + FuncMetadata, + filter_args_by_arg_model, + func_metadata, +) if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -85,10 +89,15 @@ async def run( return await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, - arguments, - {self.context_kwarg: context} - if self.context_kwarg is not None - else None, + filter_args_by_arg_model(arguments, self.fn_metadata.arg_model), + filter_args_by_arg_model( + arguments, self.fn_metadata.client_provided_arg_model + ) + | ( + {self.context_kwarg: context} + if self.context_kwarg is not None + else {} + ), ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 37439132..860e8a3d 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -18,6 +18,14 @@ logger = get_logger(__name__) +class ClientProvidedArg: + """A class to annotate an argument that is to be provided by client at call + time and to be skipped from JSON schema generation.""" + + def __init__(self): + pass + + class ArgModelBase(BaseModel): """A model representing the arguments to a function.""" @@ -36,8 +44,42 @@ def model_dump_one_level(self) -> dict[str, Any]: ) +def filter_args_by_arg_model( + arguments: dict[str, Any], model_filter: type[ArgModelBase] | None = None +) -> dict[str, Any]: + """Filter the arguments dictionary to only include keys that are present in + `model_filter`.""" + if not model_filter: + return arguments + filtered_args: dict[str, Any] = {} + for key in arguments.keys(): + if key in model_filter.model_fields.keys(): + filtered_args[key] = arguments[key] + return filtered_args + + class FuncMetadata(BaseModel): + """Metadata about a function, including Pydantic models for argument validation. + + This class manages the arguments required by a function, separating them into two + categories: + + * `arg_model`: A Pydantic model representing the function's standard arguments. + These arguments will be included in the JSON schema when the tool is listed, + allowing for automatic argument parsing. This defines the structure of the + expected input. + + * `client_provided_arg_model` (Optional): A Pydantic model representing arguments + that need to be provided directly by the client and will not be included in the + JSON schema. + + """ + arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)] + client_provided_arg_model: ( + Annotated[type[ArgModelBase], WithJsonSchema(None)] | None + ) = None + # We can add things in the future like # - Maybe some args are excluded from attempting to parse from JSON # - Maybe some args are special (like context) for dependency injection @@ -127,7 +169,8 @@ def func_metadata( """ sig = _get_typed_signature(func) params = sig.parameters - dynamic_pydantic_model_params: dict[str, Any] = {} + dynamic_pydantic_arg_model_params: dict[str, Any] = {} + dynamic_pydantic_client_provided_arg_model_params: dict[str, Any] = {} globalns = getattr(func, "__globals__", {}) for param in params.values(): if param.name.startswith("_"): @@ -164,15 +207,34 @@ def func_metadata( if param.default is not inspect.Parameter.empty else PydanticUndefined, ) - dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) - continue + + # loop through annotations, + # use ClientProvidedArg metadata to split the arguments + if any(isinstance(m, ClientProvidedArg) for m in field_info.metadata): + dynamic_pydantic_client_provided_arg_model_params[param.name] = ( + field_info.annotation, + field_info, + ) + else: + dynamic_pydantic_arg_model_params[param.name] = ( + field_info.annotation, + field_info, + ) arguments_model = create_model( f"{func.__name__}Arguments", - **dynamic_pydantic_model_params, + **dynamic_pydantic_arg_model_params, __base__=ArgModelBase, ) - resp = FuncMetadata(arg_model=arguments_model) + + provided_arguments_model = create_model( + f"{func.__name__}ClientProvidedArguments", + **dynamic_pydantic_client_provided_arg_model_params, + __base__=ArgModelBase, + ) + resp = FuncMetadata( + arg_model=arguments_model, client_provided_arg_model=provided_arguments_model + ) return resp diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index b1828ffe..5ceff4f2 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -4,7 +4,7 @@ import pytest from pydantic import BaseModel, Field -from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg, func_metadata class SomeInputModelA(BaseModel): @@ -414,3 +414,38 @@ def func_with_str_and_int(a: str, b: int): result = meta.pre_parse_json({"a": "123", "b": 123}) assert result["a"] == "123" assert result["b"] == 123 + + +def test_func_with_client_provided_args(): + """Test that client-provided arguments are correctly parsed and validated""" + + def func_with_client_provided_args( + a: int, + b: str, + c: Annotated[int, ClientProvidedArg()], + d: Annotated[str, ClientProvidedArg()], + ): + return a, b, c, d + + meta = func_metadata(func_with_client_provided_args) + + # Test schema + assert meta.arg_model.model_json_schema() == { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "string"}, + }, + "required": ["a", "b"], + "title": "func_with_client_provided_argsArguments", + "type": "object", + } + assert meta.client_provided_arg_model is not None + assert meta.client_provided_arg_model.model_json_schema() == { + "properties": { + "c": {"title": "C", "type": "integer"}, + "d": {"title": "D", "type": "string"}, + }, + "required": ["c", "d"], + "title": "func_with_client_provided_argsClientProvidedArguments", + "type": "object", + } diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c5..40b38d30 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated import pytest from pydantic import AnyUrl @@ -8,6 +8,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage from mcp.server.fastmcp.resources import FileResource, FunctionResource +from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg from mcp.server.fastmcp.utilities.types import Image from mcp.shared.exceptions import McpError from mcp.shared.memory import ( @@ -106,6 +107,12 @@ def tool_fn(x: int, y: int) -> int: return x + y +def tool_with_client_provided_args_fn( + x: int, y: Annotated[int, ClientProvidedArg()], z: str +) -> str: + return f"{x} + {y} = {z}" + + def error_tool_fn() -> None: raise ValueError("Test error") @@ -129,6 +136,13 @@ async def test_add_tool(self): mcp.add_tool(tool_fn) assert len(mcp._tool_manager.list_tools()) == 1 + @pytest.mark.anyio + async def test_add_tool_with_client_provided_arg(self): + mcp = FastMCP() + mcp.add_tool(tool_fn) + mcp.add_tool(tool_with_client_provided_args_fn) + assert len(mcp._tool_manager.list_tools()) == 2 + @pytest.mark.anyio async def test_list_tools(self): mcp = FastMCP() diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 8f52e3d8..03f10bac 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -1,5 +1,6 @@ import json import logging +from typing import Annotated import pytest from pydantic import BaseModel @@ -7,6 +8,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import ToolManager +from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg from mcp.server.session import ServerSessionT from mcp.shared.context import LifespanContextT @@ -220,6 +222,31 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: ) assert result == ["rex", "gertrude"] + @pytest.mark.anyio + async def test_call_tool_with_client_provided_arg(self): + class ClientArgModel(BaseModel): + arg1: int + arg2: str + + def afunc(args: Annotated[dict, ClientProvidedArg()], ctx: Context) -> str: + args_obj = ClientArgModel.model_validate(args) + return f"{args_obj.arg1} {args_obj.arg2}" + + manager = ToolManager() + manager.add_tool(afunc) + result = await manager.call_tool( + "afunc", + {"args": {"arg1": 3, "arg2": "apple"}}, + ) + assert result == "3 apple" + + with pytest.raises(ToolError): + # Raises an error because it misses the required args + result = await manager.call_tool( + "afunc", + {}, + ) + class TestToolSchema: @pytest.mark.anyio @@ -233,6 +260,18 @@ def something(a: int, ctx: Context) -> int: assert "Context" not in json.dumps(tool.parameters) assert "ctx" not in tool.fn_metadata.arg_model.model_fields + @pytest.mark.anyio + async def test_client_provided_arg_excluded_from_schema(self): + def something(a: int, b: Annotated[int, ClientProvidedArg()]) -> int: + return a + b + + manager = ToolManager() + tool = manager.add_tool(something) + assert "properties" in tool.parameters + assert "b" not in tool.parameters["properties"] + assert tool.fn_metadata.client_provided_arg_model is not None + assert "b" in tool.fn_metadata.client_provided_arg_model.model_fields + class TestContextHandling: """Test context handling in the tool manager."""