diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 0e6d18364..1f211bd19 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -3,6 +3,9 @@ name: Shared Checks on: workflow_call: +permissions: + contents: read + jobs: pre-commit: runs-on: ubuntu-latest @@ -46,3 +49,19 @@ jobs: - name: Run pytest run: uv run --frozen --no-sync pytest continue-on-error: true + + readme-snippets: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + version: 0.7.2 + + - name: Install dependencies + run: uv sync --frozen --all-extras --python 3.10 + + - name: Check README snippets are up to date + run: uv run --frozen scripts/update_readme_snippets.py --check diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 35e12261a..6eb8fc4ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,6 +23,7 @@ repos: types: [python] language: system pass_filenames: false + exclude: ^README\.md$ - id: pyright name: pyright entry: uv run pyright @@ -36,3 +37,9 @@ repos: language: system files: ^(pyproject\.toml|uv\.lock)$ pass_filenames: false + - id: readme-snippets + name: Check README snippets are up to date + entry: uv run scripts/update_readme_snippets.py --check + language: system + files: ^(README\.md|examples/.*\.py|scripts/update_readme_snippets\.py)$ + pass_filenames: false diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 929e5f504..a263678a2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,7 +40,12 @@ uv run ruff check . uv run ruff format . ``` -7. Submit a pull request to the same branch you branched from +7. Update README snippets if you modified example code: +```bash +uv run scripts/update_readme_snippets.py +``` + +8. Submit a pull request to the same branch you branched from ## Code Style diff --git a/README.md b/README.md index 412143a9f..c59bad11c 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ - [Context](#context) - [Completions](#completions) - [Elicitation](#elicitation) + - [Sampling](#sampling) + - [Logging and Notifications](#logging-and-notifications) - [Authentication](#authentication) - [Running Your Server](#running-your-server) - [Development Mode](#development-mode) @@ -119,7 +121,7 @@ mcp = FastMCP("Demo") # Add an addition tool @mcp.tool() -def add(a: int, b: int) -> int: +def sum(a: int, b: int) -> int: """Add two numbers""" return a + b @@ -207,48 +209,57 @@ def query_db() -> str: Resources are how you expose data to LLMs. They're similar to GET endpoints in a REST API - they provide data but shouldn't perform significant computation or have side effects: + ```python from mcp.server.fastmcp import FastMCP -mcp = FastMCP("My App") +mcp = FastMCP(name="Resource Example") -@mcp.resource("config://app", title="Application Configuration") -def get_config() -> str: - """Static configuration data""" - return "App configuration here" +@mcp.resource("file://documents/{name}") +def read_document(name: str) -> str: + """Read a document by name.""" + # This would normally read from disk + return f"Content of {name}" -@mcp.resource("users://{user_id}/profile", title="User Profile") -def get_user_profile(user_id: str) -> str: - """Dynamic user data""" - return f"Profile data for user {user_id}" +@mcp.resource("config://settings") +def get_settings() -> str: + """Get application settings.""" + return """{ + "theme": "dark", + "language": "en", + "debug": false +}""" ``` +_Full example: [examples/snippets/servers/basic_resource.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_resource.py)_ + ### Tools Tools let LLMs take actions through your server. Unlike resources, tools are expected to perform computation and have side effects: + ```python -import httpx from mcp.server.fastmcp import FastMCP -mcp = FastMCP("My App") +mcp = FastMCP(name="Tool Example") -@mcp.tool(title="BMI Calculator") -def calculate_bmi(weight_kg: float, height_m: float) -> float: - """Calculate BMI given weight in kg and height in meters""" - return weight_kg / (height_m**2) +@mcp.tool() +def sum(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b -@mcp.tool(title="Weather Fetcher") -async def fetch_weather(city: str) -> str: - """Fetch current weather for a city""" - async with httpx.AsyncClient() as client: - response = await client.get(f"https://api.weather.com/{city}") - return response.text +@mcp.tool() +def get_weather(city: str, unit: str = "celsius") -> str: + """Get weather for a city.""" + # This would normally call a weather API + return f"Weather in {city}: 22degrees{unit[0].upper()}" ``` +_Full example: [examples/snippets/servers/basic_tool.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_tool.py)_ + #### Structured Output @@ -375,11 +386,12 @@ def get_temperature(city: str) -> float: Prompts are reusable templates that help LLMs interact with your server effectively: + ```python from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.prompts import base -mcp = FastMCP("My App") +mcp = FastMCP(name="Prompt Example") @mcp.prompt(title="Code Review") @@ -395,6 +407,8 @@ def debug_error(error: str) -> list[base.Message]: base.AssistantMessage("I'll help debug that. What have you tried so far?"), ] ``` +_Full example: [examples/snippets/servers/basic_prompt.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_prompt.py)_ + ### Images @@ -419,21 +433,31 @@ def create_thumbnail(image_path: str) -> Image: The Context object gives your tools and resources access to MCP capabilities: + ```python -from mcp.server.fastmcp import FastMCP, Context +from mcp.server.fastmcp import Context, FastMCP -mcp = FastMCP("My App") +mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_task(files: list[str], ctx: Context) -> str: - """Process multiple files with progress tracking""" - for i, file in enumerate(files): - ctx.info(f"Processing {file}") - await ctx.report_progress(i, len(files)) - data, mime_type = await ctx.read_resource(f"file://{file}") - return "Processing complete" +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: + """Execute a task with progress updates.""" + await ctx.info(f"Starting: {task_name}") + + for i in range(steps): + progress = (i + 1) / steps + await ctx.report_progress( + progress=progress, + total=1.0, + message=f"Step {i + 1}/{steps}", + ) + await ctx.debug(f"Completed step {i + 1}") + + return f"Task '{task_name}' completed" ``` +_Full example: [examples/snippets/servers/tool_progress.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/tool_progress.py)_ + ### Completions @@ -465,8 +489,10 @@ async def use_completion(session: ClientSession): ``` Server implementation: + + ```python -from mcp.server import Server +from mcp.server.fastmcp import FastMCP from mcp.types import ( Completion, CompletionArgument, @@ -475,72 +501,167 @@ from mcp.types import ( ResourceTemplateReference, ) -server = Server("example-server") +mcp = FastMCP(name="Example") + +@mcp.resource("github://repos/{owner}/{repo}") +def github_repo(owner: str, repo: str) -> str: + """GitHub repository resource.""" + return f"Repository: {owner}/{repo}" -@server.completion() + +@mcp.prompt(description="Code review prompt") +def review_code(language: str, code: str) -> str: + """Generate a code review.""" + return f"Review this {language} code:\n{code}" + + +@mcp.completion() async def handle_completion( ref: PromptReference | ResourceTemplateReference, argument: CompletionArgument, context: CompletionContext | None, ) -> Completion | None: + """Provide completions for prompts and resources.""" + + # Complete programming languages for the prompt + if isinstance(ref, PromptReference): + if ref.name == "review_code" and argument.name == "language": + languages = ["python", "javascript", "typescript", "go", "rust"] + return Completion( + values=[lang for lang in languages if lang.startswith(argument.value)], + hasMore=False, + ) + + # Complete repository names for GitHub resources if isinstance(ref, ResourceTemplateReference): if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo": - # Use context to provide owner-specific repos - if context and context.arguments: - owner = context.arguments.get("owner") - if owner == "modelcontextprotocol": - repos = ["python-sdk", "typescript-sdk", "specification"] - # Filter based on partial input - filtered = [r for r in repos if r.startswith(argument.value)] - return Completion(values=filtered) + if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol": + repos = ["python-sdk", "typescript-sdk", "specification"] + return Completion(values=repos, hasMore=False) + return None ``` +_Full example: [examples/snippets/servers/completion.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/completion.py)_ + ### Elicitation Request additional information from users during tool execution: + ```python -from mcp.server.fastmcp import FastMCP, Context -from mcp.server.elicitation import ( - AcceptedElicitation, - DeclinedElicitation, - CancelledElicitation, -) from pydantic import BaseModel, Field -mcp = FastMCP("Booking System") +from mcp.server.fastmcp import Context, FastMCP +mcp = FastMCP(name="Elicitation Example") -@mcp.tool() -async def book_table(date: str, party_size: int, ctx: Context) -> str: - """Book a table with confirmation""" - # Schema must only contain primitive types (str, int, float, bool) - class ConfirmBooking(BaseModel): - confirm: bool = Field(description="Confirm booking?") - notes: str = Field(default="", description="Special requests") +class BookingPreferences(BaseModel): + """Schema for collecting user preferences.""" - result = await ctx.elicit( - message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking + checkAlternative: bool = Field(description="Would you like to check another date?") + alternativeDate: str = Field( + default="2024-12-26", + description="Alternative date (YYYY-MM-DD)", ) - match result: - case AcceptedElicitation(data=data): - if data.confirm: - return f"Booked! Notes: {data.notes or 'None'}" - return "Booking cancelled" - case DeclinedElicitation(): - return "Booking declined" - case CancelledElicitation(): - return "Booking cancelled" + +@mcp.tool() +async def book_table( + date: str, + time: str, + party_size: int, + ctx: Context, +) -> str: + """Book a table with date availability check.""" + # Check if date is available + if date == "2024-12-25": + # Date unavailable - ask user for alternative + result = await ctx.elicit( + message=(f"No tables available for {party_size} on {date}. Would you like to try another date?"), + schema=BookingPreferences, + ) + + if result.action == "accept" and result.data: + if result.data.checkAlternative: + return f"[SUCCESS] Booked for {result.data.alternativeDate}" + return "[CANCELLED] No booking made" + return "[CANCELLED] Booking cancelled" + + # Date available + return f"[SUCCESS] Booked for {date} at {time}" ``` +_Full example: [examples/snippets/servers/elicitation.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/elicitation.py)_ + The `elicit()` method returns an `ElicitationResult` with: - `action`: "accept", "decline", or "cancel" - `data`: The validated response (only when accepted) - `validation_error`: Any validation error message +### Sampling + +Tools can interact with LLMs through sampling (generating text): + + +```python +from mcp.server.fastmcp import Context, FastMCP +from mcp.types import SamplingMessage, TextContent + +mcp = FastMCP(name="Sampling Example") + + +@mcp.tool() +async def generate_poem(topic: str, ctx: Context) -> str: + """Generate a poem using LLM sampling.""" + prompt = f"Write a short poem about {topic}" + + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=prompt), + ) + ], + max_tokens=100, + ) + + if result.content.type == "text": + return result.content.text + return str(result.content) +``` +_Full example: [examples/snippets/servers/sampling.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/sampling.py)_ + + +### Logging and Notifications + +Tools can send logs and notifications through the context: + + +```python +from mcp.server.fastmcp import Context, FastMCP + +mcp = FastMCP(name="Notifications Example") + + +@mcp.tool() +async def process_data(data: str, ctx: Context) -> str: + """Process data with logging.""" + # Different log levels + await ctx.debug(f"Debug: Processing '{data}'") + await ctx.info("Info: Starting processing") + await ctx.warning("Warning: This is experimental") + await ctx.error("Error: (This is just a demo)") + + # Notify about resource changes + await ctx.session.send_resource_list_changed() + + return f"Processed: {data}" +``` +_Full example: [examples/snippets/servers/notifications.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/notifications.py)_ + + ### Authentication Authentication can be used by servers that want to expose tools accessing protected resources. @@ -664,8 +785,9 @@ from mcp.server.fastmcp import FastMCP mcp = FastMCP(name="EchoServer", stateless_http=True) -@mcp.tool(description="A simple echo tool") +@mcp.tool() def echo(message: str) -> str: + """A simple echo tool""" return f"Echo: {message}" ``` @@ -676,8 +798,9 @@ from mcp.server.fastmcp import FastMCP mcp = FastMCP(name="MathServer", stateless_http=True) -@mcp.tool(description="A simple add tool") +@mcp.tool() def add_two(n: int) -> int: + """Tool to add two to the input""" return n + 2 ``` diff --git a/examples/fastmcp/desktop.py b/examples/fastmcp/desktop.py index 8fd71b263..add7f515b 100644 --- a/examples/fastmcp/desktop.py +++ b/examples/fastmcp/desktop.py @@ -20,6 +20,6 @@ def desktop() -> list[str]: @mcp.tool() -def add(a: int, b: int) -> int: +def sum(a: int, b: int) -> int: """Add two numbers""" return a + b diff --git a/examples/fastmcp/readme-quickstart.py b/examples/fastmcp/readme-quickstart.py index d1c522a81..e1abf7c51 100644 --- a/examples/fastmcp/readme-quickstart.py +++ b/examples/fastmcp/readme-quickstart.py @@ -6,7 +6,7 @@ # Add an addition tool @mcp.tool() -def add(a: int, b: int) -> int: +def sum(a: int, b: int) -> int: """Add two numbers""" return a + b diff --git a/examples/fastmcp/weather_structured.py b/examples/fastmcp/weather_structured.py index 8c26fc39e..20cbf7957 100644 --- a/examples/fastmcp/weather_structured.py +++ b/examples/fastmcp/weather_structured.py @@ -90,7 +90,7 @@ def get_weather_alerts(region: str) -> list[WeatherAlert]: WeatherAlert( severity="high", title="Heat Wave Warning", - description="Temperatures expected to exceed 40°C", + description="Temperatures expected to exceed 40 degrees", affected_areas=["Los Angeles", "San Diego", "Riverside"], valid_until=datetime(2024, 7, 15, 18, 0), ), diff --git a/examples/snippets/pyproject.toml b/examples/snippets/pyproject.toml new file mode 100644 index 000000000..832f6495b --- /dev/null +++ b/examples/snippets/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "mcp-snippets" +version = "0.1.0" +description = "MCP Example Snippets" +requires-python = ">=3.10" +dependencies = [ + "mcp", +] + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["servers"] + +[project.scripts] +server = "servers:run_server" \ No newline at end of file diff --git a/examples/snippets/servers/__init__.py b/examples/snippets/servers/__init__.py new file mode 100644 index 000000000..e3b778420 --- /dev/null +++ b/examples/snippets/servers/__init__.py @@ -0,0 +1,36 @@ +"""MCP Snippets. + +This package contains simple examples of MCP server features. +Each server demonstrates a single feature and can be run as a standalone server. + +To run a server, use the command: + uv run server basic_tool sse +""" + +import importlib +import sys +from typing import Literal, cast + + +def run_server(): + """Run a server by name with optional transport. + + Usage: server [transport] + Example: server basic_tool sse + """ + if len(sys.argv) < 2: + print("Usage: server [transport]") + print("Available servers: basic_tool, basic_resource, basic_prompt, tool_progress,") + print(" sampling, elicitation, completion, notifications") + print("Available transports: stdio (default), sse, streamable-http") + sys.exit(1) + + server_name = sys.argv[1] + transport = sys.argv[2] if len(sys.argv) > 2 else "stdio" + + try: + module = importlib.import_module(f".{server_name}", package=__name__) + module.mcp.run(cast(Literal["stdio", "sse", "streamable-http"], transport)) + except ImportError: + print(f"Error: Server '{server_name}' not found") + sys.exit(1) diff --git a/examples/snippets/servers/basic_prompt.py b/examples/snippets/servers/basic_prompt.py new file mode 100644 index 000000000..40f606ba6 --- /dev/null +++ b/examples/snippets/servers/basic_prompt.py @@ -0,0 +1,18 @@ +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.prompts import base + +mcp = FastMCP(name="Prompt Example") + + +@mcp.prompt(title="Code Review") +def review_code(code: str) -> str: + return f"Please review this code:\n\n{code}" + + +@mcp.prompt(title="Debug Assistant") +def debug_error(error: str) -> list[base.Message]: + return [ + base.UserMessage("I'm seeing this error:"), + base.UserMessage(error), + base.AssistantMessage("I'll help debug that. What have you tried so far?"), + ] diff --git a/examples/snippets/servers/basic_resource.py b/examples/snippets/servers/basic_resource.py new file mode 100644 index 000000000..5c1973059 --- /dev/null +++ b/examples/snippets/servers/basic_resource.py @@ -0,0 +1,20 @@ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="Resource Example") + + +@mcp.resource("file://documents/{name}") +def read_document(name: str) -> str: + """Read a document by name.""" + # This would normally read from disk + return f"Content of {name}" + + +@mcp.resource("config://settings") +def get_settings() -> str: + """Get application settings.""" + return """{ + "theme": "dark", + "language": "en", + "debug": false +}""" diff --git a/examples/snippets/servers/basic_tool.py b/examples/snippets/servers/basic_tool.py new file mode 100644 index 000000000..550e24080 --- /dev/null +++ b/examples/snippets/servers/basic_tool.py @@ -0,0 +1,16 @@ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="Tool Example") + + +@mcp.tool() +def sum(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +@mcp.tool() +def get_weather(city: str, unit: str = "celsius") -> str: + """Get weather for a city.""" + # This would normally call a weather API + return f"Weather in {city}: 22degrees{unit[0].upper()}" diff --git a/examples/snippets/servers/completion.py b/examples/snippets/servers/completion.py new file mode 100644 index 000000000..2a31541dd --- /dev/null +++ b/examples/snippets/servers/completion.py @@ -0,0 +1,49 @@ +from mcp.server.fastmcp import FastMCP +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + PromptReference, + ResourceTemplateReference, +) + +mcp = FastMCP(name="Example") + + +@mcp.resource("github://repos/{owner}/{repo}") +def github_repo(owner: str, repo: str) -> str: + """GitHub repository resource.""" + return f"Repository: {owner}/{repo}" + + +@mcp.prompt(description="Code review prompt") +def review_code(language: str, code: str) -> str: + """Generate a code review.""" + return f"Review this {language} code:\n{code}" + + +@mcp.completion() +async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, +) -> Completion | None: + """Provide completions for prompts and resources.""" + + # Complete programming languages for the prompt + if isinstance(ref, PromptReference): + if ref.name == "review_code" and argument.name == "language": + languages = ["python", "javascript", "typescript", "go", "rust"] + return Completion( + values=[lang for lang in languages if lang.startswith(argument.value)], + hasMore=False, + ) + + # Complete repository names for GitHub resources + if isinstance(ref, ResourceTemplateReference): + if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo": + if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol": + repos = ["python-sdk", "typescript-sdk", "specification"] + return Completion(values=repos, hasMore=False) + + return None diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py new file mode 100644 index 000000000..6d150cd6c --- /dev/null +++ b/examples/snippets/servers/elicitation.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, Field + +from mcp.server.fastmcp import Context, FastMCP + +mcp = FastMCP(name="Elicitation Example") + + +class BookingPreferences(BaseModel): + """Schema for collecting user preferences.""" + + checkAlternative: bool = Field(description="Would you like to check another date?") + alternativeDate: str = Field( + default="2024-12-26", + description="Alternative date (YYYY-MM-DD)", + ) + + +@mcp.tool() +async def book_table( + date: str, + time: str, + party_size: int, + ctx: Context, +) -> str: + """Book a table with date availability check.""" + # Check if date is available + if date == "2024-12-25": + # Date unavailable - ask user for alternative + result = await ctx.elicit( + message=(f"No tables available for {party_size} on {date}. Would you like to try another date?"), + schema=BookingPreferences, + ) + + if result.action == "accept" and result.data: + if result.data.checkAlternative: + return f"[SUCCESS] Booked for {result.data.alternativeDate}" + return "[CANCELLED] No booking made" + return "[CANCELLED] Booking cancelled" + + # Date available + return f"[SUCCESS] Booked for {date} at {time}" diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py new file mode 100644 index 000000000..96f0bc141 --- /dev/null +++ b/examples/snippets/servers/notifications.py @@ -0,0 +1,18 @@ +from mcp.server.fastmcp import Context, FastMCP + +mcp = FastMCP(name="Notifications Example") + + +@mcp.tool() +async def process_data(data: str, ctx: Context) -> str: + """Process data with logging.""" + # Different log levels + await ctx.debug(f"Debug: Processing '{data}'") + await ctx.info("Info: Starting processing") + await ctx.warning("Warning: This is experimental") + await ctx.error("Error: (This is just a demo)") + + # Notify about resource changes + await ctx.session.send_resource_list_changed() + + return f"Processed: {data}" diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py new file mode 100644 index 000000000..230b15fcf --- /dev/null +++ b/examples/snippets/servers/sampling.py @@ -0,0 +1,24 @@ +from mcp.server.fastmcp import Context, FastMCP +from mcp.types import SamplingMessage, TextContent + +mcp = FastMCP(name="Sampling Example") + + +@mcp.tool() +async def generate_poem(topic: str, ctx: Context) -> str: + """Generate a poem using LLM sampling.""" + prompt = f"Write a short poem about {topic}" + + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=prompt), + ) + ], + max_tokens=100, + ) + + if result.content.type == "text": + return result.content.text + return str(result.content) diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py new file mode 100644 index 000000000..d62e62dd1 --- /dev/null +++ b/examples/snippets/servers/tool_progress.py @@ -0,0 +1,20 @@ +from mcp.server.fastmcp import Context, FastMCP + +mcp = FastMCP(name="Progress Example") + + +@mcp.tool() +async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: + """Execute a task with progress updates.""" + await ctx.info(f"Starting: {task_name}") + + for i in range(steps): + progress = (i + 1) / steps + await ctx.report_progress( + progress=progress, + total=1.0, + message=f"Step {i + 1}/{steps}", + ) + await ctx.debug(f"Completed step {i + 1}") + + return f"Task '{task_name}' completed" diff --git a/pyproject.toml b/pyproject.toml index 9b617f667..a4de04266 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,13 +99,14 @@ ignore = ["PERF203"] [tool.ruff] line-length = 120 target-version = "py310" +extend-exclude = ["README.md"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] "tests/server/fastmcp/test_func_metadata.py" = ["E501"] [tool.uv.workspace] -members = ["examples/servers/*"] +members = ["examples/servers/*", "examples/snippets"] [tool.uv.sources] mcp = { workspace = true } diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py new file mode 100755 index 000000000..601d176e8 --- /dev/null +++ b/scripts/update_readme_snippets.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +""" +Update README.md with live code snippets from example files. + +This script finds specially marked code blocks in README.md and updates them +with the actual code from the referenced files. + +Usage: + python scripts/update_readme_snippets.py + python scripts/update_readme_snippets.py --check # Check mode for CI +""" + +import argparse +import re +import sys +from pathlib import Path + + +def get_github_url(file_path: str) -> str: + """Generate a GitHub URL for the file. + + Args: + file_path: Path to the file relative to repo root + + Returns: + GitHub URL + """ + base_url = "https://github.com/modelcontextprotocol/python-sdk/blob/main" + return f"{base_url}/{file_path}" + + +def process_snippet_block(match: re.Match, check_mode: bool = False) -> str: + """Process a single snippet-source block. + + Args: + match: The regex match object + check_mode: If True, return original if no changes needed + + Returns: + The updated block content + """ + full_match = match.group(0) + indent = match.group(1) + file_path = match.group(2) + + try: + # Read the entire file + file = Path(file_path) + if not file.exists(): + print(f"Warning: File not found: {file_path}") + return full_match + + code = file.read_text().rstrip() + github_url = get_github_url(file_path) + + # Build the replacement block + indented_code = code.replace("\n", f"\n{indent}") + replacement = f"""{indent} +{indent}```python +{indent}{indented_code} +{indent}``` +{indent}_Full example: [{file_path}]({github_url})_ +{indent}""" + + # In check mode, only check if code has changed + if check_mode: + # Extract existing code from the match + existing_content = match.group(3) + if existing_content is not None: + existing_lines = existing_content.strip().split("\n") + # Find code between ```python and ``` + code_lines = [] + in_code = False + for line in existing_lines: + if line.strip() == "```python": + in_code = True + elif line.strip() == "```": + break + elif in_code: + code_lines.append(line) + existing_code = "\n".join(code_lines).strip() + # Compare with the indented version we would generate + expected_code = code.replace("\n", f"\n{indent}").strip() + if existing_code == expected_code: + return full_match + + return replacement + + except Exception as e: + print(f"Error processing {file_path}: {e}") + return full_match + + +def update_readme_snippets(readme_path: Path = Path("README.md"), check_mode: bool = False) -> bool: + """Update code snippets in README.md with live code from source files. + + Args: + readme_path: Path to the README file + check_mode: If True, only check if updates are needed without modifying + + Returns: + True if file is up to date or was updated, False if check failed + """ + if not readme_path.exists(): + print(f"Error: README file not found: {readme_path}") + return False + + content = readme_path.read_text() + original_content = content + + # Pattern to match snippet-source blocks + # Matches: + # ... any content ... + # + pattern = r"^(\s*)\n" r"(.*?)" r"^\1" + + # Process all snippet-source blocks + updated_content = re.sub( + pattern, lambda m: process_snippet_block(m, check_mode), content, flags=re.MULTILINE | re.DOTALL + ) + + if check_mode: + if updated_content != original_content: + print( + f"Error: {readme_path} has outdated code snippets. " + "Run 'python scripts/update_readme_snippets.py' to update." + ) + return False + else: + print(f"✓ {readme_path} code snippets are up to date") + return True + else: + if updated_content != original_content: + readme_path.write_text(updated_content) + print(f"✓ Updated {readme_path}") + else: + print(f"✓ {readme_path} already up to date") + return True + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Update README code snippets from source files") + parser.add_argument( + "--check", action="store_true", help="Check mode - verify snippets are up to date without modifying" + ) + parser.add_argument("--readme", default="README.md", help="Path to README file (default: README.md)") + + args = parser.parse_args() + + success = update_readme_snippets(Path(args.readme), check_mode=args.check) + + if not success: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index f13c76fa9..a1620ca17 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -1,8 +1,8 @@ """ Integration tests for FastMCP server functionality. -These tests validate the proper functioning of FastMCP in various configurations, -including with and without authentication. +These tests validate the proper functioning of FastMCP features using focused, +single-feature servers across different transports (SSE and StreamableHTTP). """ import json @@ -10,38 +10,33 @@ import socket import time from collections.abc import Generator -from typing import Any import pytest import uvicorn -from pydantic import AnyUrl, BaseModel, Field -from starlette.applications import Starlette -from starlette.requests import Request - +from pydantic import AnyUrl + +from examples.snippets.servers import ( + basic_prompt, + basic_resource, + basic_tool, + completion, + elicitation, + notifications, + sampling, + tool_progress, +) from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.fastmcp.resources import FunctionResource -from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import RequestContext from mcp.types import ( - Completion, - CompletionArgument, - CompletionContext, - CreateMessageRequestParams, CreateMessageResult, ElicitResult, GetPromptResult, InitializeResult, LoggingMessageNotification, ProgressNotification, - PromptReference, ReadResourceResult, - ResourceLink, ResourceListChangedNotification, - ResourceTemplateReference, - SamplingMessage, ServerNotification, TextContent, TextResourceContents, @@ -49,6 +44,29 @@ ) +class NotificationCollector: + """Collects notifications from the server for testing.""" + + def __init__(self): + self.progress_notifications: list = [] + self.log_messages: list = [] + self.resource_notifications: list = [] + self.tool_notifications: list = [] + + async def handle_generic_notification(self, message) -> None: + """Handle any server notification and route to appropriate handler.""" + if isinstance(message, ServerNotification): + if isinstance(message.root, ProgressNotification): + self.progress_notifications.append(message.root.params) + elif isinstance(message.root, LoggingMessageNotification): + self.log_messages.append(message.root.params) + elif isinstance(message.root, ResourceListChangedNotification): + self.resource_notifications.append(message.root.params) + elif isinstance(message.root, ToolListChangedNotification): + self.tool_notifications.append(message.root.params) + + +# Common fixtures @pytest.fixture def server_port() -> int: """Get a free port for testing.""" @@ -63,396 +81,64 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -@pytest.fixture -def http_server_port() -> int: - """Get a free port for testing the StreamableHTTP server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def http_server_url(http_server_port: int) -> str: - """Get the StreamableHTTP server URL for testing.""" - return f"http://127.0.0.1:{http_server_port}" +def run_server_with_transport(module_name: str, port: int, transport: str) -> None: + """Run server with specified transport.""" + # Get the MCP instance based on module name + if module_name == "basic_tool": + mcp = basic_tool.mcp + elif module_name == "basic_resource": + mcp = basic_resource.mcp + elif module_name == "basic_prompt": + mcp = basic_prompt.mcp + elif module_name == "tool_progress": + mcp = tool_progress.mcp + elif module_name == "sampling": + mcp = sampling.mcp + elif module_name == "elicitation": + mcp = elicitation.mcp + elif module_name == "completion": + mcp = completion.mcp + elif module_name == "notifications": + mcp = notifications.mcp + else: + raise ImportError(f"Unknown module: {module_name}") + # Create app based on transport type + if transport == "sse": + app = mcp.sse_app() + elif transport == "streamable-http": + app = mcp.streamable_http_app() + else: + raise ValueError(f"Invalid transport for test server: {transport}") -@pytest.fixture -def stateless_http_server_port() -> int: - """Get a free port for testing the stateless StreamableHTTP server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")) + print(f"Starting {transport} server on port {port}") + server.run() @pytest.fixture -def stateless_http_server_url(stateless_http_server_port: int) -> str: - """Get the stateless StreamableHTTP server URL for testing.""" - return f"http://127.0.0.1:{stateless_http_server_port}" - - -# Create a function to make the FastMCP server app -def make_fastmcp_app(): - """Create a FastMCP server without auth settings.""" - transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) - - # Add a simple tool - @mcp.tool(description="A simple echo tool") - def echo(message: str) -> str: - return f"Echo: {message}" - - # Add a tool that uses elicitation - @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context) -> str: - class AnswerSchema(BaseModel): - answer: str = Field(description="The user's answer to the question") +def server_transport(request, server_port: int) -> Generator[str, None, None]: + """Start server in a separate process with specified MCP instance and transport. - result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) - - if result.action == "accept" and result.data: - return f"User answered: {result.data.answer}" - else: - # Handle cancellation or decline - return f"User cancelled or declined: {result.action}" - - # Create the SSE app - app = mcp.sse_app() - - return mcp, app - - -def make_everything_fastmcp() -> FastMCP: - """Create a FastMCP server with all features enabled for testing.""" - transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - mcp = FastMCP(name="EverythingServer", transport_security=transport_security) - - # Tool with context for logging and progress - @mcp.tool(description="A tool that demonstrates logging and progress", title="Progress Tool") - async def tool_with_progress(message: str, ctx: Context, steps: int = 3) -> str: - await ctx.info(f"Starting processing of '{message}' with {steps} steps") - - # Send progress notifications - for i in range(steps): - progress_value = (i + 1) / steps - await ctx.report_progress( - progress=progress_value, - total=1.0, - message=f"Processing step {i + 1} of {steps}", - ) - await ctx.debug(f"Completed step {i + 1}") - - return f"Processed '{message}' in {steps} steps" - - # Simple tool for basic functionality - @mcp.tool(description="A simple echo tool", title="Echo Tool") - def echo(message: str) -> str: - return f"Echo: {message}" - - # Tool that returns ResourceLinks - @mcp.tool(description="Lists files and returns resource links", title="List Files Tool") - def list_files() -> list[ResourceLink]: - """Returns a list of resource links for files matching the pattern.""" - - # Mock some file resources for testing - file_resources = [ - { - "type": "resource_link", - "uri": "file:///project/README.md", - "name": "README.md", - "mimeType": "text/markdown", - } - ] - - result: list[ResourceLink] = [ResourceLink.model_validate(file_json) for file_json in file_resources] - - return result - - # Tool with sampling capability - @mcp.tool(description="A tool that uses sampling to generate content", title="Sampling Tool") - async def sampling_tool(prompt: str, ctx: Context) -> str: - await ctx.info(f"Requesting sampling for prompt: {prompt}") - - # Request sampling from the client - result = await ctx.session.create_message( - messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], - max_tokens=100, - temperature=0.7, - ) - - await ctx.info(f"Received sampling result from model: {result.model}") - # Handle different content types - if result.content.type == "text": - return f"Sampling result: {result.content.text[:100]}..." - else: - return f"Sampling result: {str(result.content)[:100]}..." - - # Tool that sends notifications and logging - @mcp.tool(description="A tool that demonstrates notifications and logging", title="Notification Tool") - async def notification_tool(message: str, ctx: Context) -> str: - # Send different log levels - await ctx.debug("Debug: Starting notification tool") - await ctx.info(f"Info: Processing message '{message}'") - await ctx.warning("Warning: This is a test warning") - - # Send resource change notifications - await ctx.session.send_resource_list_changed() - await ctx.session.send_tool_list_changed() - - await ctx.info("Completed notification tool successfully") - return f"Sent notifications and logs for: {message}" - - # Resource - static - def get_static_info() -> str: - return "This is static resource content" - - static_resource = FunctionResource( - uri=AnyUrl("resource://static/info"), - name="Static Info", - title="Static Information", - description="Static information resource", - fn=get_static_info, - ) - mcp.add_resource(static_resource) - - # Resource - dynamic function - @mcp.resource("resource://dynamic/{category}", title="Dynamic Resource") - def dynamic_resource(category: str) -> str: - return f"Dynamic resource content for category: {category}" - - # Resource template - @mcp.resource("resource://template/{id}/data", title="Template Resource") - def template_resource(id: str) -> str: - return f"Template resource data for ID: {id}" - - # Prompt - simple - @mcp.prompt(description="A simple prompt", title="Simple Prompt") - def simple_prompt(topic: str) -> str: - return f"Tell me about {topic}" - - # Prompt - complex with multiple messages - @mcp.prompt(description="Complex prompt with context", title="Complex Prompt") - def complex_prompt(user_query: str, context: str = "general") -> str: - # For simplicity, return a single string that incorporates the context - # Since FastMCP doesn't support system messages in the same way - return f"Context: {context}. Query: {user_query}" - - # Resource template with completion support - @mcp.resource("github://repos/{owner}/{repo}", title="GitHub Repository") - def github_repo_resource(owner: str, repo: str) -> str: - return f"Repository: {owner}/{repo}" - - # Add completion handler for the server - @mcp.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - # Handle GitHub repository completion - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo": - if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol": - # Return repos for modelcontextprotocol org - return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False) - elif context and context.arguments and context.arguments.get("owner") == "test-org": - # Return repos for test-org - return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False) - - # Handle prompt completions - if isinstance(ref, PromptReference): - if ref.name == "complex_prompt" and argument.name == "context": - # Complete context values - contexts = ["general", "technical", "business", "academic"] - return Completion( - values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False - ) - - # Default: no completion available - return Completion(values=[], total=0, hasMore=False) - - # Tool that echoes request headers from context - @mcp.tool(description="Echo request headers from context", title="Echo Headers") - def echo_headers(ctx: Context[Any, Any, Request]) -> str: - """Returns the request headers as JSON.""" - headers_info = {} - if ctx.request_context.request: - # Now the type system knows request is a Starlette Request object - headers_info = dict(ctx.request_context.request.headers) - return json.dumps(headers_info) - - # Tool that returns full request context - @mcp.tool(description="Echo request context with custom data", title="Echo Context") - def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: - """Returns request context including headers and custom data.""" - context_data = { - "custom_request_id": custom_request_id, - "headers": {}, - "method": None, - "path": None, - } - if ctx.request_context.request: - request = ctx.request_context.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return json.dumps(context_data) - - # Restaurant booking tool with elicitation - @mcp.tool(description="Book a table at a restaurant with elicitation", title="Restaurant Booking") - async def book_restaurant( - date: str, - time: str, - party_size: int, - ctx: Context, - ) -> str: - """Book a table - uses elicitation if requested date is unavailable.""" - - class AlternativeDateSchema(BaseModel): - checkAlternative: bool = Field(description="Would you like to try another date?") - alternativeDate: str = Field( - default="2024-12-26", - description="What date would you prefer? (YYYY-MM-DD)", - ) - - # For testing: assume dates starting with "2024-12-25" are unavailable - if date.startswith("2024-12-25"): - # Use elicitation to ask about alternatives - result = await ctx.elicit( - message=( - f"No tables available for {party_size} people on {date} " - f"at {time}. Would you like to check another date?" - ), - schema=AlternativeDateSchema, - ) - - if result.action == "accept" and result.data: - if result.data.checkAlternative: - alt_date = result.data.alternativeDate - return f"✅ Booked table for {party_size} on {alt_date} at {time}" - else: - return "❌ No booking made" - elif result.action in ("decline", "cancel"): - return "❌ Booking cancelled" - else: - # Handle case where action is "accept" but data is None - return "❌ No booking data received" - else: - # Available - book directly - return f"✅ Booked table for {party_size} on {date} at {time}" - - return mcp - - -def make_everything_fastmcp_app(): - """Create a comprehensive FastMCP server with SSE transport.""" - mcp = make_everything_fastmcp() - # Create the SSE app - app = mcp.sse_app() - return mcp, app - - -def make_fastmcp_streamable_http_app(): - """Create a FastMCP server with StreamableHTTP transport.""" - transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) - - # Add a simple tool - @mcp.tool(description="A simple echo tool") - def echo(message: str) -> str: - return f"Echo: {message}" - - # Create the StreamableHTTP app - app: Starlette = mcp.streamable_http_app() - - return mcp, app - - -def make_everything_fastmcp_streamable_http_app(): - """Create a comprehensive FastMCP server with StreamableHTTP transport.""" - # Create a new instance with different name for HTTP transport - mcp = make_everything_fastmcp() - # We can't change the name after creation, so we'll use the same name - # Create the StreamableHTTP app - app: Starlette = mcp.streamable_http_app() - return mcp, app + Args: + request: pytest request with param tuple of (module_name, transport) + server_port: Port to run the server on + Yields: + str: The transport type ('sse' or 'streamable_http') + """ + module_name, transport = request.param -def make_fastmcp_stateless_http_app(): - """Create a FastMCP server with stateless StreamableHTTP transport.""" - transport_security = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + proc = multiprocessing.Process( + target=run_server_with_transport, + args=(module_name, server_port, transport), + daemon=True, ) - mcp = FastMCP(name="StatelessServer", stateless_http=True, transport_security=transport_security) - - # Add a simple tool - @mcp.tool(description="A simple echo tool") - def echo(message: str) -> str: - return f"Echo: {message}" - - # Create the StreamableHTTP app - app: Starlette = mcp.streamable_http_app() - - return mcp, app - - -def run_server(server_port: int) -> None: - """Run the server.""" - _, app = make_fastmcp_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"Starting server on port {server_port}") - server.run() - - -def run_everything_legacy_sse_http_server(server_port: int) -> None: - """Run the comprehensive server with all features.""" - _, app = make_everything_fastmcp_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"Starting comprehensive server on port {server_port}") - server.run() - - -def run_streamable_http_server(server_port: int) -> None: - """Run the StreamableHTTP server.""" - _, app = make_fastmcp_streamable_http_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"Starting StreamableHTTP server on port {server_port}") - server.run() - - -def run_everything_server(server_port: int) -> None: - """Run the comprehensive StreamableHTTP server with all features.""" - _, app = make_everything_fastmcp_streamable_http_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"Starting comprehensive StreamableHTTP server on port {server_port}") - server.run() - - -def run_stateless_http_server(server_port: int) -> None: - """Run the stateless StreamableHTTP server.""" - _, app = make_fastmcp_stateless_http_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"Starting stateless StreamableHTTP server on port {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - """Start the server in a separate process and clean up after the test.""" - proc = multiprocessing.Process(target=run_server, args=(server_port,), daemon=True) - print("Starting server process") proc.start() # Wait for server to be running max_attempts = 20 attempt = 0 - print("Waiting for server to start") while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -464,708 +150,443 @@ def server(server_port: int) -> Generator[None, None, None]: else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - yield + yield transport - print("Killing server") proc.kill() proc.join(timeout=2) if proc.is_alive(): print("Server process failed to terminate") -@pytest.fixture() -def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: - """Start the StreamableHTTP server in a separate process.""" - proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True) - print("Starting StreamableHTTP server process") - proc.start() - - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for StreamableHTTP server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", http_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 +# Helper function to create client based on transport +def create_client_for_transport(transport: str, server_url: str): + """Create the appropriate client context manager based on transport type.""" + if transport == "sse": + endpoint = f"{server_url}/sse" + return sse_client(endpoint) + elif transport == "streamable-http": + endpoint = f"{server_url}/mcp" + return streamablehttp_client(endpoint) else: - raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts") + raise ValueError(f"Invalid transport: {transport}") - yield - print("Killing StreamableHTTP server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("StreamableHTTP server process failed to terminate") +def unpack_streams(client_streams): + """Unpack client streams handling different return values from SSE vs StreamableHTTP. + SSE client returns (read_stream, write_stream) + StreamableHTTP client returns (read_stream, write_stream, session_id_callback) -@pytest.fixture() -def stateless_http_server( - stateless_http_server_port: int, -) -> Generator[None, None, None]: - """Start the stateless StreamableHTTP server in a separate process.""" - proc = multiprocessing.Process( - target=run_stateless_http_server, - args=(stateless_http_server_port,), - daemon=True, - ) - print("Starting stateless StreamableHTTP server process") - proc.start() + Args: + client_streams: Tuple from client context manager - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for stateless StreamableHTTP server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", stateless_http_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 + Returns: + Tuple of (read_stream, write_stream) + """ + if len(client_streams) == 2: + return client_streams else: - raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts") + read_stream, write_stream, _ = client_streams + return read_stream, write_stream - yield - print("Killing stateless StreamableHTTP server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("Stateless StreamableHTTP server process failed to terminate") +# Callback functions for testing +async def sampling_callback(context, params) -> CreateMessageResult: + """Sampling callback for tests.""" + return CreateMessageResult( + role="assistant", + content=TextContent( + type="text", + text="This is a simulated LLM response for testing", + ), + model="test-model", + ) + + +async def elicitation_callback(context, params): + """Elicitation callback for tests.""" + # For restaurant booking test + if "No tables available" in params.message: + return ElicitResult( + action="accept", + content={"checkAlternative": True, "alternativeDate": "2024-12-26"}, + ) + else: + return ElicitResult(action="decline") +# Test basic tools @pytest.mark.anyio -async def test_fastmcp_without_auth(server: None, server_url: str) -> None: - """Test that FastMCP works when auth settings are not provided.""" - # Connect to the server - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: +@pytest.mark.parametrize( + "server_transport", + [ + ("basic_tool", "sse"), + ("basic_tool", "streamable-http"), + ], + indirect=True, +) +async def test_basic_tools(server_transport: str, server_url: str) -> None: + """Test basic tool functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream) as session: # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "NoAuthServer" + assert result.serverInfo.name == "Tool Example" + assert result.capabilities.tools is not None - # Test that we can call tools without authentication - tool_result = await session.call_tool("echo", {"message": "hello"}) + # Test sum tool + tool_result = await session.call_tool("sum", {"a": 5, "b": 3}) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" + assert tool_result.content[0].text == "8" + + # Test weather tool + weather_result = await session.call_tool("get_weather", {"city": "London"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + assert "Weather in London: 22degreesC" in weather_result.content[0].text +# Test resources @pytest.mark.anyio -async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None: - """Test that FastMCP works with StreamableHTTP transport.""" - # Connect to the server using StreamableHTTP - async with streamablehttp_client(http_server_url + "/mcp") as ( - read_stream, - write_stream, - _, - ): - # Create a session using the client streams +@pytest.mark.parametrize( + "server_transport", + [ + ("basic_resource", "sse"), + ("basic_resource", "streamable-http"), + ], + indirect=True, +) +async def test_basic_resources(server_transport: str, server_url: str) -> None: + """Test basic resource functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) async with ClientSession(read_stream, write_stream) as session: # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "NoAuthServer" - - # Test that we can call tools without authentication - tool_result = await session.call_tool("echo", {"message": "hello"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" - - + assert result.serverInfo.name == "Resource Example" + assert result.capabilities.resources is not None + + # Test document resource + doc_content = await session.read_resource(AnyUrl("file://documents/readme")) + assert isinstance(doc_content, ReadResourceResult) + assert len(doc_content.contents) == 1 + assert isinstance(doc_content.contents[0], TextResourceContents) + assert "Content of readme" in doc_content.contents[0].text + + # Test settings resource + settings_content = await session.read_resource(AnyUrl("config://settings")) + assert isinstance(settings_content, ReadResourceResult) + assert len(settings_content.contents) == 1 + assert isinstance(settings_content.contents[0], TextResourceContents) + settings_json = json.loads(settings_content.contents[0].text) + assert settings_json["theme"] == "dark" + assert settings_json["language"] == "en" + + +# Test prompts @pytest.mark.anyio -async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None: - """Test that FastMCP works with stateless StreamableHTTP transport.""" - # Connect to the server using StreamableHTTP - async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( - read_stream, - write_stream, - _, - ): +@pytest.mark.parametrize( + "server_transport", + [ + ("basic_prompt", "sse"), + ("basic_prompt", "streamable-http"), + ], + indirect=True, +) +async def test_basic_prompts(server_transport: str, server_url: str) -> None: + """Test basic prompt functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) async with ClientSession(read_stream, write_stream) as session: + # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "StatelessServer" - tool_result = await session.call_tool("echo", {"message": "hello"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" - - for i in range(3): - tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == f"Echo: test_{i}" - - -@pytest.fixture -def everything_server_port() -> int: - """Get a free port for testing the comprehensive server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def everything_server_url(everything_server_port: int) -> str: - """Get the comprehensive server URL for testing.""" - return f"http://127.0.0.1:{everything_server_port}" - - -@pytest.fixture -def everything_http_server_port() -> int: - """Get a free port for testing the comprehensive StreamableHTTP server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def everything_http_server_url(everything_http_server_port: int) -> str: - """Get the comprehensive StreamableHTTP server URL for testing.""" - return f"http://127.0.0.1:{everything_http_server_port}" - - -@pytest.fixture() -def everything_server(everything_server_port: int) -> Generator[None, None, None]: - """Start the comprehensive server in a separate process and clean up after.""" - proc = multiprocessing.Process( - target=run_everything_legacy_sse_http_server, - args=(everything_server_port,), - daemon=True, - ) - print("Starting comprehensive server process") - proc.start() - - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for comprehensive server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", everything_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts") - - yield - - print("Killing comprehensive server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("Comprehensive server process failed to terminate") - - -@pytest.fixture() -def everything_streamable_http_server( - everything_http_server_port: int, -) -> Generator[None, None, None]: - """Start the comprehensive StreamableHTTP server in a separate process.""" - proc = multiprocessing.Process( - target=run_everything_server, - args=(everything_http_server_port,), - daemon=True, - ) - print("Starting comprehensive StreamableHTTP server process") - proc.start() - - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for comprehensive StreamableHTTP server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", everything_http_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after {max_attempts} attempts") - - yield - - print("Killing comprehensive StreamableHTTP server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("Comprehensive StreamableHTTP server process failed to terminate") - - -class NotificationCollector: - def __init__(self): - self.progress_notifications: list = [] - self.log_messages: list = [] - self.resource_notifications: list = [] - self.tool_notifications: list = [] - - async def handle_progress(self, params) -> None: - self.progress_notifications.append(params) - - async def handle_log(self, params) -> None: - self.log_messages.append(params) - - async def handle_resource_list_changed(self, params) -> None: - self.resource_notifications.append(params) - - async def handle_tool_list_changed(self, params) -> None: - self.tool_notifications.append(params) + assert result.serverInfo.name == "Prompt Example" + assert result.capabilities.prompts is not None + + # Test review_code prompt + prompts = await session.list_prompts() + review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) + assert review_prompt is not None + + prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) == 1 + assert isinstance(prompt_result.messages[0].content, TextContent) + assert "Please review this code:" in prompt_result.messages[0].content.text + assert "def hello():" in prompt_result.messages[0].content.text + + # Test debug_error prompt + debug_result = await session.get_prompt( + "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} + ) + assert isinstance(debug_result, GetPromptResult) + assert len(debug_result.messages) == 3 + assert debug_result.messages[0].role == "user" + assert isinstance(debug_result.messages[0].content, TextContent) + assert "I'm seeing this error:" in debug_result.messages[0].content.text + assert debug_result.messages[1].role == "user" + assert isinstance(debug_result.messages[1].content, TextContent) + assert "TypeError" in debug_result.messages[1].content.text + assert debug_result.messages[2].role == "assistant" + assert isinstance(debug_result.messages[2].content, TextContent) + assert "I'll help debug that" in debug_result.messages[2].content.text + + +# Test progress reporting +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_transport", + [ + ("tool_progress", "sse"), + ("tool_progress", "streamable-http"), + ], + indirect=True, +) +async def test_tool_progress(server_transport: str, server_url: str) -> None: + """Test tool progress reporting.""" + transport = server_transport + collector = NotificationCollector() - async def handle_generic_notification(self, message) -> None: - # Check if this is a ServerNotification - if isinstance(message, ServerNotification): - # Check the specific notification type - if isinstance(message.root, ProgressNotification): - await self.handle_progress(message.root.params) - elif isinstance(message.root, LoggingMessageNotification): - await self.handle_log(message.root.params) - elif isinstance(message.root, ResourceListChangedNotification): - await self.handle_resource_list_changed(message.root.params) - elif isinstance(message.root, ToolListChangedNotification): - await self.handle_tool_list_changed(message.root.params) + async def message_handler(message): + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + client_cm = create_client_for_transport(transport, server_url) -async def create_test_elicitation_callback(context, params): - """Shared elicitation callback for tests. + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Progress Example" - Handles elicitation requests for restaurant booking tests. - """ - # For restaurant booking test - if "No tables available" in params.message: - return ElicitResult( - action="accept", - content={"checkAlternative": True, "alternativeDate": "2024-12-26"}, - ) - else: - # Default response - return ElicitResult(action="decline") + # Test progress callback + progress_updates = [] + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) -async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None: - """ - Test all MCP features using the provided session. + # Call tool with progress + steps = 3 + tool_result = await session.call_tool( + "long_running_task", + {"task_name": "Test Task", "steps": steps}, + progress_callback=progress_callback, + ) - Args: - session: The MCP client session to test with - collector: Notification collector for capturing server notifications - """ - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "EverythingServer" - - # Check server features are reported - assert result.capabilities.prompts is not None - assert result.capabilities.resources is not None - assert result.capabilities.tools is not None - # Note: logging capability may be None if no tools use context logging - - # Test tools - # 1. Simple echo tool - tool_result = await session.call_tool("echo", {"message": "hello"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" - - # 2. Test tool that returns ResourceLinks - list_files_result = await session.call_tool("list_files") - assert len(list_files_result.content) == 1 - - # Rest should be ResourceLinks - content = list_files_result.content[0] - assert isinstance(content, ResourceLink) - assert str(content.uri).startswith("file:///") - assert content.name is not None - assert content.mimeType is not None - - # Test progress callback functionality - progress_updates = [] - - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - """Collect progress updates for testing (async version).""" - progress_updates.append((progress, total, message)) - print(f"Progress: {progress}/{total} - {message}") - - test_message = "test" - steps = 3 - params = { - "message": test_message, - "steps": steps, - } - tool_result = await session.call_tool( - "tool_with_progress", - params, - progress_callback=progress_callback, - ) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert f"Processed '{test_message}' in {steps} steps" in tool_result.content[0].text - - # Verify progress callback was called - assert len(progress_updates) == steps - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / steps - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert message is not None - assert f"step {i + 1} of {steps}" in message - - # Verify we received log messages from the tool - # Note: Progress notifications require special handling in the MCP client - # that's not implemented by default, so we focus on testing logging - assert len(collector.log_messages) > 0 - - # 3. Test sampling tool - prompt = "What is the meaning of life?" - sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "Sampling result:" in sampling_result.content[0].text - assert "This is a simulated LLM response" in sampling_result.content[0].text - - # Verify we received log messages from the sampling tool - assert len(collector.log_messages) > 0 - assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages) - assert any("Received sampling result from model" in msg.data for msg in collector.log_messages) - - # 4. Test notification tool - notification_message = "test_notifications" - notification_result = await session.call_tool("notification_tool", {"message": notification_message}) - assert len(notification_result.content) == 1 - assert isinstance(notification_result.content[0], TextContent) - assert "Sent notifications and logs" in notification_result.content[0].text - - # Verify we received various notification types - assert len(collector.log_messages) > 3 # Should have logs from both tools - assert len(collector.resource_notifications) > 0 - assert len(collector.tool_notifications) > 0 - - # Check that we got different log levels - log_levels = [msg.level for msg in collector.log_messages] - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - - # 5. Test elicitation tool - # Test restaurant booking with unavailable date (triggers elicitation) - booking_result = await session.call_tool( - "book_restaurant", - { - "date": "2024-12-25", # Unavailable date to trigger elicitation - "time": "19:00", - "party_size": 4, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - # Should have booked the alternative date from elicitation callback - assert "✅ Booked table for 4 on 2024-12-26" in booking_result.content[0].text - - # Test resources - # 1. Static resource - resources = await session.list_resources() - # Try using string comparison since AnyUrl might not match directly - static_resource = next( - (r for r in resources.resources if str(r.uri) == "resource://static/info"), - None, - ) - assert static_resource is not None - assert static_resource.name == "Static Info" - - static_content = await session.read_resource(AnyUrl("resource://static/info")) - assert isinstance(static_content, ReadResourceResult) - assert len(static_content.contents) == 1 - assert isinstance(static_content.contents[0], TextResourceContents) - assert static_content.contents[0].text == "This is static resource content" - - # 2. Dynamic resource - resource_category = "test" - dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}")) - assert isinstance(dynamic_content, ReadResourceResult) - assert len(dynamic_content.contents) == 1 - assert isinstance(dynamic_content.contents[0], TextResourceContents) - assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text - - # 3. Template resource - resource_id = "456" - template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data")) - assert isinstance(template_content, ReadResourceResult) - assert len(template_content.contents) == 1 - assert isinstance(template_content.contents[0], TextResourceContents) - assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text - - # Test prompts - # 1. Simple prompt - prompts = await session.list_prompts() - simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None) - assert simple_prompt is not None - - prompt_topic = "AI" - prompt_result = await session.get_prompt("simple_prompt", {"topic": prompt_topic}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) >= 1 - # The actual message structure depends on the prompt implementation - - # 2. Complex prompt - complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None) - assert complex_prompt is not None - - query = "What is AI?" - context = "technical" - complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context}) - assert isinstance(complex_result, GetPromptResult) - assert len(complex_result.messages) >= 1 - - # Test request context propagation (only works when headers are available) - - headers_result = await session.call_tool("echo_headers", {}) - assert len(headers_result.content) == 1 - assert isinstance(headers_result.content[0], TextContent) - - # If we got headers, verify they exist - headers_data = json.loads(headers_result.content[0].text) - # The headers depend on the transport and test setup - print(f"Received headers: {headers_data}") - - # Test 6: Call tool that returns full context - context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"}) - assert len(context_result.content) == 1 - assert isinstance(context_result.content[0], TextContent) - - context_data = json.loads(context_result.content[0].text) - assert context_data["custom_request_id"] == "test-123" - # The method should be POST for most transports - if context_data["method"]: - assert context_data["method"] == "POST" - - # Test completion functionality - # 1. Test resource template completion with context - repo_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "modelcontextprotocol"}, - ) - assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"] - assert repo_result.completion.total == 3 - assert repo_result.completion.hasMore is False - - # 2. Test with different context - repo_result2 = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "test-org"}, - ) - assert repo_result2.completion.values == ["test-repo1", "test-repo2"] - assert repo_result2.completion.total == 2 + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Task 'Test Task' completed" in tool_result.content[0].text - # 3. Test prompt argument completion - context_result = await session.complete( - ref=PromptReference(type="ref/prompt", name="complex_prompt"), - argument={"name": "context", "value": "tech"}, - ) - assert "technical" in context_result.completion.values + # Verify progress updates + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert f"Step {i + 1}/{steps}" in message - # 4. Test completion without context (should return empty) - no_context_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": "test"}, - ) - assert no_context_result.completion.values == [] - assert no_context_result.completion.total == 0 + # Verify log messages + assert len(collector.log_messages) > 0 -async def sampling_callback( - context: RequestContext[ClientSession, None], - params: CreateMessageRequestParams, -) -> CreateMessageResult: - # Simulate LLM response based on the input - if params.messages and isinstance(params.messages[0].content, TextContent): - input_text = params.messages[0].content.text - else: - input_text = "No input" - response_text = f"This is a simulated LLM response to: {input_text}" +# Test sampling +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_transport", + [ + ("sampling", "sse"), + ("sampling", "streamable-http"), + ], + indirect=True, +) +async def test_sampling(server_transport: str, server_url: str) -> None: + """Test sampling (LLM interaction) functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Sampling Example" + assert result.capabilities.tools is not None - model_name = "test-llm-model" - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=response_text), - model=model_name, - stopReason="endTurn", - ) + # Test sampling tool + sampling_result = await session.call_tool("generate_poem", {"topic": "nature"}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "This is a simulated LLM response" in sampling_result.content[0].text +# Test elicitation @pytest.mark.anyio -async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None: - """Test all MCP features work correctly with SSE transport.""" - - # Create notification collector - collector = NotificationCollector() - - # Connect to the server with callbacks - async with sse_client(everything_server_url + "/sse") as streams: - # Set up message handler to capture notifications - async def message_handler(message): - print(f"Received message: {message}") - await collector.handle_generic_notification(message) - if isinstance(message, Exception): - raise message - - async with ClientSession( - *streams, - sampling_callback=sampling_callback, - elicitation_callback=create_test_elicitation_callback, - message_handler=message_handler, - ) as session: - # Run the common test suite - await call_all_mcp_features(session, collector) +@pytest.mark.parametrize( + "server_transport", + [ + ("elicitation", "sse"), + ("elicitation", "streamable-http"), + ], + indirect=True, +) +async def test_elicitation(server_transport: str, server_url: str) -> None: + """Test elicitation (user interaction) functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Elicitation Example" + + # Test booking with unavailable date (triggers elicitation) + booking_result = await session.call_tool( + "book_table", + { + "date": "2024-12-25", # Unavailable date + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text + + # Test booking with available date (no elicitation) + booking_result = await session.call_tool( + "book_table", + { + "date": "2024-12-20", # Available date + "time": "20:00", + "party_size": 2, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text +# Test notifications @pytest.mark.anyio -async def test_fastmcp_all_features_streamable_http( - everything_streamable_http_server: None, everything_http_server_url: str -) -> None: - """Test all MCP features work correctly with StreamableHTTP transport.""" - - # Create notification collector +@pytest.mark.parametrize( + "server_transport", + [ + ("notifications", "sse"), + ("notifications", "streamable-http"), + ], + indirect=True, +) +async def test_notifications(server_transport: str, server_url: str) -> None: + """Test notifications and logging functionality.""" + transport = server_transport collector = NotificationCollector() - # Connect to the server using StreamableHTTP - async with streamablehttp_client(everything_http_server_url + "/mcp") as ( - read_stream, - write_stream, - _, - ): - # Set up message handler to capture notifications - async def message_handler(message): - print(f"Received message: {message}") - await collector.handle_generic_notification(message) - if isinstance(message, Exception): - raise message - - async with ClientSession( - read_stream, - write_stream, - sampling_callback=sampling_callback, - elicitation_callback=create_test_elicitation_callback, - message_handler=message_handler, - ) as session: - # Run the common test suite with HTTP-specific test suffix - await call_all_mcp_features(session, collector) + async def message_handler(message): + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + client_cm = create_client_for_transport(transport, server_url) -@pytest.mark.anyio -async def test_elicitation_feature(server: None, server_url: str) -> None: - """Test the elicitation feature.""" - - # Create a custom handler for elicitation requests - async def elicitation_callback(context, params): - # Verify the elicitation parameters - if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(content={"answer": "Test User"}, action="accept") - else: - raise ValueError("Unexpected elicitation message") - - # Connect to the server with our custom elicitation handler - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams, elicitation_callback=elicitation_callback) as session: - # First initialize the session + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "NoAuthServer" + assert result.serverInfo.name == "Notifications Example" - # Call the tool that uses elicitation - tool_result = await session.call_tool("ask_user", {"prompt": "What is your name?"}) - # Verify the result + # Call tool that generates notifications + tool_result = await session.call_tool("process_data", {"data": "test_data"}) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) - # # The test should only succeed with the successful elicitation response - assert tool_result.content[0].text == "User answered: Test User" + assert "Processed: test_data" in tool_result.content[0].text + + # Verify log messages at different levels + assert len(collector.log_messages) >= 4 + log_levels = {msg.level for msg in collector.log_messages} + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + assert "error" in log_levels + # Verify resource list changed notification + assert len(collector.resource_notifications) > 0 + +# Test completion @pytest.mark.anyio -async def test_title_precedence(everything_server: None, everything_server_url: str) -> None: - """Test that titles are properly returned for tools, resources, and prompts.""" - from mcp.shared.metadata_utils import get_display_name +@pytest.mark.parametrize( + "server_transport", + [ + ("completion", "sse"), + ("completion", "streamable-http"), + ], + indirect=True, +) +async def test_completion(server_transport: str, server_url: str) -> None: + """Test completion (autocomplete) functionality.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) - async with sse_client(everything_server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - # Initialize the session + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream) as session: + # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Example" + assert result.capabilities.resources is not None + assert result.capabilities.prompts is not None + + # Test resource completion + from mcp.types import ResourceTemplateReference + + completion_result = await session.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert len(completion_result.completion.values) == 3 + assert "python-sdk" in completion_result.completion.values + assert "typescript-sdk" in completion_result.completion.values + assert "specification" in completion_result.completion.values + + # Test prompt completion + from mcp.types import PromptReference + + completion_result = await session.complete( + ref=PromptReference(type="ref/prompt", name="review_code"), + argument={"name": "language", "value": "py"}, + ) - # Test tools have titles - tools_result = await session.list_tools() - assert tools_result.tools - - # Check specific tools have titles - tool_names_to_titles = { - "tool_with_progress": "Progress Tool", - "echo": "Echo Tool", - "sampling_tool": "Sampling Tool", - "notification_tool": "Notification Tool", - "echo_headers": "Echo Headers", - "echo_context": "Echo Context", - "book_restaurant": "Restaurant Booking", - } - - for tool in tools_result.tools: - if tool.name in tool_names_to_titles: - assert tool.title == tool_names_to_titles[tool.name] - # Test get_display_name utility - assert get_display_name(tool) == tool_names_to_titles[tool.name] - - # Test resources have titles - resources_result = await session.list_resources() - assert resources_result.resources - - # Check specific resources have titles - static_resource = next((r for r in resources_result.resources if r.name == "Static Info"), None) - assert static_resource is not None - assert static_resource.title == "Static Information" - assert get_display_name(static_resource) == "Static Information" - - # Test resource templates have titles - resource_templates = await session.list_resource_templates() - assert resource_templates.resourceTemplates - - # Check specific resource templates have titles - template_uris_to_titles = { - "resource://dynamic/{category}": "Dynamic Resource", - "resource://template/{id}/data": "Template Resource", - "github://repos/{owner}/{repo}": "GitHub Repository", - } - - for template in resource_templates.resourceTemplates: - if template.uriTemplate in template_uris_to_titles: - assert template.title == template_uris_to_titles[template.uriTemplate] - assert get_display_name(template) == template_uris_to_titles[template.uriTemplate] - - # Test prompts have titles - prompts_result = await session.list_prompts() - assert prompts_result.prompts - - # Check specific prompts have titles - prompt_names_to_titles = { - "simple_prompt": "Simple Prompt", - "complex_prompt": "Complex Prompt", - } - - for prompt in prompts_result.prompts: - if prompt.name in prompt_names_to_titles: - assert prompt.title == prompt_names_to_titles[prompt.name] - assert get_display_name(prompt) == prompt_names_to_titles[prompt.name] + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert "python" in completion_result.completion.values + assert all(lang.startswith("py") for lang in completion_result.completion.values) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index c30930f7b..37351bc6f 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -146,7 +146,7 @@ async def test_add_tool_decorator(self): mcp = FastMCP() @mcp.tool() - def add(x: int, y: int) -> int: + def sum(x: int, y: int) -> int: return x + y assert len(mcp._tool_manager.list_tools()) == 1 @@ -158,7 +158,7 @@ async def test_add_tool_decorator_incorrect_usage(self): with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): @mcp.tool # Missing parentheses #type: ignore - def add(x: int, y: int) -> int: + def sum(x: int, y: int) -> int: return x + y @pytest.mark.anyio diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 4b2052da5..27e16cc8e 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -19,23 +19,23 @@ class TestAddTools: def test_basic_function(self): """Test registering and running a basic function.""" - def add(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() - manager.add_tool(add) + manager.add_tool(sum) - tool = manager.get_tool("add") + tool = manager.get_tool("sum") assert tool is not None - assert tool.name == "add" + assert tool.name == "sum" assert tool.description == "Add two numbers." assert tool.is_async is False assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" def test_init_with_tools(self, caplog): - def add(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: return a + b class AddArguments(ArgModelBase): @@ -45,10 +45,10 @@ class AddArguments(ArgModelBase): fn_metadata = FuncMetadata(arg_model=AddArguments) original_tool = Tool( - name="add", + name="sum", title="Add Tool", description="Add two numbers.", - fn=add, + fn=sum, fn_metadata=fn_metadata, is_async=False, parameters=AddArguments.model_json_schema(), @@ -56,13 +56,13 @@ class AddArguments(ArgModelBase): annotations=None, ) manager = ToolManager(tools=[original_tool]) - saved_tool = manager.get_tool("add") + saved_tool = manager.get_tool("sum") assert saved_tool == original_tool # warn on duplicate tools with caplog.at_level(logging.WARNING): manager = ToolManager(True, tools=[original_tool, original_tool]) - assert "Tool already exists: add" in caplog.text + assert "Tool already exists: sum" in caplog.text @pytest.mark.anyio async def test_async_function(self): @@ -182,13 +182,13 @@ def f(x: int) -> int: class TestCallTools: @pytest.mark.anyio async def test_call_tool(self): - def add(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() - manager.add_tool(add) - result = await manager.call_tool("add", {"a": 1, "b": 2}) + manager.add_tool(sum) + result = await manager.call_tool("sum", {"a": 1, "b": 2}) assert result == 3 @pytest.mark.anyio @@ -232,25 +232,25 @@ async def __call__(self, x: int) -> int: @pytest.mark.anyio async def test_call_tool_with_default_args(self): - def add(a: int, b: int = 1) -> int: + def sum(a: int, b: int = 1) -> int: """Add two numbers.""" return a + b manager = ToolManager() - manager.add_tool(add) - result = await manager.call_tool("add", {"a": 1}) + manager.add_tool(sum) + result = await manager.call_tool("sum", {"a": 1}) assert result == 2 @pytest.mark.anyio async def test_call_tool_with_missing_args(self): - def add(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: """Add two numbers.""" return a + b manager = ToolManager() - manager.add_tool(add) + manager.add_tool(sum) with pytest.raises(ToolError): - await manager.call_tool("add", {"a": 1}) + await manager.call_tool("sum", {"a": 1}) @pytest.mark.anyio async def test_call_unknown_tool(self): diff --git a/tests/test_examples.py b/tests/test_examples.py index 230e7d394..decffd810 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -56,8 +56,8 @@ async def test_desktop(monkeypatch): monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) async with client_session(mcp._mcp_server) as client: - # Test the add function - result = await client.call_tool("add", {"a": 1, "b": 2}) + # Test the sum function + result = await client.call_tool("sum", {"a": 1, "b": 2}) assert len(result.content) == 1 content = result.content[0] assert isinstance(content, TextContent) @@ -82,11 +82,13 @@ async def test_desktop(monkeypatch): @pytest.mark.parametrize("example", find_examples("README.md"), ids=str) def test_docs_examples(example: CodeExample, eval_example: EvalExample): - ruff_ignore: list[str] = ["F841", "I001"] + ruff_ignore: list[str] = ["F841", "I001", "F821"] # F821: undefined names (snippets lack imports) - eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=88) + # Use project's actual line length of 120 + eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=120) + # Use Ruff for both formatting and linting (skip Black) if eval_example.update_examples: # pragma: no cover - eval_example.format(example) + eval_example.format_ruff(example) else: - eval_example.lint(example) + eval_example.lint_ruff(example) diff --git a/uv.lock b/uv.lock index 82b8b0859..eaef2e6b9 100644 --- a/uv.lock +++ b/uv.lock @@ -14,6 +14,7 @@ members = [ "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", + "mcp-snippets", ] [[package]] @@ -851,6 +852,17 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-snippets" +version = "0.1.0" +source = { editable = "examples/snippets" } +dependencies = [ + { name = "mcp" }, +] + +[package.metadata] +requires-dist = [{ name = "mcp", editable = "." }] + [[package]] name = "mdurl" version = "0.1.2"