From 04c4a727dacf42609032a13cb00c6a967ea21ed3 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 25 Jun 2025 19:38:17 +0000 Subject: [PATCH 1/6] add authorizer plugin to enable fine granded authorization checks on tools/resources/prompts --- src/mcp/server/fastmcp/authorizer.py | 107 ++++++++++++++++++ src/mcp/server/fastmcp/prompts/manager.py | 21 +++- .../fastmcp/resources/resource_manager.py | 22 +++- src/mcp/server/fastmcp/server.py | 19 +++- src/mcp/server/fastmcp/tools/tool_manager.py | 16 ++- tests/server/fastmcp/test_tool_manager.py | 27 ++++- 6 files changed, 193 insertions(+), 19 deletions(-) create mode 100644 src/mcp/server/fastmcp/authorizer.py diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py new file mode 100644 index 000000000..4973b3bd8 --- /dev/null +++ b/src/mcp/server/fastmcp/authorizer.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Any +from pydantic import AnyUrl + +from mcp.shared.context import LifespanContextT, RequestT + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + + +class Authorizer: + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def permit_get_tool(self, name: str) -> bool: + """Check if the specified tool can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_tool(self, name: str) -> bool: + """Check if the specified tool can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> bool: + """Check if the specified tool can be called from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_resource(self, resource: AnyUrl | str) -> bool: + """Check if the specified resource can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool: + """Check if the specified resource can be created on the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_resource(self, resource: AnyUrl | str) -> bool: + """Check if the specified resource can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_template(self, resource: AnyUrl | str) -> bool: + """Check if the specified template can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_prompt(self, name: str) -> bool: + """Check if the specified prompt can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_prompt(self, name: str) -> bool: + """Check if the specified prompt can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + """Check if the specified prompt can be rendered from the associated mcp server""" + return False + +class AllAllAuthorizer(Authorizer): + def permit_get_tool(self, name: str) -> bool: + return True + + def permit_list_tool(self, name: str) -> bool: + return True + + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + ) -> bool: + return True + + def permit_get_resource(self, resource: AnyUrl | str) -> bool: + return True + + def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool: + return True + + def permit_list_resource(self, resource: AnyUrl | str) -> bool: + return True + + def permit_list_template(self, resource: AnyUrl | str) -> bool: + return True + + def permit_get_prompt(self, name: str) -> bool: + return True + + def permit_list_prompt(self, name: str) -> bool: + return True + + def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + return True + diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 6b01d91cd..1f7d382f5 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -2,6 +2,7 @@ from typing import Any +from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.prompts.base import Message, Prompt from mcp.server.fastmcp.utilities.logging import get_logger @@ -11,17 +12,25 @@ class PromptManager: """Manages FastMCP prompts.""" - def __init__(self, warn_on_duplicate_prompts: bool = True): + def __init__( + self, + warn_on_duplicate_prompts: bool = True, + authorizer: Authorizer = AllAllAuthorizer(), + ): self._prompts: dict[str, Prompt] = {} + self._authorizer = authorizer self.warn_on_duplicate_prompts = warn_on_duplicate_prompts def get_prompt(self, name: str) -> Prompt | None: """Get prompt by name.""" - return self._prompts.get(name) + if self._authorizer.permit_get_prompt(name): + return self._prompts.get(name) + else: + return None def list_prompts(self) -> list[Prompt]: """List all registered prompts.""" - return list(self._prompts.values()) + return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name)] def add_prompt( self, @@ -44,5 +53,7 @@ async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - - return await prompt.render(arguments) + if self._authorizer.permit_render_prompt(name, arguments): + return await prompt.render(arguments) + else: + raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 35e4ec04d..19f525b60 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -5,6 +5,7 @@ from pydantic import AnyUrl +from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger @@ -15,10 +16,15 @@ class ResourceManager: """Manages FastMCP resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__( + self, + warn_on_duplicate_resources: bool = True, + authorizer: Authorizer = AllAllAuthorizer(), + ): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + self._authorizer = authorizer def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. @@ -74,13 +80,19 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None: # First check concrete resources if resource := self._resources.get(uri_str): - return resource + if self._authorizer.permit_get_resource(uri_str): + return resource + else: + raise ValueError(f"Unknown resource: {uri}") # Then check templates for template in self._templates.values(): if params := template.matches(uri_str): try: - return await template.create_resource(uri_str, params) + if self._authorizer.permit_create_resource(uri_str, params): + return await template.create_resource(uri_str, params) + else: + raise ValueError(f"Unknown resource: {uri}") except Exception as e: raise ValueError(f"Error creating resource from template: {e}") @@ -89,9 +101,9 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None: def list_resources(self) -> list[Resource]: """List all registered resources.""" logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) + return [resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri)] def list_templates(self) -> list[ResourceTemplate]: """List all registered templates.""" logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) + return [template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri)] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 668c6df82..2af68025a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -33,6 +33,7 @@ from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation +from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -120,6 +121,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # Transport security settings (DNS rebinding protection) transport_security: TransportSecuritySettings | None = None + authorizer: Authorizer = AllAllAuthorizer() + def lifespan_wrapper( app: FastMCP, @@ -152,9 +155,19 @@ def __init__( instructions=instructions, lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) + self._tool_manager = ToolManager( + tools=tools, + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools, + authorizer=self.settings.authorizer, + ) + self._resource_manager = ResourceManager( + warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources, + authorizer=self.settings.authorizer, + ) + self._prompt_manager = PromptManager( + warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts, + authorizer=self.settings.authorizer, + ) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index b9ca1655d..caed79032 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger @@ -24,6 +25,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, + authorizer: Authorizer = AllAllAuthorizer(), ): self._tools: dict[str, Tool] = {} if tools is not None: @@ -32,15 +34,19 @@ def __init__( logger.warning(f"Tool already exists: {tool.name}") self._tools[tool.name] = tool - self.warn_on_duplicate_tools = warn_on_duplicate_tools + self.warn_on_duplicate_tools = (warn_on_duplicate_tools,) + self._authorizer = authorizer def get_tool(self, name: str) -> Tool | None: """Get tool by name.""" - return self._tools.get(name) + if self._authorizer.permit_get_tool(name): + return self._tools.get(name) + else: + return None def list_tools(self) -> list[Tool]: """List all registered tools.""" - return list(self._tools.values()) + return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name)] def add_tool( self, @@ -67,8 +73,8 @@ async def call_tool( context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, ) -> Any: """Call a tool by name with arguments.""" - tool = self.get_tool(name) - if not tool: + tool = self._tools.get(name) + if not tool or not self._authorizer.permit_call_tool(name, arguments, context): raise ToolError(f"Unknown tool: {name}") return await tool.run(arguments, context=context) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 206df42d7..76aa7e5ad 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.authorizer import Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata @@ -171,7 +172,7 @@ def f(x: int) -> int: manager = ToolManager() manager.add_tool(f) - manager.warn_on_duplicate_tools = False + manager.warn_on_duplicate_tools = False # type: ignore with caplog.at_level(logging.WARNING): manager.add_tool(f) assert "Tool already exists: f" not in caplog.text @@ -311,6 +312,30 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: ) assert result == ["rex", "gertrude"] + @pytest.mark.anyio + async def test_call_tool_not_permitted(self): + async def double(n: int) -> int: + """Double a number.""" + return n * 2 + + class TestAuthorizer(Authorizer): + allow: bool = True + + def permit_list_tool(self, name): + return self.allow + + def permit_call_tool(self, name, arguments, context=None): + return self.allow + + authorizer = TestAuthorizer() + manager = ToolManager(authorizer=authorizer) + manager.add_tool(double) + result = await manager.call_tool("double", {"n": 5}) + assert result == 10 + authorizer.allow = False + with pytest.raises(ToolError, match="Unknown tool: double"): + await manager.call_tool("double", {"n": 5}) + class TestToolSchema: @pytest.mark.anyio From 55581f627feabdcefca534c64312e84f9a460aea Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 25 Jun 2025 19:41:34 +0000 Subject: [PATCH 2/6] ruff check fix --- src/mcp/server/fastmcp/authorizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py index 4973b3bd8..7e6a19b9d 100644 --- a/src/mcp/server/fastmcp/authorizer.py +++ b/src/mcp/server/fastmcp/authorizer.py @@ -2,6 +2,7 @@ import abc from typing import TYPE_CHECKING, Any + from pydantic import AnyUrl from mcp.shared.context import LifespanContextT, RequestT From a80d47f006df2868d03fefbf0859a7c8513e52ac Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 25 Jun 2025 19:42:02 +0000 Subject: [PATCH 3/6] ruff format --- src/mcp/server/fastmcp/authorizer.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py index 7e6a19b9d..ed373d77f 100644 --- a/src/mcp/server/fastmcp/authorizer.py +++ b/src/mcp/server/fastmcp/authorizer.py @@ -19,7 +19,7 @@ class Authorizer: def permit_get_tool(self, name: str) -> bool: """Check if the specified tool can be retrieved from the associated mcp server""" return False - + @abc.abstractmethod def permit_list_tool(self, name: str) -> bool: """Check if the specified tool can be listed from the associated mcp server""" @@ -54,26 +54,27 @@ def permit_list_resource(self, resource: AnyUrl | str) -> bool: def permit_list_template(self, resource: AnyUrl | str) -> bool: """Check if the specified template can be listed from the associated mcp server""" return False - + @abc.abstractmethod def permit_get_prompt(self, name: str) -> bool: """Check if the specified prompt can be retrieved from the associated mcp server""" return False - + @abc.abstractmethod def permit_list_prompt(self, name: str) -> bool: """Check if the specified prompt can be listed from the associated mcp server""" return False - + @abc.abstractmethod - def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: """Check if the specified prompt can be rendered from the associated mcp server""" return False - + + class AllAllAuthorizer(Authorizer): def permit_get_tool(self, name: str) -> bool: return True - + def permit_list_tool(self, name: str) -> bool: return True @@ -90,19 +91,18 @@ def permit_get_resource(self, resource: AnyUrl | str) -> bool: def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool: return True - + def permit_list_resource(self, resource: AnyUrl | str) -> bool: return True def permit_list_template(self, resource: AnyUrl | str) -> bool: return True - + def permit_get_prompt(self, name: str) -> bool: return True - + def permit_list_prompt(self, name: str) -> bool: return True - def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: return True - From 092d769f3455e8f24641680e60a07d3fb079ce96 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 28 Jun 2025 09:33:15 +0000 Subject: [PATCH 4/6] pass context to all permit methods to allow checking other request params such as http headers --- src/mcp/server/fastmcp/authorizer.py | 74 +++++++++++++------ src/mcp/server/fastmcp/prompts/manager.py | 27 +++++-- .../fastmcp/resources/resource_manager.py | 27 +++++-- src/mcp/server/fastmcp/server.py | 27 ++++--- src/mcp/server/fastmcp/tools/tool_manager.py | 15 ++-- tests/server/fastmcp/test_tool_manager.py | 2 +- 6 files changed, 118 insertions(+), 54 deletions(-) diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py index ed373d77f..1e6314422 100644 --- a/src/mcp/server/fastmcp/authorizer.py +++ b/src/mcp/server/fastmcp/authorizer.py @@ -4,24 +4,28 @@ from typing import TYPE_CHECKING, Any from pydantic import AnyUrl +from starlette.requests import Request -from mcp.shared.context import LifespanContextT, RequestT +from mcp.server.session import ServerSession if TYPE_CHECKING: from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT class Authorizer: __metaclass__ = abc.ABCMeta @abc.abstractmethod - def permit_get_tool(self, name: str) -> bool: + def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: """Check if the specified tool can be retrieved from the associated mcp server""" return False @abc.abstractmethod - def permit_list_tool(self, name: str) -> bool: + def permit_list_tool( + self, + name: str, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: """Check if the specified tool can be listed from the associated mcp server""" return False @@ -30,79 +34,105 @@ def permit_call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSession, object, Request] | None = None, ) -> bool: """Check if the specified tool can be called from the associated mcp server""" return False @abc.abstractmethod - def permit_get_resource(self, resource: AnyUrl | str) -> bool: + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: """Check if the specified resource can be retrieved from the associated mcp server""" return False @abc.abstractmethod - def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool: + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None + ) -> bool: """Check if the specified resource can be created on the associated mcp server""" return False @abc.abstractmethod - def permit_list_resource(self, resource: AnyUrl | str) -> bool: + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: """Check if the specified resource can be listed from the associated mcp server""" return False @abc.abstractmethod - def permit_list_template(self, resource: AnyUrl | str) -> bool: + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: """Check if the specified template can be listed from the associated mcp server""" return False @abc.abstractmethod - def permit_get_prompt(self, name: str) -> bool: + def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: """Check if the specified prompt can be retrieved from the associated mcp server""" return False @abc.abstractmethod - def permit_list_prompt(self, name: str) -> bool: + def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: """Check if the specified prompt can be listed from the associated mcp server""" return False @abc.abstractmethod - def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: """Check if the specified prompt can be rendered from the associated mcp server""" return False class AllAllAuthorizer(Authorizer): - def permit_get_tool(self, name: str) -> bool: + def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: return True - def permit_list_tool(self, name: str) -> bool: + def permit_list_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: return True def permit_call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSession, object, Request] | None = None, ) -> bool: return True - def permit_get_resource(self, resource: AnyUrl | str) -> bool: + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: return True - def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool: + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None + ) -> bool: return True - def permit_list_resource(self, resource: AnyUrl | str) -> bool: + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: return True - def permit_list_template(self, resource: AnyUrl | str) -> bool: + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> bool: return True - def permit_get_prompt(self, name: str) -> bool: + def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: return True - def permit_list_prompt(self, name: str) -> bool: + def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: return True - def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool: + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: return True diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 1f7d382f5..0d15a1c9a 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -1,10 +1,18 @@ """Prompt management functionality.""" -from typing import Any +from __future__ import annotations as _annotations + +from typing import TYPE_CHECKING, Any + +from starlette.requests import Request from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.prompts.base import Message, Prompt from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -21,16 +29,16 @@ def __init__( self._authorizer = authorizer self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - def get_prompt(self, name: str) -> Prompt | None: + def get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Prompt | None: """Get prompt by name.""" - if self._authorizer.permit_get_prompt(name): + if self._authorizer.permit_get_prompt(name, context): return self._prompts.get(name) else: return None - def list_prompts(self) -> list[Prompt]: + def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[Prompt]: """List all registered prompts.""" - return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name)] + return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name, context)] def add_prompt( self, @@ -48,12 +56,17 @@ def add_prompt( self._prompts[prompt.name] = prompt return prompt - async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: + async def render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - if self._authorizer.permit_render_prompt(name, arguments): + if self._authorizer.permit_render_prompt(name, arguments, context): return await prompt.render(arguments) else: raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 19f525b60..709af70e8 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -1,14 +1,21 @@ """Resource manager functionality.""" +from __future__ import annotations as _annotations + from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import AnyUrl +from starlette.requests import Request from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -73,14 +80,16 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource(self, uri: AnyUrl | str) -> Resource | None: + async def get_resource( + self, uri: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None + ) -> Resource | None: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) # First check concrete resources if resource := self._resources.get(uri_str): - if self._authorizer.permit_get_resource(uri_str): + if self._authorizer.permit_get_resource(uri_str, context): return resource else: raise ValueError(f"Unknown resource: {uri}") @@ -98,12 +107,16 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None: raise ValueError(f"Unknown resource: {uri}") - def list_resources(self) -> list[Resource]: + def list_resources(self, context: Context[ServerSession, object, Request] | None = None) -> list[Resource]: """List all registered resources.""" logger.debug("Listing resources", extra={"count": len(self._resources)}) - return [resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri)] + return [ + resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri, context) + ] - def list_templates(self) -> list[ResourceTemplate]: + def list_templates(self, context: Context[ServerSession, object, Request] | None = None) -> list[ResourceTemplate]: """List all registered templates.""" logger.debug("Listing templates", extra={"count": len(self._templates)}) - return [template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri)] + return [ + template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri, context) + ] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2af68025a..01bee20b1 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -257,7 +257,8 @@ def _setup_handlers(self) -> None: async def list_tools(self) -> list[MCPTool]: """List all available tools.""" - tools = self._tool_manager.list_tools() + context = self.get_context() + tools = self._tool_manager.list_tools(context) return [ MCPTool( name=info.name, @@ -289,8 +290,8 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Cont async def list_resources(self) -> list[MCPResource]: """List all available resources.""" - - resources = self._resource_manager.list_resources() + context = self.get_context() + resources = self._resource_manager.list_resources(context) return [ MCPResource( uri=resource.uri, @@ -303,7 +304,8 @@ async def list_resources(self) -> list[MCPResource]: ] async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() + context = self.get_context() + templates = self._resource_manager.list_templates(context) return [ MCPResourceTemplate( uriTemplate=template.uri_template, @@ -316,8 +318,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - resource = await self._resource_manager.get_resource(uri) + context = self.get_context() + resource = await self._resource_manager.get_resource(uri, context) if not resource: raise ResourceError(f"Unknown resource: {uri}") @@ -924,9 +926,9 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> lifespan=lambda app: self.session_manager.run(), ) - async def list_prompts(self) -> list[MCPPrompt]: + async def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[MCPPrompt]: """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() + prompts = self._prompt_manager.list_prompts(context) return [ MCPPrompt( name=prompt.name, @@ -944,10 +946,15 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> GetPromptResult: """Get a prompt by name with arguments.""" try: - messages = await self._prompt_manager.render_prompt(name, arguments) + messages = await self._prompt_manager.render_prompt(name, arguments, context) return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages)) except Exception as e: diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index caed79032..160863814 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -3,16 +3,17 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from starlette.requests import Request + from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT, RequestT +from mcp.server.session import ServerSession from mcp.types import ToolAnnotations if TYPE_CHECKING: from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -37,16 +38,16 @@ def __init__( self.warn_on_duplicate_tools = (warn_on_duplicate_tools,) self._authorizer = authorizer - def get_tool(self, name: str) -> Tool | None: + def get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Tool | None: """Get tool by name.""" - if self._authorizer.permit_get_tool(name): + if self._authorizer.permit_get_tool(name, context): return self._tools.get(name) else: return None - def list_tools(self) -> list[Tool]: + def list_tools(self, context: Context[ServerSession, object, Request] | None = None) -> list[Tool]: """List all registered tools.""" - return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name)] + return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name, context)] def add_tool( self, @@ -70,7 +71,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSession, object, Request] | None = None, ) -> Any: """Call a tool by name with arguments.""" tool = self._tools.get(name) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 76aa7e5ad..a7b2f305e 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -321,7 +321,7 @@ async def double(n: int) -> int: class TestAuthorizer(Authorizer): allow: bool = True - def permit_list_tool(self, name): + def permit_list_tool(self, name, context=None): return self.allow def permit_call_tool(self, name, arguments, context=None): From 4f3a5558b3927c18436f977f38ef0a5736362c04 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 28 Jun 2025 12:11:34 +0100 Subject: [PATCH 5/6] fix name of default policy --- src/mcp/server/fastmcp/authorizer.py | 2 +- src/mcp/server/fastmcp/prompts/manager.py | 4 ++-- src/mcp/server/fastmcp/resources/resource_manager.py | 4 ++-- src/mcp/server/fastmcp/server.py | 4 ++-- src/mcp/server/fastmcp/tools/tool_manager.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py index 1e6314422..3fa4932e2 100644 --- a/src/mcp/server/fastmcp/authorizer.py +++ b/src/mcp/server/fastmcp/authorizer.py @@ -88,7 +88,7 @@ def permit_render_prompt( return False -class AllAllAuthorizer(Authorizer): +class AllowAllAuthorizer(Authorizer): def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool: return True diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 0d15a1c9a..fafe493b5 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -6,7 +6,7 @@ from starlette.requests import Request -from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.prompts.base import Message, Prompt from mcp.server.fastmcp.utilities.logging import get_logger from mcp.server.session import ServerSession @@ -23,7 +23,7 @@ class PromptManager: def __init__( self, warn_on_duplicate_prompts: bool = True, - authorizer: Authorizer = AllAllAuthorizer(), + authorizer: Authorizer = AllowAllAuthorizer(), ): self._prompts: dict[str, Prompt] = {} self._authorizer = authorizer diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 709af70e8..ea65ce81b 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -8,7 +8,7 @@ from pydantic import AnyUrl from starlette.requests import Request -from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger @@ -26,7 +26,7 @@ class ResourceManager: def __init__( self, warn_on_duplicate_resources: bool = True, - authorizer: Authorizer = AllAllAuthorizer(), + authorizer: Authorizer = AllowAllAuthorizer(), ): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 01bee20b1..d22f48748 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -33,7 +33,7 @@ from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation -from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -121,7 +121,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # Transport security settings (DNS rebinding protection) transport_security: TransportSecuritySettings | None = None - authorizer: Authorizer = AllAllAuthorizer() + authorizer: Authorizer = AllowAllAuthorizer() def lifespan_wrapper( diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 160863814..9a37594e4 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -5,7 +5,7 @@ from starlette.requests import Request -from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger @@ -26,7 +26,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, - authorizer: Authorizer = AllAllAuthorizer(), + authorizer: Authorizer = AllowAllAuthorizer(), ): self._tools: dict[str, Tool] = {} if tools is not None: From ae640eeb6cefd9bd6bbfe5184e8d156171a8bb13 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 28 Jun 2025 13:12:21 +0000 Subject: [PATCH 6/6] fixed broken merge --- src/mcp/server/fastmcp/tools/tool_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index c52fa16d4..95d2584eb 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -80,7 +80,6 @@ async def call_tool( name: str, arguments: dict[str, Any], context: Context[ServerSession, object, Request] | None = None, - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, convert_result: bool = False, ) -> Any: """Call a tool by name with arguments."""