Skip to content

Allow function argument to be excluded from the tool's JSON schema #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from pydantic import BaseModel, Field

from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.server.fastmcp.utilities.func_metadata import (
FuncMetadata,
filter_args_by_arg_model,
func_metadata,
)

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
Expand Down Expand Up @@ -85,10 +89,15 @@ async def run(
return await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
arguments,
{self.context_kwarg: context}
if self.context_kwarg is not None
else None,
filter_args_by_arg_model(arguments, self.fn_metadata.arg_model),
filter_args_by_arg_model(
arguments, self.fn_metadata.client_provided_arg_model
)
| (
{self.context_kwarg: context}
if self.context_kwarg is not None
else {}
),
)
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e
72 changes: 67 additions & 5 deletions src/mcp/server/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
logger = get_logger(__name__)


class ClientProvidedArg:
"""A class to annotate an argument that is to be provided by client at call
time and to be skipped from JSON schema generation."""

def __init__(self):
pass


class ArgModelBase(BaseModel):
"""A model representing the arguments to a function."""

Expand All @@ -36,8 +44,42 @@ def model_dump_one_level(self) -> dict[str, Any]:
)


def filter_args_by_arg_model(
arguments: dict[str, Any], model_filter: type[ArgModelBase] | None = None
) -> dict[str, Any]:
"""Filter the arguments dictionary to only include keys that are present in
`model_filter`."""
if not model_filter:
return arguments
filtered_args: dict[str, Any] = {}
for key in arguments.keys():
if key in model_filter.model_fields.keys():
filtered_args[key] = arguments[key]
return filtered_args


class FuncMetadata(BaseModel):
"""Metadata about a function, including Pydantic models for argument validation.

This class manages the arguments required by a function, separating them into two
categories:

* `arg_model`: A Pydantic model representing the function's standard arguments.
These arguments will be included in the JSON schema when the tool is listed,
allowing for automatic argument parsing. This defines the structure of the
expected input.

* `client_provided_arg_model` (Optional): A Pydantic model representing arguments
that need to be provided directly by the client and will not be included in the
JSON schema.

"""

arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
client_provided_arg_model: (
Annotated[type[ArgModelBase], WithJsonSchema(None)] | None
) = None

