diff --git a/examples/clients/simple-auth-client-client-credentials/README.md b/examples/clients/simple-auth-client-client-credentials/README.md new file mode 100644 index 000000000..6608f8f05 --- /dev/null +++ b/examples/clients/simple-auth-client-client-credentials/README.md @@ -0,0 +1,85 @@ +# Simple Auth Client Example + +A demonstration of how to use the MCP Python SDK with OAuth authentication using client credentials over streamable HTTP or SSE transport. +This example demonstrates integration with an authorization server that does not implement Dynamic Client Registration. + +## Features + +- OAuth 2.0 authentication with the `client_credentials` flow +- Support for both StreamableHTTP and SSE transports +- Interactive command-line interface + +## Installation + +```bash +cd examples/clients/simple-auth-client-client-credentials +uv sync --reinstall +``` + +## Usage + +### 1. Start an MCP server with OAuth support using client credentials + +```bash +# Example with mcp-simple-auth-client-credentials +cd path/to/mcp-simple-auth-client-credentials +uv run mcp-simple-auth-client-credentials --transport streamable-http --port 3001 +``` + +### 2. Run the client + +```bash +uv run mcp-simple-auth-client + +# Or with custom server URL +MCP_SERVER_PORT=3001 uv run mcp-simple-auth-client + +# Use SSE transport +MCP_TRANSPORT_TYPE=sse uv run mcp-simple-auth-client +``` + +### 3. Complete OAuth flow + +The client will automatically authenticate using dummy client credentials for the demo authorization server. After completing OAuth, you can use commands: + +- `list` - List available tools +- `call [args]` - Call a tool with optional JSON arguments +- `quit` - Exit + +## Example + +``` +šŸš€ Simple MCP Auth Client +Connecting to: http://localhost:8001/mcp +Transport type: streamable_http +šŸ”— Attempting to connect to http://localhost:8001/mcp... +šŸ“” Opening StreamableHTTP transport connection with auth... +šŸ¤ Initializing MCP session... +⚔ Starting session initialization... +✨ Session initialization complete! + +āœ… Connected to MCP server at http://localhost:8001/mcp +Session ID: ... + +šŸŽÆ Interactive MCP Client +Commands: + list - List available tools + call [args] - Call a tool + quit - Exit the client + +mcp> list +šŸ“‹ Available tools: +1. echo - Echo back the input text + +mcp> call echo {"text": "Hello, world!"} +šŸ”§ Tool 'echo' result: +Hello, world! + +mcp> quit +šŸ‘‹ Goodbye! +``` + +## Configuration + +- `MCP_SERVER_PORT` - Server URL (default: 8000) +- `MCP_TRANSPORT_TYPE` - Transport type: `streamable_http` (default) or `sse` diff --git a/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/__init__.py b/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/__init__.py new file mode 100644 index 000000000..06eb1f29d --- /dev/null +++ b/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/__init__.py @@ -0,0 +1 @@ +"""Simple OAuth client for MCP simple-auth server.""" diff --git a/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/main.py b/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/main.py new file mode 100644 index 000000000..9674dd158 --- /dev/null +++ b/examples/clients/simple-auth-client-client-credentials/mcp_simple_auth_client_client_credentials/main.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" +Simple MCP client example with OAuth authentication support. + +This client connects to an MCP server using streamable HTTP transport with OAuth. + +""" + +import asyncio +import os +from datetime import timedelta +from typing import Any + +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + +# Hardcoded credentials assuming a preconfigured client, to demonstrate +# working with an AS that does not have DCR support +MCP_CLIENT_ID = "0000000000000000000" +MCP_CLIENT_SECRET = "aaaaaaaaaaaaaaaaaaa" + + +class InMemoryTokenStorage(TokenStorage): + """Simple in-memory token storage implementation.""" + + def __init__(self, client_id: str | None, client_secret: str | None): + self._tokens: OAuthToken | None = None + self._client_info = OAuthClientInformationFull( + client_id=client_id, + client_secret=client_secret, + redirect_uris=None, + ) + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +class SimpleAuthClient: + """Simple MCP client with auth support.""" + + def __init__(self, server_url: str, transport_type: str = "streamable_http"): + self.server_url = server_url + self.transport_type = transport_type + self.session: ClientSession | None = None + + async def connect(self): + """Connect to the MCP server.""" + print(f"šŸ”— Attempting to connect to {self.server_url}...") + + try: + client_metadata_dict = { + "client_name": "Simple Auth Client", + "redirect_uris": None, + "grant_types": ["client_credentials"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + "scope": "identify", + } + + # Create OAuth authentication handler using the new interface + oauth_auth = OAuthClientProvider( + server_url=self.server_url.replace("/mcp", ""), + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), + storage=InMemoryTokenStorage( + client_id=MCP_CLIENT_ID, + client_secret=MCP_CLIENT_SECRET, + ), + ) + oauth_auth.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + ) + + # Create transport with auth handler based on transport type + if self.transport_type == "sse": + print("šŸ“” Opening SSE transport connection with auth...") + async with sse_client( + url=self.server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + await self._run_session(read_stream, write_stream, None) + else: + print("šŸ“” Opening StreamableHTTP transport connection with auth...") + async with streamablehttp_client( + url=self.server_url, + auth=oauth_auth, + timeout=timedelta(seconds=60), + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) + + except Exception as e: + print(f"āŒ Failed to connect: {e}") + import traceback + + traceback.print_exc() + + async def _run_session(self, read_stream, write_stream, get_session_id): + """Run the MCP session with the given streams.""" + print("šŸ¤ Initializing MCP session...") + async with ClientSession(read_stream, write_stream) as session: + self.session = session + print("⚔ Starting session initialization...") + await session.initialize() + print("✨ Session initialization complete!") + + print(f"\nāœ… Connected to MCP server at {self.server_url}") + if get_session_id: + session_id = get_session_id() + if session_id: + print(f"Session ID: {session_id}") + + # Run interactive loop + await self.interactive_loop() + + async def list_tools(self): + """List available tools from the server.""" + if not self.session: + print("āŒ Not connected to server") + return + + try: + result = await self.session.list_tools() + if hasattr(result, "tools") and result.tools: + print("\nšŸ“‹ Available tools:") + for i, tool in enumerate(result.tools, 1): + print(f"{i}. {tool.name}") + if tool.description: + print(f" Description: {tool.description}") + print() + else: + print("No tools available") + except Exception as e: + print(f"āŒ Failed to list tools: {e}") + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + """Call a specific tool.""" + if not self.session: + print("āŒ Not connected to server") + return + + try: + result = await self.session.call_tool(tool_name, arguments or {}) + print(f"\nšŸ”§ Tool '{tool_name}' result:") + if hasattr(result, "content"): + for content in result.content: + if content.type == "text": + print(content.text) + else: + print(content) + else: + print(result) + except Exception as e: + print(f"āŒ Failed to call tool '{tool_name}': {e}") + + async def interactive_loop(self): + """Run interactive command loop.""" + print("\nšŸŽÆ Interactive MCP Client") + print("Commands:") + print(" list - List available tools") + print(" call [args] - Call a tool") + print(" quit - Exit the client") + print() + + while True: + try: + command = input("mcp> ").strip() + + if not command: + continue + + if command == "quit": + break + + elif command == "list": + await self.list_tools() + + elif command.startswith("call "): + parts = command.split(maxsplit=2) + tool_name = parts[1] if len(parts) > 1 else "" + + if not tool_name: + print("āŒ Please specify a tool name") + continue + + # Parse arguments (simple JSON-like format) + arguments = {} + if len(parts) > 2: + import json + + try: + arguments = json.loads(parts[2]) + except json.JSONDecodeError: + print("āŒ Invalid arguments format (expected JSON)") + continue + + await self.call_tool(tool_name, arguments) + + else: + print("āŒ Unknown command. Try 'list', 'call ', or 'quit'") + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Goodbye!") + break + except EOFError: + break + + +async def main(): + """Main entry point.""" + # Default server URL - can be overridden with environment variable + # Most MCP streamable HTTP servers use /mcp as the endpoint + server_url = os.getenv("MCP_SERVER_PORT", 8000) + transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") + server_url = ( + f"http://localhost:{server_url}/mcp" + if transport_type == "streamable_http" + else f"http://localhost:{server_url}/sse" + ) + + print("šŸš€ Simple MCP Auth Client") + print(f"Connecting to: {server_url}") + print(f"Transport type: {transport_type}") + + # Start connection flow - OAuth will be handled automatically + client = SimpleAuthClient(server_url, transport_type) + await client.connect() + + +def cli(): + """CLI entry point for uv script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + cli() diff --git a/examples/clients/simple-auth-client-client-credentials/pyproject.toml b/examples/clients/simple-auth-client-client-credentials/pyproject.toml new file mode 100644 index 000000000..112e0a575 --- /dev/null +++ b/examples/clients/simple-auth-client-client-credentials/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "mcp-simple-auth-client-client-credentials" +version = "0.1.0" +description = "A simple OAuth client for the MCP simple-auth server" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic" }] +keywords = ["mcp", "oauth", "client", "auth"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0.0", "mcp"] + +[project.scripts] +mcp-simple-auth-client-client-credentials = "mcp_simple_auth_client_client_credentials.main:cli" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_auth_client_client_credentials"] + +[tool.uv] +dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/clients/simple-auth-client/pyproject.toml b/examples/clients/simple-auth-client/pyproject.toml index 5ae7c6b9d..bbe7f9b8f 100644 --- a/examples/clients/simple-auth-client/pyproject.toml +++ b/examples/clients/simple-auth-client/pyproject.toml @@ -14,10 +14,7 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", ] -dependencies = [ - "click>=8.0.0", - "mcp>=1.0.0", -] +dependencies = ["click>=8.0.0", "mcp"] [project.scripts] mcp-simple-auth-client = "mcp_simple_auth_client.main:cli" @@ -44,9 +41,3 @@ target-version = "py310" [tool.uv] dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] - -[tool.uv.sources] -mcp = { path = "../../../" } - -[[tool.uv.index]] -url = "https://pypi.org/simple" diff --git a/examples/servers/simple-auth-client-credentials/README.md b/examples/servers/simple-auth-client-credentials/README.md new file mode 100644 index 000000000..22a5203a7 --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/README.md @@ -0,0 +1,104 @@ +# MCP OAuth Authentication Demo + +This example demonstrates OAuth 2.0 authentication with the Model Context Protocol as an OAuth 2.0 Resource Server using the `client_credentials` token exchange, with +an Authorization Server that does not support Dynamic Client Registration. + +--- + +## Setup Requirements + +**Create a Discord OAuth App:** + +- Go to the [Discord Developer Portal](https://discord.com/developers/applications) > New Application +- Navigate to Settings > OAuth2 +- Note down your **Client ID** +- Reset your **Client Secret** and note it down + +**Set environment variables:** + +```bash +export MCP_DISCORD_CLIENT_ID="your_client_id_here" +export MCP_DISCORD_CLIENT_SECRET="your_client_secret_here" +``` + +--- + +## Running the Servers + +### Step 1: Start Authorization Server + +```bash +# Navigate to the simple-auth-client-credentials directory +cd examples/servers/simple-auth-client-credentials + +# Start Authorization Server on port 9000 +uv run mcp-simple-auth-as --port=9000 +``` + +**What it provides:** + +- OAuth 2.0 flows (registration, authorization, token exchange) +- Discord OAuth integration for user authentication + +--- + +### Step 2: Start Resource Server (MCP Server) + +```bash +# In another terminal, navigate to the simple-auth-client-credentials directory +cd examples/servers/simple-auth-client-credentials + +# Start Resource Server on port 8001, connected to Authorization Server +uv run mcp-simple-auth-rs --port=8001 --auth-server=http://localhost:9000 --transport=streamable-http +``` + +### Step 3: Test with Client + +```bash +cd examples/clients/simple-auth-client-client-credentials +# Start client with streamable HTTP +MCP_SERVER_PORT=8001 MCP_TRANSPORT_TYPE=streamable_http uv run mcp-simple-auth-client-client-credentials +``` + +## How It Works + +### RFC 9728 Discovery + +**Client → Resource Server:** + +```bash +curl http://localhost:8001/.well-known/oauth-protected-resource +``` + +```json +{ + "resource": "http://localhost:8001", + "authorization_servers": ["http://localhost:9000"] +} +``` + +**Client → Authorization Server:** + +```bash +curl http://localhost:9000/.well-known/oauth-authorization-server +``` + +```json +{ + "issuer": "http://localhost:9000", + "authorization_endpoint": "http://localhost:9000/authorize", + "token_endpoint": "http://localhost:9000/token" +} +``` + +## Manual Testing + +### Test Discovery + +```bash +# Test Resource Server discovery endpoint (new architecture) +curl -v http://localhost:8001/.well-known/oauth-protected-resource + +# Test Authorization Server metadata +curl -v http://localhost:9000/.well-known/oauth-authorization-server +``` diff --git a/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__init__.py b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__init__.py new file mode 100644 index 000000000..35ed549de --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__init__.py @@ -0,0 +1 @@ +"""Simple MCP server with Discord OAuth authentication over client credentials.""" diff --git a/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__main__.py b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__main__.py new file mode 100644 index 000000000..468c339b4 --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/__main__.py @@ -0,0 +1,7 @@ +"""Main entry point for simple MCP server with Discord OAuth authentication over client credentials.""" + +import sys + +from mcp_simple_auth_client_credentials.server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/auth_server.py b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/auth_server.py new file mode 100644 index 000000000..0c6475104 --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/auth_server.py @@ -0,0 +1,290 @@ +""" +Authorization Server for MCP Split Demo. + +This server handles OAuth flows, client registration, and token issuance. +Can be replaced with enterprise authorization servers like Auth0, Entra ID, etc. + +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. + +Usage: + python -m mcp_simple_auth.auth_server --port=9000 +""" + +import asyncio +import logging +import secrets +from base64 import b64decode, b64encode +from typing import Literal + +import click +from pydantic import AnyHttpUrl, BaseModel +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.applications import Starlette +from starlette.endpoints import HTTPEndpoint +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route +from starlette.types import Receive, Scope, Send +from uvicorn import Config, Server + +from mcp.server.auth.handlers.metadata import MetadataHandler +from mcp.server.auth.routes import cors_middleware +from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.auth import OAuthMetadata, OAuthToken + +logger = logging.getLogger(__name__) + +API_ENDPOINT = "https://discord.com/api/v10" + + +class DiscordOAuthSettings(BaseSettings): + """Discord OAuth settings.""" + + model_config = SettingsConfigDict(env_prefix="MCP_") + + # Discord OAuth settings - MUST be provided via environment variables + discord_client_id: str | None = None + discord_client_secret: str | None = None + + token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic" + + # Discord OAuth URL + discord_token_url: str = f"{API_ENDPOINT}/oauth2/token" + + discord_scope: str = "identify" + + +class AuthServerSettings(BaseModel): + """Settings for the Authorization Server.""" + + # Server settings + host: str = "localhost" + port: int = 9000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") + + +# Hardcoded credentials assuming a preconfigured client, to demonstrate +# working with an AS that does not have DCR support +MCP_CLIENT_ID = "0000000000000000000" +MCP_CLIENT_SECRET = "aaaaaaaaaaaaaaaaaaa" + +# Map of MCP server tokens to Discord API tokens +TOKEN_MAP: dict[str, str] = {} + + +class TokenEndpoint(HTTPEndpoint): + # Map of MCP client IDs to Discord client IDs + client_map: dict[str, str] = {} + client_credentials: dict[str, str] = {} + + discord_client_credentials: dict[str, str] = {} + + def __init__(self, scope: Scope, receive: Receive, send: Send): + super().__init__(scope, receive, send) + self.discord_settings = DiscordOAuthSettings() + + assert self.discord_settings.discord_client_id is not None, "Discord client ID not set" + assert self.discord_settings.discord_client_secret is not None, "Discord client secret not set" + + # Assume a preconfigured client ID to demonstrate working with an AS that does not have DCR support + self.client_map = { + MCP_CLIENT_ID: self.discord_settings.discord_client_id, + } + self.client_credentials = { + MCP_CLIENT_ID: MCP_CLIENT_SECRET, + } + self.discord_client_credentials = { + self.discord_settings.discord_client_id: self.discord_settings.discord_client_secret, + } + + async def post(self, request: Request) -> Response: + # Get request data (application/x-www-form-urlencoded) + data = await request.form() + + if self.discord_settings.token_endpoint_auth_method == "client_secret_basic": + # Get client_id and client_secret from Basic auth header + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Basic "): + return JSONResponse({"error": "Invalid authorization header"}, status_code=401) + auth_header_encoded = auth_header.split(" ")[1] + auth_header_raw = b64decode(auth_header_encoded).decode("utf-8") + client_id, client_secret = auth_header_raw.split(":") + else: + # Get from body + client_id = str(data.get("client_id")) + client_secret = str(data.get("client_secret")) + + # Validate MCP client + if client_id not in self.client_map: + return JSONResponse({"error": "Invalid client"}, status_code=401) + # Check if client secret matches + if client_secret != self.client_credentials[client_id]: + return JSONResponse({"error": "Invalid client secret"}, status_code=401) + + # Get mapped credentials + discord_client_id = self.client_map[client_id] + discord_client_secret = self.discord_client_credentials[discord_client_id] + + # Validate scopes + scopes = str(data.get("scope", "")).split(" ") + if not set(scopes).issubset(set(self.discord_settings.discord_scope.split(" "))): + return JSONResponse({"error": "Invalid scope"}, status_code=400) + + # Set credentials in HTTP client + headers = { + "Authorization": f"Basic {b64encode(f'{discord_client_id}:{discord_client_secret}'.encode()).decode()}" + } + + # Create HTTP client + async with create_mcp_http_client() as http_client: + # Forward request to Discord API + method = getattr(http_client, request.method.lower()) + response = await method(self.discord_settings.discord_token_url, data=data, headers=headers) + if response.status_code != 200: + body = await response.aread() + return Response(body, status_code=response.status_code, headers=response.headers) + + # Generate MCP access token + mcp_token = f"mcp_{secrets.token_hex(32)}" + + # Store mapped access token + TOKEN_MAP[mcp_token] = response.json()["access_token"] + + # Return response + return JSONResponse( + OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=response.json()["expires_in"], + scope=self.discord_settings.discord_scope, + ).model_dump(), + status_code=response.status_code, + ) + + +class DiscordAPIProxy(HTTPEndpoint): + """Proxy for Discord API.""" + + async def get(self, request: Request) -> Response: + """Proxy GET requests to Discord API.""" + return await self.handle(request) + + async def post(self, request: Request) -> Response: + """Proxy POST requests to Discord API.""" + return await self.handle(request) + + async def handle(self, request: Request) -> Response: + """Proxy requests to Discord API.""" + path = request.url.path[len("/discord") :] + query = request.url.query + + # Get access token from Authorization header + access_token = request.headers.get("Authorization", "").split(" ")[1] + if not access_token: + return JSONResponse({"error": "Missing access token"}, status_code=401) + + # Map access token to Discord access token + access_token = TOKEN_MAP.get(access_token, None) + if not access_token: + return JSONResponse({"error": "Invalid access token"}, status_code=401) + + # Set mapped access token in HTTP client + headers = {"Authorization": f"Bearer {access_token}"} + + # Create HTTP client + async with create_mcp_http_client() as http_client: + # Forward request to Discord API + response = await http_client.get(f"{API_ENDPOINT}{path}?{query}", headers=headers) + + # Return response + return JSONResponse(response.json(), status_code=response.status_code) + + +def create_authorization_server( + server_settings: AuthServerSettings, discord_settings: DiscordOAuthSettings +) -> Starlette: + """Create the Authorization Server application.""" + + routes = [ + # Create RFC 8414 authorization server metadata endpoint + Route( + "/.well-known/oauth-authorization-server", + endpoint=cors_middleware( + MetadataHandler( + metadata=OAuthMetadata( + issuer=server_settings.server_url, + authorization_endpoint=AnyHttpUrl(f"{server_settings.server_url}authorize"), + token_endpoint=AnyHttpUrl(f"{server_settings.server_url}token"), + token_endpoint_auth_methods_supported=["client_secret_post"], + response_types_supported=["code"], + grant_types_supported=["client_credentials"], + scopes_supported=[discord_settings.discord_scope], + ) + ).handle, + ["GET", "OPTIONS"], + ), + methods=["GET", "OPTIONS"], + ), + # Create OAuth 2.0 token endpoint + Route("/token", TokenEndpoint), + # Create API proxy endpoint + Route("/discord/{path:path}", DiscordAPIProxy), + ] + + return Starlette(routes=routes) + + +async def run_server(server_settings: AuthServerSettings, discord_settings: DiscordOAuthSettings): + """Run the Authorization Server.""" + auth_server = create_authorization_server(server_settings, discord_settings) + + config = Config( + auth_server, + host=server_settings.host, + port=server_settings.port, + log_level="info", + ) + server = Server(config) + + logger.info("=" * 80) + logger.info("MCP AUTHORIZATION PROXY SERVER") + logger.info("=" * 80) + logger.info(f"Server URL: {server_settings.server_url}") + logger.info("Endpoints:") + logger.info(f" - OAuth Metadata: {server_settings.server_url}.well-known/oauth-authorization-server") + logger.info(f" - Token Exchange: {server_settings.server_url}token") + logger.info(f" - Discord API Proxy: {server_settings.server_url}discord") + logger.info("") + logger.info("=" * 80) + + await server.serve() + + +@click.command() +@click.option("--port", default=9000, help="Port to listen on") +def main(port: int) -> int: + """ + Run the MCP Authorization Server. + + This server handles OAuth flows and can be used by multiple Resource Servers. + """ + logging.basicConfig(level=logging.INFO) + + # Create server settings + host = "localhost" + server_url = f"http://{host}:{port}" + server_settings = AuthServerSettings( + host=host, + port=port, + server_url=AnyHttpUrl(server_url), + ) + + discord_settings = DiscordOAuthSettings() + + asyncio.run(run_server(server_settings, discord_settings)) + return 0 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/server.py b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/server.py new file mode 100644 index 000000000..03c160bdc --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/server.py @@ -0,0 +1,195 @@ +""" +MCP Resource Server. + +Usage: + python -m mcp_simple_auth.server --port=8001 +""" + +import logging +from typing import Any, Literal + +import click +import httpx +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict + +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.settings import AuthSettings +from mcp.server.fastmcp.server import FastMCP + +from .token_verifier import PartialIntrospectionTokenVerifier + +logger = logging.getLogger(__name__) + +API_ENDPOINT = "https://discord.com/api/v10" + + +class ResourceServerSettings(BaseSettings): + """Settings for the MCP Resource Server.""" + + model_config = SettingsConfigDict(env_prefix="MCP_RESOURCE_") + + # Server settings + host: str = "localhost" + port: int = 8001 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8001") + + # Authorization Server settings + auth_server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") + auth_server_introspection_endpoint: str = "http://localhost:9000/discord/oauth2/@me" + auth_server_discord_user_endpoint: str = "http://localhost:9000/discord/users/@me" + + # MCP settings + mcp_scope: str = "identify" + + def __init__(self, **data): + """Initialize settings with values from environment variables.""" + super().__init__(**data) + + +def create_resource_server(settings: ResourceServerSettings) -> FastMCP: + """ + Create MCP Resource Server. + """ + + # Create partial token verifier + token_verifier = PartialIntrospectionTokenVerifier( + introspection_endpoint=settings.auth_server_introspection_endpoint, + server_url=str(settings.server_url), + ) + + # Create FastMCP server as a Resource Server + app = FastMCP( + name="MCP Resource Server", + host=settings.host, + port=settings.port, + debug=True, + token_verifier=token_verifier, + auth=AuthSettings( + issuer_url=settings.auth_server_url, + required_scopes=[settings.mcp_scope], + resource_server_url=settings.server_url, + ), + ) + + async def get_discord_user_data() -> dict[str, Any]: + """ + Get Discord user data via the Discord API. + """ + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + async with httpx.AsyncClient() as client: + response = await client.get( + settings.auth_server_discord_user_endpoint, + headers={ + "Authorization": f"Bearer {access_token.token}", + }, + ) + + if response.status_code != 200: + raise ValueError(f"Discord user data fetch failed: {response.status_code} - {response.text}") + + return response.json() + + @app.tool() + async def get_user_profile() -> dict[str, Any]: + """ + Get the authenticated user's Discord profile information. + """ + return await get_discord_user_data() + + @app.tool() + async def get_user_info() -> dict[str, Any]: + """ + Get information about the currently authenticated user. + + Returns token and scope information from the Resource Server's perspective. + """ + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + return { + "authenticated": True, + "client_id": access_token.client_id, + "scopes": access_token.scopes, + "token_expires_at": access_token.expires_at, + "token_type": "Bearer", + "resource_server": str(settings.server_url), + "authorization_server": str(settings.auth_server_url), + } + + return app + + +@click.command() +@click.option("--port", default=8001, help="Port to listen on") +@click.option("--auth-server", default="http://localhost:9000", help="Authorization Server URL") +@click.option( + "--transport", + default="streamable-http", + type=click.Choice(["sse", "streamable-http"]), + help="Transport protocol to use ('sse' or 'streamable-http')", +) +def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http"]) -> int: + """ + Run the MCP Resource Server. + """ + logging.basicConfig(level=logging.INFO) + + try: + # Parse auth server URL + auth_server_url = AnyHttpUrl(auth_server) + + # Create settings + host = "localhost" + server_url = f"http://{host}:{port}" + settings = ResourceServerSettings( + host=host, + port=port, + server_url=AnyHttpUrl(server_url), + auth_server_url=auth_server_url, + auth_server_introspection_endpoint=f"{auth_server_url}discord/oauth2/@me", + auth_server_discord_user_endpoint=f"{auth_server_url}discord/users/@me", + ) + except ValueError as e: + logger.error(f"Configuration error: {e}") + logger.error("Make sure to provide a valid Authorization Server URL") + return 1 + + try: + mcp_server = create_resource_server(settings) + + logger.info("=" * 80) + logger.info("šŸ“¦ MCP RESOURCE SERVER") + logger.info("=" * 80) + logger.info(f"🌐 Server URL: {settings.server_url}") + logger.info(f"šŸ”‘ Authorization Server: {settings.auth_server_url}") + logger.info("šŸ“‹ Endpoints:") + logger.info(f" ā”Œā”€ Protected Resource Metadata: {settings.server_url}.well-known/oauth-protected-resource") + mcp_path = "sse" if transport == "sse" else "mcp" + logger.info(f" ā”œā”€ MCP Protocol: {settings.server_url}{mcp_path}") + logger.info(f" └─ Token Introspection: {settings.auth_server_introspection_endpoint}") + logger.info("") + logger.info("šŸ› ļø Available Tools:") + logger.info(" ā”œā”€ get_user_profile() - Get Discord user profile") + logger.info(" └─ get_user_info() - Get authentication status") + logger.info("") + logger.info("šŸ” Tokens validated via Authorization Server introspection") + logger.info("šŸ“± Clients discover Authorization Server via Protected Resource Metadata") + logger.info("=" * 80) + + # Run the server - this should block and keep running + mcp_server.run(transport=transport) + logger.info("Server stopped") + return 0 + except Exception as e: + logger.error(f"Server error: {e}") + logger.exception("Exception details:") + return 1 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/token_verifier.py b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/token_verifier.py new file mode 100644 index 000000000..9131534e5 --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/mcp_simple_auth_client_credentials/token_verifier.py @@ -0,0 +1,68 @@ +"""Example token verifier implementation.""" + +import logging +from datetime import datetime + +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.shared.auth_utils import resource_url_from_server_url + +logger = logging.getLogger(__name__) + + +class PartialIntrospectionTokenVerifier(TokenVerifier): + """ + Example token verifier. + + Discord doesn't actually support token introspection, but this is required by FastMCP, so + we shim a non-strict verifier on top of it that leverages the "current application" endpoint. + """ + + def __init__( + self, + introspection_endpoint: str, + server_url: str, + ): + self.introspection_endpoint = introspection_endpoint + self.server_url = server_url + self.resource_url = resource_url_from_server_url(server_url) + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify token via introspection endpoint.""" + import httpx + + # Validate URL to prevent SSRF attacks + if not self.introspection_endpoint.startswith(("https://", "http://localhost", "http://127.0.0.1")): + logger.warning(f"Rejecting introspection endpoint with unsafe scheme: {self.introspection_endpoint}") + return None + + # Configure secure HTTP client + timeout = httpx.Timeout(10.0, connect=5.0) + limits = httpx.Limits(max_connections=10, max_keepalive_connections=5) + + async with httpx.AsyncClient( + timeout=timeout, + limits=limits, + verify=True, # Enforce SSL verification + headers={ + "Authorization": f"Bearer {token}", + }, + ) as client: + try: + response = await client.get( + self.introspection_endpoint, + ) + + if response.status_code != 200: + logger.debug(f"Token introspection returned status {response.status_code}") + return None + + data = response.json() + return AccessToken( + token=token, + client_id=data.get("application", {"id": "unknown"}).get("id", "unknown"), + scopes=data.get("scopes", "") if data.get("scopes") else [], + expires_at=int(datetime.fromisoformat(data.get("expires")).timestamp()), + ) + except Exception as e: + logger.warning(f"Token introspection failed: {e}") + return None diff --git a/examples/servers/simple-auth-client-credentials/pyproject.toml b/examples/servers/simple-auth-client-credentials/pyproject.toml new file mode 100644 index 000000000..8def0f7a2 --- /dev/null +++ b/examples/servers/simple-auth-client-credentials/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "mcp-simple-auth-client-credentials" +version = "0.1.0" +description = "A simple MCP server demonstrating OAuth authentication with client credentials (2LO)" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +license = { text = "MIT" } +dependencies = [ + "anyio>=4.5", + "click>=8.1.0", + "httpx>=0.27", + "mcp", + "pydantic>=2.0", + "pydantic-settings>=2.5.2", + "sse-starlette>=1.6.1", + "uvicorn>=0.23.1; sys_platform != 'emscripten'", +] + +[project.scripts] +mcp-simple-auth-rs = "mcp_simple_auth_client_credentials.server:main" +mcp-simple-auth-as = "mcp_simple_auth_client_credentials.auth_server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_auth_client_credentials"] + +[tool.uv] +dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 9ae189b84..69610ac5e 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull): """Register a new OAuth client.""" + if not client_info.client_id: + raise ValueError("No client_id provided") self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: @@ -209,6 +211,8 @@ async def exchange_authorization_code( """Exchange authorization code for tokens.""" if authorization_code.code not in self.auth_codes: raise ValueError("Invalid authorization code") + if not client.client_id: + raise ValueError("No client_id provided") # Generate MCP access token mcp_token = f"mcp_{secrets.token_hex(32)}" diff --git a/pyproject.toml b/pyproject.toml index 9b617f667..6b20736e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ target-version = "py310" "tests/server/fastmcp/test_func_metadata.py" = ["E501"] [tool.uv.workspace] -members = ["examples/servers/*"] +members = ["examples/clients/*", "examples/servers/*"] [tool.uv.sources] mcp = { workspace = true } @@ -123,5 +123,5 @@ filterwarnings = [ # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", ] diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 769e9b4c8..4fee9e4bb 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -87,8 +87,8 @@ class OAuthContext: server_url: str client_metadata: OAuthClientMetadata storage: TokenStorage - redirect_handler: Callable[[str], Awaitable[None]] - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] + redirect_handler: Callable[[str], Awaitable[None]] | None + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None timeout: float = 300.0 # Discovered metadata @@ -188,8 +188,8 @@ def __init__( server_url: str, client_metadata: OAuthClientMetadata, storage: TokenStorage, - redirect_handler: Callable[[str], Awaitable[None]], - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, ): """Initialize OAuth2 authentication.""" @@ -318,8 +318,25 @@ async def _handle_registration_response(self, response: httpx.Response) -> None: except ValidationError as e: raise OAuthRegistrationError(f"Invalid registration response: {e}") - async def _perform_authorization(self) -> tuple[str, str]: + async def _perform_authorization(self) -> httpx.Request: + """Perform the authorization flow.""" + if "client_credentials" in self.context.client_metadata.grant_types: + token_request = await self._exchange_token_client_credentials() + return token_request + else: + auth_code, code_verifier = await self._perform_authorization_code_grant() + token_request = await self._exchange_token_authorization_code(auth_code, code_verifier) + return token_request + + async def _perform_authorization_code_grant(self) -> tuple[str, str]: """Perform the authorization redirect and get auth code.""" + if self.context.client_metadata.redirect_uris is None: + raise OAuthFlowError("No redirect URIs provided for authorization code grant") + if not self.context.redirect_handler: + raise OAuthFlowError("No redirect handler provided for authorization code grant") + if not self.context.callback_handler: + raise OAuthFlowError("No callback handler provided for authorization code grant") + if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: @@ -364,8 +381,10 @@ async def _perform_authorization(self) -> tuple[str, str]: # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier - async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request: - """Build token exchange request.""" + async def _exchange_token_authorization_code(self, auth_code: str, code_verifier: str) -> httpx.Request: + """Build token exchange request for authorization_code flow.""" + if self.context.client_metadata.redirect_uris is None: + raise OAuthFlowError("No redirect URIs provided for authorization code grant") if not self.context.client_info: raise OAuthFlowError("Missing client info") @@ -394,10 +413,50 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} ) + async def _exchange_token_client_credentials(self) -> httpx.Request: + """Build token exchange request for client_credentials flow.""" + if not self.context.client_info: + raise OAuthFlowError("Missing client info") + + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data = { + "grant_type": "client_credentials", + "resource": self.context.get_resource_url(), # RFC 8707 + } + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + if self.context.client_metadata.token_endpoint_auth_method == "client_secret_post": + # Include in request body + if self.context.client_info.client_id: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret: + token_data["client_secret"] = self.context.client_info.client_secret + elif self.context.client_metadata.token_endpoint_auth_method == "client_secret_basic": + # Include as Basic auth header + if not self.context.client_info.client_id: + raise OAuthTokenError("Missing client_id in Basic auth flow") + if not self.context.client_info.client_secret: + raise OAuthTokenError("Missing client_secret in Basic auth flow") + raw_auth = f"{self.context.client_info.client_id}:{self.context.client_info.client_secret}" + headers["Authorization"] = f"Basic {base64.b64encode(raw_auth.encode()).decode()}" + + return httpx.Request("POST", token_url, data=token_data, headers=headers) + async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" if response.status_code != 200: - raise OAuthTokenError(f"Token exchange failed: {response.status_code}") + body = await response.aread() + body = body.decode("utf-8") + raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}") try: content = await response.aread() @@ -515,12 +574,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. registration_response = yield registration_request await self._handle_registration_response(registration_response) - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request + # Step 4: Perform authorization and complete token exchange + token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) except Exception as e: logger.error(f"OAuth flow error: {e}") @@ -567,12 +622,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. registration_response = yield registration_request await self._handle_registration_response(registration_response) - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request + # Step 4: Perform authorization and complete token exchange + token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) # Retry with new tokens diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 1f2d1659a..97efb019f 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -41,13 +41,11 @@ class OAuthClientMetadata(BaseModel): for the full specification. """ - redirect_uris: list[AnyUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & - # client_secret_post; - # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) + # supported auth methods for the token endpoint + token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post"] = "client_secret_post" + # supported grant_types of this implementation + grant_types: list[Literal["authorization_code", "client_credentials", "refresh_token"]] = [ "authorization_code", "refresh_token", ] @@ -81,10 +79,10 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs - if redirect_uri not in self.redirect_uris: + if self.redirect_uris is None or redirect_uri not in self.redirect_uris: raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri - elif len(self.redirect_uris) == 1: + elif self.redirect_uris is not None and len(self.redirect_uris) == 1: return self.redirect_uris[0] else: raise InvalidRedirectUriError("redirect_uri must be specified when client " "has multiple registered URIs") @@ -96,7 +94,7 @@ class OAuthClientInformationFull(OAuthClientMetadata): (client information plus metadata). """ - client_id: str + client_id: str | None = None client_secret: str | None = None client_id_issued_at: int | None = None client_secret_expires_at: int | None = None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8e6b4f54d..fe30b08cc 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2,7 +2,10 @@ Tests for refactored OAuth client authentication implementation. """ +import base64 import time +import urllib +import urllib.parse import httpx import pytest @@ -387,7 +390,7 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto assert request is None @pytest.mark.anyio - async def test_token_exchange_request(self, oauth_provider): + async def test_token_exchange_request_authorization_code(self, oauth_provider): """Test token exchange request building.""" # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( @@ -396,7 +399,7 @@ async def test_token_exchange_request(self, oauth_provider): redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - request = await oauth_provider._exchange_token("test_auth_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" @@ -410,6 +413,65 @@ async def test_token_exchange_request(self, oauth_provider): assert "client_id=test_client" in content assert "client_secret=test_secret" in content + @pytest.mark.anyio + async def test_token_exchange_request_client_credentials_basic(self, oauth_provider): + """Test token exchange request building.""" + # Set up required context + oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull( + grant_types=["client_credentials"], + token_endpoint_auth_method="client_secret_basic", + client_id="test_client", + client_secret="test_secret", + redirect_uris=None, + scope="read write", + ) + + request = await oauth_provider._exchange_token_client_credentials() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Check form data + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + assert "client_id=test_client" not in content + assert "client_secret=test_secret" not in content + + # Check auth header + assert "Authorization" in request.headers + assert request.headers["Authorization"].startswith("Basic ") + assert base64.b64decode(request.headers["Authorization"].split(" ")[1]).decode() == "test_client:test_secret" + + @pytest.mark.anyio + async def test_token_exchange_request_client_credentials_post(self, oauth_provider): + """Test token exchange request building.""" + # Set up required context + oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull( + grant_types=["client_credentials"], + token_endpoint_auth_method="client_secret_post", + client_id="test_client", + client_secret="test_secret", + redirect_uris=None, + scope="read write", + ) + + request = await oauth_provider._exchange_token_client_credentials() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Check form data + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + assert "client_id=test_client" in content + assert "client_secret=test_secret" in content + @pytest.mark.anyio async def test_refresh_token_request(self, oauth_provider, valid_tokens): """Test refresh token request building.""" @@ -450,7 +512,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content # Check URL-encoded resource parameter @@ -481,7 +543,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" not in content @@ -511,7 +573,7 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 5db5d58c2..3f902db92 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -50,6 +50,7 @@ async def register_client(self, client_info: OAuthClientInformationFull): async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # toy authorize implementation which just immediately generates an authorization # code and completes the redirect + assert client.client_id is not None code = AuthorizationCode( code=f"code_{int(time.time())}", client_id=client.client_id, @@ -78,6 +79,7 @@ async def exchange_authorization_code( refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens + assert client.client_id is not None self.tokens[access_token] = AccessToken( token=access_token, client_id=client.client_id, @@ -139,6 +141,7 @@ async def exchange_refresh_token( new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens + assert client.client_id is not None self.tokens[new_access_token] = AccessToken( token=new_access_token, client_id=client.client_id, diff --git a/uv.lock b/uv.lock index cfcc8238e..45d3799b6 100644 --- a/uv.lock +++ b/uv.lock @@ -9,6 +9,10 @@ resolution-mode = "lowest-direct" members = [ "mcp", "mcp-simple-auth", + "mcp-simple-auth-client", + "mcp-simple-auth-client-client-credentials", + "mcp-simple-auth-client-credentials", + "mcp-simple-chatbot", "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", @@ -678,6 +682,138 @@ dev = [ { name = "ruff", specifier = ">=0.8.5" }, ] +[[package]] +name = "mcp-simple-auth-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-auth-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-auth-client-client-credentials" +version = "0.1.0" +source = { editable = "examples/clients/simple-auth-client-client-credentials" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-auth-client-credentials" +version = "0.1.0" +source = { editable = "examples/servers/simple-auth-client-credentials" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "pydantic", specifier = ">=2.0" }, + { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "sse-starlette", specifier = ">=1.6.1" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "ruff", specifier = ">=0.8.5" }, +] + +[[package]] +name = "mcp-simple-chatbot" +version = "0.1.0" +source = { editable = "examples/clients/simple-chatbot" } +dependencies = [ + { name = "mcp" }, + { name = "python-dotenv" }, + { name = "requests" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "mcp", editable = "." }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "requests", specifier = ">=2.31.0" }, + { name = "uvicorn", specifier = ">=0.32.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-prompt" version = "0.1.0" @@ -1858,16 +1994,16 @@ wheels = [ [[package]] name = "uvicorn" -version = "0.30.0" +version = "0.32.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d3/f7/4ad826703a49b320a4adf2470fdd2a3481ea13f4460cb615ad12c75be003/uvicorn-0.30.0.tar.gz", hash = "sha256:f678dec4fa3a39706bbf49b9ec5fc40049d42418716cea52b53f07828a60aa37", size = 42560, upload-time = "2024-05-28T07:20:42.231Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/3c/21dba3e7d76138725ef307e3d7ddd29b763119b3aa459d02cc05fefcff75/uvicorn-0.32.1.tar.gz", hash = "sha256:ee9519c246a72b1c084cea8d3b44ed6026e78a4a309cbedae9c37e4cb9fbb175", size = 77630, upload-time = "2024-11-20T19:41:13.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/a1/d57e38417a8dabb22df02b6aebc209dc73485792e6c5620e501547133d0b/uvicorn-0.30.0-py3-none-any.whl", hash = "sha256:78fa0b5f56abb8562024a59041caeb555c86e48d0efdd23c3fe7de7a4075bdab", size = 62388, upload-time = "2024-05-28T07:20:38.256Z" }, + { url = "https://files.pythonhosted.org/packages/50/c1/2d27b0a15826c2b71dcf6e2f5402181ef85acf439617bb2f1453125ce1f3/uvicorn-0.32.1-py3-none-any.whl", hash = "sha256:82ad92fd58da0d12af7482ecdb5f2470a04c9c9a53ced65b9bbb4a205377602e", size = 63828, upload-time = "2024-11-20T19:41:11.244Z" }, ] [[package]]