# We can add things in the future like
# - Maybe some args are excluded from attempting to parse from JSON
# - Maybe some args are special (like context) for dependency injection
Expand Down Expand Up @@ -127,7 +169,8 @@ def func_metadata(
"""
sig = _get_typed_signature(func)
params = sig.parameters
dynamic_pydantic_model_params: dict[str, Any] = {}
dynamic_pydantic_arg_model_params: dict[str, Any] = {}
dynamic_pydantic_client_provided_arg_model_params: dict[str, Any] = {}
globalns = getattr(func, "__globals__", {})
for param in params.values():
if param.name.startswith("_"):
Expand Down Expand Up @@ -164,15 +207,34 @@ def func_metadata(
if param.default is not inspect.Parameter.empty
else PydanticUndefined,
)
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
continue

# loop through annotations,
# use ClientProvidedArg metadata to split the arguments
if any(isinstance(m, ClientProvidedArg) for m in field_info.metadata):
dynamic_pydantic_client_provided_arg_model_params[param.name] = (
field_info.annotation,
field_info,
)
else:
dynamic_pydantic_arg_model_params[param.name] = (
field_info.annotation,
field_info,
)

arguments_model = create_model(
f"{func.__name__}Arguments",
**dynamic_pydantic_model_params,
**dynamic_pydantic_arg_model_params,
__base__=ArgModelBase,
)
resp = FuncMetadata(arg_model=arguments_model)

provided_arguments_model = create_model(
f"{func.__name__}ClientProvidedArguments",
**dynamic_pydantic_client_provided_arg_model_params,
__base__=ArgModelBase,
)
resp = FuncMetadata(
arg_model=arguments_model, client_provided_arg_model=provided_arguments_model
)
return resp


Expand Down
37 changes: 36 additions & 1 deletion tests/server/fastmcp/test_func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pydantic import BaseModel, Field

from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg, func_metadata


class SomeInputModelA(BaseModel):
Expand Down Expand Up @@ -414,3 +414,38 @@ def func_with_str_and_int(a: str, b: int):
result = meta.pre_parse_json({"a": "123", "b": 123})
assert result["a"] == "123"
assert result["b"] == 123


def test_func_with_client_provided_args():
"""Test that client-provided arguments are correctly parsed and validated"""

def func_with_client_provided_args(
a: int,
b: str,
c: Annotated[int, ClientProvidedArg()],
d: Annotated[str, ClientProvidedArg()],
):
return a, b, c, d

meta = func_metadata(func_with_client_provided_args)

# Test schema
assert meta.arg_model.model_json_schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": "func_with_client_provided_argsArguments",
"type": "object",
}
assert meta.client_provided_arg_model is not None
assert meta.client_provided_arg_model.model_json_schema() == {
"properties": {
"c": {"title": "C", "type": "integer"},
"d": {"title": "D", "type": "string"},
},
"required": ["c", "d"],
"title": "func_with_client_provided_argsClientProvidedArguments",
"type": "object",
}
16 changes: 15 additions & 1 deletion tests/server/fastmcp/test_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import base64
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated

import pytest
from pydantic import AnyUrl

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
from mcp.server.fastmcp.resources import FileResource, FunctionResource
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg
from mcp.server.fastmcp.utilities.types import Image
from mcp.shared.exceptions import McpError
from mcp.shared.memory import (
Expand Down Expand Up @@ -106,6 +107,12 @@ def tool_fn(x: int, y: int) -> int:
return x + y


def tool_with_client_provided_args_fn(
x: int, y: Annotated[int, ClientProvidedArg()], z: str
) -> str:
return f"{x} + {y} = {z}"


def error_tool_fn() -> None:
raise ValueError("Test error")

Expand All @@ -129,6 +136,13 @@ async def test_add_tool(self):
mcp.add_tool(tool_fn)
assert len(mcp._tool_manager.list_tools()) == 1

@pytest.mark.anyio
async def test_add_tool_with_client_provided_arg(self):
mcp = FastMCP()
mcp.add_tool(tool_fn)
mcp.add_tool(tool_with_client_provided_args_fn)
assert len(mcp._tool_manager.list_tools()) == 2

@pytest.mark.anyio
async def test_list_tools(self):
mcp = FastMCP()
Expand Down
39 changes: 39 additions & 0 deletions tests/server/fastmcp/test_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import logging
from typing import Annotated

import pytest
from pydantic import BaseModel

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT

Expand Down Expand Up @@ -220,6 +222,31 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]:
)
assert result == ["rex", "gertrude"]

@pytest.mark.anyio
async def test_call_tool_with_client_provided_arg(self):
class ClientArgModel(BaseModel):
arg1: int
arg2: str

def afunc(args: Annotated[dict, ClientProvidedArg()], ctx: Context) -> str:
args_obj = ClientArgModel.model_validate(args)
return f"{args_obj.arg1} {args_obj.arg2}"

manager = ToolManager()
manager.add_tool(afunc)
result = await manager.call_tool(
"afunc",
{"args": {"arg1": 3, "arg2": "apple"}},
)
assert result == "3 apple"

with pytest.raises(ToolError):
# Raises an error because it misses the required args
result = await manager.call_tool(
"afunc",
{},
)


class TestToolSchema:
@pytest.mark.anyio
Expand All @@ -233,6 +260,18 @@ def something(a: int, ctx: Context) -> int:
assert "Context" not in json.dumps(tool.parameters)
assert "ctx" not in tool.fn_metadata.arg_model.model_fields

@pytest.mark.anyio
async def test_client_provided_arg_excluded_from_schema(self):
def something(a: int, b: Annotated[int, ClientProvidedArg()]) -> int:
return a + b

manager = ToolManager()
tool = manager.add_tool(something)
assert "properties" in tool.parameters
assert "b" not in tool.parameters["properties"]
assert tool.fn_metadata.client_provided_arg_model is not None
assert "b" in tool.fn_metadata.client_provided_arg_model.model_fields


class TestContextHandling:
"""Test context handling in the tool manager."""
Expand Down
Loading